Skip to content

Commit 791debb

Browse files
authored
[TorchToArith] Implement conversion for AtenNegFloatOp. (#4397)
Adds a conversion from `AtenNegFloatOp` to `arith::NegFOp`.
1 parent 3cebce2 commit 791debb

File tree

2 files changed

+32
-6
lines changed

2 files changed

+32
-6
lines changed

lib/Conversion/TorchToArith/TorchToArith.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class ConvertAtenNegIntOp : public OpConversionPattern<AtenNegIntOp> {
112112

113113
namespace {
114114
template <typename AtenOp, typename UnaryOp>
115-
class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern<AtenOp> {
115+
class ConvertAtenUnaryOp : public OpConversionPattern<AtenOp> {
116116
public:
117117
using OpConversionPattern<AtenOp>::OpConversionPattern;
118118
LogicalResult
@@ -513,6 +513,10 @@ class ConvertTorchToArith
513513

514514
target.addIllegalOp<AtenNegIntOp>();
515515
patterns.add<ConvertAtenNegIntOp>(typeConverter, context);
516+
target.addIllegalOp<AtenNegFloatOp>();
517+
patterns.add<ConvertAtenUnaryOp<AtenNegFloatOp, arith::NegFOp>>(
518+
typeConverter, context);
519+
516520
target.addIllegalOp<AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp,
517521
AtenMulIntOp, AtenRemainderIntOp, AtenMulIntFloatOp,
518522
AtenMulFloatIntOp>();
@@ -558,12 +562,11 @@ class ConvertTorchToArith
558562
typeConverter, context);
559563
patterns.add<ConvertAtenBinaryOp<AtenNeBoolOp, arith::XOrIOp>>(
560564
typeConverter, context);
561-
patterns
562-
.add<ConvertAtenUnaryOpToFloatMathOp<AtenCeilFloatOp, math::CeilOp>>(
563-
typeConverter, context);
564-
target.addIllegalOp<AtenSqrtIntOp>();
565-
patterns.add<ConvertAtenUnaryOpToFloatMathOp<AtenSqrtIntOp, math::SqrtOp>>(
565+
patterns.add<ConvertAtenUnaryOp<AtenCeilFloatOp, math::CeilOp>>(
566566
typeConverter, context);
567+
target.addIllegalOp<AtenSqrtIntOp>();
568+
patterns.add<ConvertAtenUnaryOp<AtenSqrtIntOp, math::SqrtOp>>(typeConverter,
569+
context);
567570
target.addIllegalOp<AtenAnyBoolOp, AtenAllBoolOp>();
568571
patterns.add<ConvertAtenAnyOp>(typeConverter, context);
569572
patterns.add<ConvertAtenAllOp>(typeConverter, context);

test/Conversion/TorchToArith/basic.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,3 +462,26 @@ func.func @torch.aten.div(%arg0: !torch.float, %arg1: !torch.float) -> !torch.fl
462462
%0 = torch.aten.div %arg0, %arg1 : !torch.float, !torch.float -> !torch.float
463463
return %0 : !torch.float
464464
}
465+
466+
// CHECK-LABEL: func.func @torch.aten.neg.int(
467+
// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int {
468+
// CHECK: %[[ARG_I64:.*]] = torch_c.to_i64 %[[ARG]]
469+
// CHECK: %[[CST:.*]] = arith.constant 0 : i64
470+
// CHECK: %[[SUB:.*]] = arith.subi %[[CST:.*]], [[ARG_I64:.*]] : i64
471+
// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[SUB:.*]]
472+
// CHECK: return %[[OUT:.*]] : !torch.int
473+
func.func @torch.aten.neg.int(%arg0: !torch.int) -> !torch.int {
474+
%0 = torch.aten.neg.int %arg0 : !torch.int -> !torch.int
475+
return %0 : !torch.int
476+
}
477+
478+
// CHECK-LABEL: func.func @torch.aten.neg.float(
479+
// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.float {
480+
// CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]]
481+
// CHECK: %[[NEG:.*]] = arith.negf %[[ARG_F64]] : f64
482+
// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[NEG]]
483+
// CHECK: return %[[OUT]] : !torch.float
484+
func.func @torch.aten.neg.float(%arg0: !torch.float) -> !torch.float {
485+
%0 = torch.aten.neg.float %arg0 : !torch.float -> !torch.float
486+
return %0 : !torch.float
487+
}

0 commit comments

Comments
 (0)