Skip to content

Commit 359b665

Browse files
committed
[TorchToArith] Implement conversion patterns for AtenNegFloatOp.
Implement conversion patterns for `AtenNegFloatOp`: arith::subf(0.0, a);
1 parent e3200d9 commit 359b665

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

lib/Conversion/TorchToArith/TorchToArith.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,27 @@ class ConvertAtenNegIntOp : public OpConversionPattern<AtenNegIntOp> {
110110
};
111111
} // namespace
112112

113+
namespace {
114+
class ConvertAtenNegFloatOp : public OpConversionPattern<AtenNegFloatOp> {
115+
public:
116+
using OpConversionPattern<AtenNegFloatOp>::OpConversionPattern;
117+
LogicalResult matchAndRewrite(
118+
AtenNegFloatOp op,
119+
typename OpConversionPattern<AtenNegFloatOp>::OpAdaptor adaptor,
120+
ConversionPatternRewriter &rewriter) const override {
121+
Value a = adaptor.getA();
122+
Type inputDtype = a.getType();
123+
rewriter.replaceOpWithNewOp<arith::SubFOp>(
124+
op,
125+
arith::ConstantOp::create(
126+
rewriter, op.getLoc(),
127+
rewriter.getFloatAttr(inputDtype, /*value=*/0.0)),
128+
a);
129+
return success();
130+
}
131+
};
132+
} // namespace
133+
113134
namespace {
114135
template <typename AtenOp, typename UnaryOp>
115136
class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern<AtenOp> {
@@ -513,6 +534,9 @@ class ConvertTorchToArith
513534

514535
target.addIllegalOp<AtenNegIntOp>();
515536
patterns.add<ConvertAtenNegIntOp>(typeConverter, context);
537+
target.addIllegalOp<AtenNegFloatOp>();
538+
patterns.add<ConvertAtenNegFloatOp>(typeConverter, context);
539+
516540
target.addIllegalOp<AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp,
517541
AtenMulIntOp, AtenRemainderIntOp, AtenMulIntFloatOp,
518542
AtenMulFloatIntOp>();

test/Conversion/TorchToArith/basic.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,3 +462,15 @@ func.func @torch.aten.div(%arg0: !torch.float, %arg1: !torch.float) -> !torch.fl
462462
%0 = torch.aten.div %arg0, %arg1 : !torch.float, !torch.float -> !torch.float
463463
return %0 : !torch.float
464464
}
465+
466+
// CHECK-LABEL: func.func @torch.aten.neg.float(
467+
// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.float {
468+
// CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]]
469+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f64
470+
// CHECK: %[[SUB:.*]] = arith.subf %[[CST:.*]], [[ARG_F64:.*]] : f64
471+
// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[SUB:.*]]
472+
// CHECK: return %[[OUT:.*]] : !torch.float
473+
func.func @torch.aten.neg.float(%arg0: !torch.float) -> !torch.float {
474+
%0 = torch.aten.neg.float %arg0 : !torch.float -> !torch.float
475+
return %0 : !torch.float
476+
}

0 commit comments

Comments
 (0)