Skip to content

Commit 68e74f1

Browse files
authored
[TorchToLinalg] Upcast low precision dtypes for direct backward conv lowering (#4408)
The newer direct lowering for backward conv is directly accumulating lower precision types like bf16. This patch adds a check for the default accumulator type. If this type doesn't match the result types for the op, it will also introduce a downcasting elementwise op (post-convolution and pre-collapsing for groups). --------- Signed-off-by: zjgarvey <zjgarvey@gmail.com>
1 parent 06eca3f commit 68e74f1

File tree

2 files changed

+188
-15
lines changed

2 files changed

+188
-15
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,7 +1695,15 @@ class ConvertAtenConvolutionBackwardOp
16951695
auto weightDTy = cast<RankedTensorType>(weight.getType()).getElementType();
16961696
if (!isa<mlir::FloatType>(gradOutputDTy) ||
16971697
!isa<mlir::FloatType>(inputDTy) || !isa<mlir::FloatType>(weightDTy))
1698-
return op.emitError("unimplemented: only fp convolution bwd supported");
1698+
return rewriter.notifyMatchFailure(
1699+
op, "unimplemented: only fp convolution bwd supported");
1700+
1701+
// TODO: support this.
1702+
if (!llvm::all_equal({inputDTy, weightDTy, gradOutputDTy}))
1703+
return rewriter.notifyMatchFailure(
1704+
op, "unimplemented: mixed-precision fp types.");
1705+
1706+
auto accumulatorDTy = getDefaultAccType(rewriter, inputDTy);
16991707

17001708
size_t gradRank = cast<RankedTensorType>(gradOutput.getType()).getRank();
17011709
size_t numSpatialDims = gradRank - 2;
@@ -1833,6 +1841,22 @@ class ConvertAtenConvolutionBackwardOp
18331841
return createZeroInitTensor(rewriter, loc, expandedSizes, type);
18341842
};
18351843

1844+
auto convertFloatAccDtype = [&](Value accumulator, Type targetDTy) {
1845+
auto accDTy =
1846+
cast<RankedTensorType>(accumulator.getType()).getElementType();
1847+
auto floatAccDTy = dyn_cast<mlir::FloatType>(accDTy);
1848+
auto floatTargetDTy = dyn_cast<mlir::FloatType>(targetDTy);
1849+
1850+
assert(floatAccDTy && "Dtype conversion expects float dtypes only.");
1851+
assert(floatTargetDTy && "Dtype conversion expects float dtypes only.");
1852+
1853+
if (floatAccDTy == floatTargetDTy)
1854+
return accumulator;
1855+
1856+
return torch_to_linalg::convertTensorToElementType(
1857+
rewriter, loc, accumulator, targetDTy);
1858+
};
1859+
18361860
SmallVector<Value> newResults(op->getNumResults());
18371861

18381862
// Computing Backward-Input Convolution.
@@ -1945,11 +1969,11 @@ class ConvertAtenConvolutionBackwardOp
19451969
// [N, G, C/G, D*] tensor and collapse back to the original input shape.
19461970
SmallVector<ReassociationIndices> gradInputCollapseIndices;
19471971
Value gradInputInit =
1948-
isGroupedConvBwd
1949-
? createZeroInitExpandedGroupsTensor(rewriter, loc,
1950-
gradInputSizes, inputDTy, 1,
1951-
gradInputCollapseIndices)
1952-
: createZeroInitTensor(rewriter, loc, gradInputSizes, inputDTy);
1972+
isGroupedConvBwd ? createZeroInitExpandedGroupsTensor(
1973+
rewriter, loc, gradInputSizes, accumulatorDTy,
1974+
1, gradInputCollapseIndices)
1975+
: createZeroInitTensor(rewriter, loc, gradInputSizes,
1976+
accumulatorDTy);
19531977

19541978
// Create convolution for data gradient
19551979
auto convRes = createConvInputGradient(rewriter, loc, context,
@@ -1958,11 +1982,16 @@ class ConvertAtenConvolutionBackwardOp
19581982
weightExpanded, gradInputInit)
19591983
.getResult(0);
19601984

1985+
auto returnTensorTy = cast<RankedTensorType>(
1986+
getTypeConverter()->convertType(op->getResult(0).getType()));
1987+
auto returnDTy = returnTensorTy.getElementType();
1988+
convRes = convertFloatAccDtype(convRes, returnDTy);
1989+
19611990
// Collapse [N, G, C/G, D] to [N, C, D] the result of the conv
19621991
// if it is grouped.
19631992
if (isGroupedConvBwd) {
19641993
convRes = tensor::CollapseShapeOp::create(
1965-
rewriter, loc, input.getType(), convRes, gradInputCollapseIndices);
1994+
rewriter, loc, returnTensorTy, convRes, gradInputCollapseIndices);
19661995
}
19671996

19681997
// Cast to the final result type expected by the type converter.
@@ -1998,10 +2027,11 @@ class ConvertAtenConvolutionBackwardOp
19982027
SmallVector<ReassociationIndices> gradWeightCollapseIndices;
19992028
Value gradWeightInit =
20002029
isGroupedConvBwd
2001-
? createZeroInitExpandedGroupsTensor(rewriter, loc,
2002-
gradWeightSizes, weightDTy,
2003-
0, gradWeightCollapseIndices)
2004-
: createZeroInitTensor(rewriter, loc, gradWeightSizes, weightDTy);
2030+
? createZeroInitExpandedGroupsTensor(
2031+
rewriter, loc, gradWeightSizes, accumulatorDTy, 0,
2032+
gradWeightCollapseIndices)
2033+
: createZeroInitTensor(rewriter, loc, gradWeightSizes,
2034+
accumulatorDTy);
20052035

20062036
// Create convolution for weight gradient
20072037
auto convResult = createConvWeightGradient(
@@ -2010,12 +2040,17 @@ class ConvertAtenConvolutionBackwardOp
20102040
paddedInput, gradOutputExpanded, gradWeightInit)
20112041
.getResult(0);
20122042

2043+
auto returnTensorTy = cast<RankedTensorType>(
2044+
getTypeConverter()->convertType(op->getResult(1).getType()));
2045+
auto returnDTy = returnTensorTy.getElementType();
2046+
convResult = convertFloatAccDtype(convResult, returnDTy);
2047+
20132048
// Collapse [G, F/G, C/G, D] to [F, C/G, D] the result of the conv
20142049
// if it is grouped.
20152050
if (isGroupedConvBwd) {
2016-
convResult = tensor::CollapseShapeOp::create(
2017-
rewriter, loc, weight.getType(), convResult,
2018-
gradWeightCollapseIndices);
2051+
convResult = tensor::CollapseShapeOp::create(rewriter, loc,
2052+
returnTensorTy, convResult,
2053+
gradWeightCollapseIndices);
20192054
}
20202055

20212056
// Cast to the final result type expected by the type converter.
@@ -2038,10 +2073,12 @@ class ConvertAtenConvolutionBackwardOp
20382073

20392074
// Zero init for the element type (arith.constant expects a scalar attr).
20402075
Value initSum = arith::ConstantOp::create(
2041-
rewriter, loc, rewriter.getZeroAttr(gradOutputDTy));
2076+
rewriter, loc, rewriter.getZeroAttr(accumulatorDTy));
20422077

20432078
auto reductionBody = [&](OpBuilder &b, Location loc, ValueRange args) {
20442079
Value x = args[0];
2080+
if (gradOutputDTy != accumulatorDTy)
2081+
x = arith::ExtFOp::create(b, loc, accumulatorDTy, x);
20452082
Value acc = args[1];
20462083
Value sum = arith::AddFOp::create(b, loc, x, acc);
20472084
linalg::YieldOp::create(b, loc, sum);
@@ -2050,6 +2087,11 @@ class ConvertAtenConvolutionBackwardOp
20502087
Value gradBias = torch_to_linalg::createReductionLinalgGeneric(
20512088
rewriter, loc, opInfo, initSum, reductionBody);
20522089

2090+
auto resultType = cast<RankedTensorType>(
2091+
getTypeConverter()->convertType(op->getResult(2).getType()));
2092+
auto resultDTy = resultType.getElementType();
2093+
gradBias = convertFloatAccDtype(gradBias, resultDTy);
2094+
20532095
newResults[2] = tensor::CastOp::create(rewriter, loc,
20542096
getTypeConverter()->convertType(
20552097
op->getResult(2).getType()),

test/Conversion/TorchToLinalg/convolution_bwd.mlir

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,3 +415,134 @@ func.func @convolution_backward_input_1x1x1s_1x0x1p_1x1x1d_1g(%arg0: !torch.vten
415415
}
416416

417417
// -----
418+
419+
// CHECK-LABEL: func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g_bf16(
420+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],bf16>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],bf16>,
421+
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,32,2,2],bf16>) -> (!torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16>) {
422+
func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g_bf16(%arg0: !torch.vtensor<[2,16,33,33],bf16>, %arg1: !torch.vtensor<[2,128,64,64],bf16>, %arg2: !torch.vtensor<[16,32,2,2],bf16>) -> (!torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16>) {
423+
// CHECK-DAG: %[[CST_F32:.*]] = arith.constant 0.000000e+00 : f32
424+
// CHECK-DAG: %[[CST_BF16:.*]] = arith.constant 0.000000e+00 : bf16
425+
// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,128,64,64],bf16> -> tensor<2x128x64x64xbf16>
426+
// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,33,33],bf16> -> tensor<2x16x33x33xbf16>
427+
// CHECK: %[[T0_EXP:.*]] = tensor.expand_shape %[[T0]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 4, 33, 33] : tensor<2x16x33x33xbf16> into tensor<2x4x4x33x33xbf16>
428+
// CHECK: %[[T1_EXP:.*]] = tensor.expand_shape %[[T1]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 32, 64, 64] : tensor<2x128x64x64xbf16> into tensor<2x4x32x64x64xbf16>
429+
// CHECK: %[[PAD:.*]] = tensor.pad %[[T1_EXP]] low[0, 0, 0, 2, 2] high[0, 0, 0, 2, 2]
430+
// CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index):
431+
// CHECK-NEXT: tensor.yield %[[CST_BF16]] : bf16
432+
// CHECK-NEXT: } : tensor<2x4x32x64x64xbf16> to tensor<2x4x32x68x68xbf16>
433+
// CHECK: %[[OUT0_EMPTY:.*]] = tensor.empty() : tensor<4x4x32x2x2xf32>
434+
// CHECK: %[[OUT0_FILLED:.*]] = linalg.fill ins(%[[CST_F32]] : f32) outs(%[[OUT0_EMPTY]] : tensor<4x4x32x2x2xf32>) -> tensor<4x4x32x2x2xf32>
435+
// CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d0, d2, d3 * 2 + d6 * 2, d4 * 2 + d7 * 2)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d0, d1, d6, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PAD]], %[[T0_EXP]] : tensor<2x4x32x68x68xbf16>, tensor<2x4x4x33x33xbf16>) outs(%[[OUT0_FILLED]] : tensor<4x4x32x2x2xf32>) {
436+
// CHECK-NEXT: ^bb0(%[[IN:.*]]: bf16, %[[IN1:.*]]: bf16, %[[OUT:.*]]: f32):
437+
// CHECK-NEXT: %[[EXT0:.*]] = arith.extf %[[IN]] : bf16 to f32
438+
// CHECK-NEXT: %[[EXT1:.*]] = arith.extf %[[IN1]] : bf16 to f32
439+
// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[EXT0]], %[[EXT1]] : f32
440+
// CHECK-NEXT: %[[CONV_RES:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32
441+
// CHECK-NEXT: linalg.yield %[[CONV_RES]] : f32
442+
// CHECK-NEXT: } -> tensor<4x4x32x2x2xf32>
443+
// CHECK: %[[DOWNCAST0:.*]] = linalg.generic
444+
// CHECK-SAME: {indexing_maps = [
445+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>,
446+
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>],
447+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
448+
// CHECK-SAME: ins(%[[CONV]] : tensor<4x4x32x2x2xf32>) outs(%[[ZERO_BF16_INIT:.*]] : tensor<4x4x32x2x2xbf16>) {
449+
// CHECK-NEXT: ^bb0(%[[IN_BBARG:.*]]: f32, %[[OUT_BBARG:.*]]: bf16):
450+
// CHECK-NEXT: %[[TRUNC:.*]] = arith.truncf %[[IN_BBARG]] : f32 to bf16
451+
// CHECK-NEXT: linalg.yield %[[TRUNC]] : bf16
452+
// CHECK-NEXT: } -> tensor<4x4x32x2x2xbf16>
453+
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[DOWNCAST0]] {{\[\[0, 1\], \[2\], \[3\], \[4\]\]}} : tensor<4x4x32x2x2xbf16> into tensor<16x32x2x2xbf16>
454+
// CHECK: %[[WGRAD:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<16x32x2x2xbf16> -> !torch.vtensor<[16,32,2,2],bf16>
455+
// CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32>
456+
// CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST_F32]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32>
457+
// CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x33x33xbf16>) outs(%[[SUM_FILLED]] : tensor<16xf32>) {
458+
// CHECK-NEXT: ^bb0(%[[IN_B:.*]]: bf16, %[[ACC_B:.*]]: f32):
459+
// CHECK-NEXT: %[[B_EXT:.*]] = arith.extf %[[IN_B]] : bf16 to f32
460+
// CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[B_EXT]], %[[ACC_B]] : f32
461+
// CHECK-NEXT: linalg.yield %[[B_RES]] : f32
462+
// CHECK-NEXT: } -> tensor<16xf32>
463+
// CHECK: %[[DOWNCAST1:.*]] = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
464+
// CHECK-SAME: iterator_types = ["parallel"]} ins(%[[SUM_GEN]] : tensor<16xf32>) outs(%[[ZERO_BF16_INIT_1:.*]] : tensor<16xbf16>) {
465+
// CHECK-NEXT: ^bb0(%[[IN_BBARG:.*]]: f32, %[[OUT_BBARG:.*]]: bf16):
466+
// CHECK-NEXT: %[[TRUNC:.*]] = arith.truncf %[[IN_BBARG]] : f32 to bf16
467+
// CHECK-NEXT: linalg.yield %[[TRUNC]] : bf16
468+
// CHECK-NEXT: } -> tensor<16xbf16>
469+
// CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[DOWNCAST1]] : tensor<16xbf16> -> !torch.vtensor<[16],bf16>
470+
// CHECK: return %[[WGRAD]], %[[BIAS]] : !torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16>
471+
%true = torch.constant.bool true
472+
%int0 = torch.constant.int 0
473+
%false = torch.constant.bool false
474+
%int1 = torch.constant.int 1
475+
%int2 = torch.constant.int 2
476+
%int4 = torch.constant.int 4
477+
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
478+
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
479+
%2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
480+
%3 = torch.prim.ListConstruct %false, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
481+
%result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int4, %3 : !torch.vtensor<[2,16,33,33],bf16>, !torch.vtensor<[2,128,64,64],bf16>, !torch.vtensor<[16,32,2,2],bf16>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int, !torch.list<bool> -> !torch.none, !torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16>
482+
return %result1, %result2 : !torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16>
483+
}
484+
485+
// -----
486+
487+
// CHECK-LABEL: func.func @convolution_backward_input_2x2s_2x2p_2x2d_4g_bf16(
488+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],bf16>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],bf16>,
489+
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,32,2,2],bf16>) -> !torch.vtensor<[2,128,64,64],bf16> {
490+
func.func @convolution_backward_input_2x2s_2x2p_2x2d_4g_bf16(%arg0: !torch.vtensor<[2,16,33,33],bf16>, %arg1: !torch.vtensor<[2,128,64,64],bf16>, %arg2: !torch.vtensor<[16,32,2,2],bf16>) -> !torch.vtensor<[2,128,64,64],bf16> {
491+
// CHECK: %[[CST1:.*]] = arith.constant 1 : index
492+
// CHECK: %[[CST0_F32:.*]] = arith.constant 0.000000e+00 : f32
493+
// CHECK: %[[CST0_BF16:.*]] = arith.constant 0.000000e+00 : bf16
494+
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[16,32,2,2],bf16> -> tensor<16x32x2x2xbf16>
495+
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,33,33],bf16> -> tensor<2x16x33x33xbf16>
496+
// CHECK: %[[T0_EXP:.*]] = tensor.expand_shape %[[T0]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 4, 33, 33] : tensor<2x16x33x33xbf16> into tensor<2x4x4x33x33xbf16>
497+
// CHECK: %[[W_EXP:.*]] = tensor.expand_shape %[[T1]] {{\[\[0, 1\], \[2\], \[3\], \[4\]\]}} output_shape [4, 4, 32, 2, 2] : tensor<16x32x2x2xbf16> into tensor<4x4x32x2x2xbf16>
498+
// CHECK: %[[W_EMPTY:.*]] = tensor.empty() : tensor<4x4x32x2x2xbf16>
499+
// CHECK: %[[W_FILLED:.*]] = linalg.fill ins(%[[CST0_BF16]] : bf16) outs(%[[W_EMPTY]] : tensor<4x4x32x2x2xbf16>) -> tensor<4x4x32x2x2xbf16>
500+
// CHECK: %[[W_REV:.*]] = linalg.generic {{.*}} ins(%[[W_EXP]] : tensor<4x4x32x2x2xbf16>) outs(%[[W_FILLED]] : tensor<4x4x32x2x2xbf16>) {
501+
// CHECK-NEXT: ^bb0(%[[IN_W:.*]]: bf16, %[[OUT_W:.*]]: bf16):
502+
// CHECK-NEXT: %[[I0:.*]] = linalg.index 0 : index
503+
// CHECK-NEXT: %[[I1:.*]] = linalg.index 1 : index
504+
// CHECK-NEXT: %[[I2:.*]] = linalg.index 2 : index
505+
// CHECK-NEXT: %[[I3:.*]] = linalg.index 3 : index
506+
// CHECK-NEXT: %[[I4:.*]] = linalg.index 4 : index
507+
// CHECK-NEXT: %[[R3:.*]] = arith.subi %[[CST1]], %[[I3]] : index
508+
// CHECK-NEXT: %[[R4:.*]] = arith.subi %[[CST1]], %[[I4]] : index
509+
// CHECK-NEXT: %[[EX:.*]] = tensor.extract %[[W_EXP]][%[[I0]], %[[I1]], %[[I2]], %[[R3]], %[[R4]]] : tensor<4x4x32x2x2xbf16>
510+
// CHECK-NEXT: linalg.yield %[[EX]] : bf16
511+
// CHECK-NEXT: } -> tensor<4x4x32x2x2xbf16>
512+
// CHECK: %[[SLICE_EMPTY:.*]] = tensor.empty() : tensor<2x4x4x66x66xbf16>
513+
// CHECK: %[[SLICE_FILLED:.*]] = linalg.fill ins(%[[CST_BF16]] : bf16) outs(%[[SLICE_EMPTY]] : tensor<2x4x4x66x66xbf16>) -> tensor<2x4x4x66x66xbf16>
514+
// CHECK: %[[SLICE:.*]] = tensor.insert_slice %[[T0_EXP]] into %[[SLICE_FILLED]][0, 0, 0, 0, 0] [2, 4, 4, 33, 33] [1, 1, 1, 2, 2] : tensor<2x4x4x33x33xbf16> into tensor<2x4x4x66x66xbf16>
515+
// CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<2x4x32x64x64xf32>
516+
// CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0_F32]] : f32) outs(%[[OUT_EMPTY]] : tensor<2x4x32x64x64xf32>) -> tensor<2x4x32x64x64xf32>
517+
// CHECK: %[[CONV_F32:.*]] = linalg.generic {{.*}} ins(%[[SLICE]], %[[W_REV]] : tensor<2x4x4x66x66xbf16>, tensor<4x4x32x2x2xbf16>) outs(%[[OUT_FILLED]] : tensor<2x4x32x64x64xf32>) {
518+
// CHECK-NEXT: ^bb0(%[[IN:.*]]: bf16, %[[IN1:.*]]: bf16, %[[OUT:.*]]: f32):
519+
// CHECK-NEXT: %[[EXT:.*]] = arith.extf %[[IN]] : bf16 to f32
520+
// CHECK-NEXT: %[[EXT1:.*]] = arith.extf %[[IN1]] : bf16 to f32
521+
// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[EXT]], %[[EXT1]] : f32
522+
// CHECK-NEXT: %[[ACC:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32
523+
// CHECK-NEXT: linalg.yield %[[ACC]] : f32
524+
// CHECK-NEXT: } -> tensor<2x4x32x64x64xf32>
525+
// CHECK: %[[EMPTY_BF16:.*]] = tensor.empty() : tensor<2x4x32x64x64xbf16>
526+
// CHECK: %[[CONV_BF16:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[CONV_F32]] : tensor<2x4x32x64x64xf32>) outs(%[[EMPTY_BF16]] : tensor<2x4x32x64x64xbf16>) {
527+
// CHECK: ^bb0(%[[IN_F32:.*]]: f32, %[[OUT_BF16:.*]]: bf16):
528+
// CHECK: %[[TRUNC_BF16:.*]] = arith.truncf %[[IN_F32]] : f32 to bf16
529+
// CHECK: linalg.yield %[[TRUNC_BF16]] : bf16
530+
// CHECK: } -> tensor<2x4x32x64x64xbf16>
531+
// CHECK: %[[CONV_COLLAPSED:.*]] = tensor.collapse_shape %[[CONV_BF16]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} : tensor<2x4x32x64x64xbf16> into tensor<2x128x64x64xbf16>
532+
// CHECK: %[[IGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV_COLLAPSED]] : tensor<2x128x64x64xbf16> -> !torch.vtensor<[2,128,64,64],bf16>
533+
// CHECK: return %[[IGRAD]] : !torch.vtensor<[2,128,64,64],bf16>
534+
%true = torch.constant.bool true
535+
%int0 = torch.constant.int 0
536+
%false = torch.constant.bool false
537+
%int1 = torch.constant.int 1
538+
%int2 = torch.constant.int 2
539+
%int4 = torch.constant.int 4
540+
%0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
541+
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
542+
%2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
543+
%3 = torch.prim.ListConstruct %true, %false, %false : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
544+
%result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int4, %3 : !torch.vtensor<[2,16,33,33],bf16>, !torch.vtensor<[2,128,64,64],bf16>, !torch.vtensor<[16,32,2,2],bf16>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int, !torch.list<bool> -> !torch.vtensor<[2,128,64,64],bf16>, !torch.none, !torch.none
545+
return %result0 : !torch.vtensor<[2,128,64,64],bf16>
546+
}
547+
548+
// -----

0 commit comments

Comments
 (0)