@@ -112,7 +112,7 @@ class ConvertAtenNegIntOp : public OpConversionPattern<AtenNegIntOp> {
112112
113113namespace {
114114template <typename AtenOp, typename UnaryOp>
115- class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern <AtenOp> {
115+ class ConvertAtenUnaryOp : public OpConversionPattern <AtenOp> {
116116public:
117117 using OpConversionPattern<AtenOp>::OpConversionPattern;
118118 LogicalResult
@@ -513,6 +513,10 @@ class ConvertTorchToArith
513513
514514 target.addIllegalOp <AtenNegIntOp>();
515515 patterns.add <ConvertAtenNegIntOp>(typeConverter, context);
516+ target.addIllegalOp <AtenNegFloatOp>();
517+ patterns.add <ConvertAtenUnaryOp<AtenNegFloatOp, arith::NegFOp>>(
518+ typeConverter, context);
519+
516520 target.addIllegalOp <AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp,
517521 AtenMulIntOp, AtenRemainderIntOp, AtenMulIntFloatOp,
518522 AtenMulFloatIntOp>();
@@ -558,12 +562,11 @@ class ConvertTorchToArith
558562 typeConverter, context);
559563 patterns.add <ConvertAtenBinaryOp<AtenNeBoolOp, arith::XOrIOp>>(
560564 typeConverter, context);
561- patterns
562- .add <ConvertAtenUnaryOpToFloatMathOp<AtenCeilFloatOp, math::CeilOp>>(
563- typeConverter, context);
564- target.addIllegalOp <AtenSqrtIntOp>();
565- patterns.add <ConvertAtenUnaryOpToFloatMathOp<AtenSqrtIntOp, math::SqrtOp>>(
565+ patterns.add <ConvertAtenUnaryOp<AtenCeilFloatOp, math::CeilOp>>(
566566 typeConverter, context);
567+ target.addIllegalOp <AtenSqrtIntOp>();
568+ patterns.add <ConvertAtenUnaryOp<AtenSqrtIntOp, math::SqrtOp>>(typeConverter,
569+ context);
567570 target.addIllegalOp <AtenAnyBoolOp, AtenAllBoolOp>();
568571 patterns.add <ConvertAtenAnyOp>(typeConverter, context);
569572 patterns.add <ConvertAtenAllOp>(typeConverter, context);
0 commit comments