@@ -415,3 +415,134 @@ func.func @convolution_backward_input_1x1x1s_1x0x1p_1x1x1d_1g(%arg0: !torch.vten
415415}
416416
417417// -----
418+
419+ // CHECK-LABEL: func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g_bf16(
420+ // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],bf16>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],bf16>,
421+ // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,32,2,2],bf16>) -> (!torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16>) {
422+ func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g_bf16 (%arg0: !torch.vtensor <[2 ,16 ,33 ,33 ],bf16 >, %arg1: !torch.vtensor <[2 ,128 ,64 ,64 ],bf16 >, %arg2: !torch.vtensor <[16 ,32 ,2 ,2 ],bf16 >) -> (!torch.vtensor <[16 ,32 ,2 ,2 ],bf16 >, !torch.vtensor <[16 ],bf16 >) {
423+ // CHECK-DAG: %[[CST_F32:.*]] = arith.constant 0.000000e+00 : f32
424+ // CHECK-DAG: %[[CST_BF16:.*]] = arith.constant 0.000000e+00 : bf16
425+ // CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,128,64,64],bf16> -> tensor<2x128x64x64xbf16>
426+ // CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,33,33],bf16> -> tensor<2x16x33x33xbf16>
427+ // CHECK: %[[T0_EXP:.*]] = tensor.expand_shape %[[T0]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 4, 33, 33] : tensor<2x16x33x33xbf16> into tensor<2x4x4x33x33xbf16>
428+ // CHECK: %[[T1_EXP:.*]] = tensor.expand_shape %[[T1]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 32, 64, 64] : tensor<2x128x64x64xbf16> into tensor<2x4x32x64x64xbf16>
429+ // CHECK: %[[PAD:.*]] = tensor.pad %[[T1_EXP]] low[0, 0, 0, 2, 2] high[0, 0, 0, 2, 2]
430+ // CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index):
431+ // CHECK-NEXT: tensor.yield %[[CST_BF16]] : bf16
432+ // CHECK-NEXT: } : tensor<2x4x32x64x64xbf16> to tensor<2x4x32x68x68xbf16>
433+ // CHECK: %[[OUT0_EMPTY:.*]] = tensor.empty() : tensor<4x4x32x2x2xf32>
434+ // CHECK: %[[OUT0_FILLED:.*]] = linalg.fill ins(%[[CST_F32]] : f32) outs(%[[OUT0_EMPTY]] : tensor<4x4x32x2x2xf32>) -> tensor<4x4x32x2x2xf32>
435+ // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d0, d2, d3 * 2 + d6 * 2, d4 * 2 + d7 * 2)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d0, d1, d6, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PAD]], %[[T0_EXP]] : tensor<2x4x32x68x68xbf16>, tensor<2x4x4x33x33xbf16>) outs(%[[OUT0_FILLED]] : tensor<4x4x32x2x2xf32>) {
436+ // CHECK-NEXT: ^bb0(%[[IN:.*]]: bf16, %[[IN1:.*]]: bf16, %[[OUT:.*]]: f32):
437+ // CHECK-NEXT: %[[EXT0:.*]] = arith.extf %[[IN]] : bf16 to f32
438+ // CHECK-NEXT: %[[EXT1:.*]] = arith.extf %[[IN1]] : bf16 to f32
439+ // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[EXT0]], %[[EXT1]] : f32
440+ // CHECK-NEXT: %[[CONV_RES:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32
441+ // CHECK-NEXT: linalg.yield %[[CONV_RES]] : f32
442+ // CHECK-NEXT: } -> tensor<4x4x32x2x2xf32>
443+ // CHECK: %[[DOWNCAST0:.*]] = linalg.generic
444+ // CHECK-SAME: {indexing_maps = [
445+ // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>,
446+ // CHECK-SAME: affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>],
447+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
448+ // CHECK-SAME: ins(%[[CONV]] : tensor<4x4x32x2x2xf32>) outs(%[[ZERO_BF16_INIT:.*]] : tensor<4x4x32x2x2xbf16>) {
449+ // CHECK-NEXT: ^bb0(%[[IN_BBARG:.*]]: f32, %[[OUT_BBARG:.*]]: bf16):
450+ // CHECK-NEXT: %[[TRUNC:.*]] = arith.truncf %[[IN_BBARG]] : f32 to bf16
451+ // CHECK-NEXT: linalg.yield %[[TRUNC]] : bf16
452+ // CHECK-NEXT: } -> tensor<4x4x32x2x2xbf16>
453+ // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[DOWNCAST0]] {{\[\[0, 1\], \[2\], \[3\], \[4\]\]}} : tensor<4x4x32x2x2xbf16> into tensor<16x32x2x2xbf16>
454+ // CHECK: %[[WGRAD:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<16x32x2x2xbf16> -> !torch.vtensor<[16,32,2,2],bf16>
455+ // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32>
456+ // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST_F32]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32>
457+ // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x33x33xbf16>) outs(%[[SUM_FILLED]] : tensor<16xf32>) {
458+ // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: bf16, %[[ACC_B:.*]]: f32):
459+ // CHECK-NEXT: %[[B_EXT:.*]] = arith.extf %[[IN_B]] : bf16 to f32
460+ // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[B_EXT]], %[[ACC_B]] : f32
461+ // CHECK-NEXT: linalg.yield %[[B_RES]] : f32
462+ // CHECK-NEXT: } -> tensor<16xf32>
463+ // CHECK: %[[DOWNCAST1:.*]] = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
464+ // CHECK-SAME: iterator_types = ["parallel"]} ins(%[[SUM_GEN]] : tensor<16xf32>) outs(%[[ZERO_BF16_INIT_1:.*]] : tensor<16xbf16>) {
465+ // CHECK-NEXT: ^bb0(%[[IN_BBARG:.*]]: f32, %[[OUT_BBARG:.*]]: bf16):
466+ // CHECK-NEXT: %[[TRUNC:.*]] = arith.truncf %[[IN_BBARG]] : f32 to bf16
467+ // CHECK-NEXT: linalg.yield %[[TRUNC]] : bf16
468+ // CHECK-NEXT: } -> tensor<16xbf16>
469+ // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[DOWNCAST1]] : tensor<16xbf16> -> !torch.vtensor<[16],bf16>
470+ // CHECK: return %[[WGRAD]], %[[BIAS]] : !torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16>
471+ %true = torch.constant.bool true
472+ %int0 = torch.constant.int 0
473+ %false = torch.constant.bool false
474+ %int1 = torch.constant.int 1
475+ %int2 = torch.constant.int 2
476+ %int4 = torch.constant.int 4
477+ %0 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
478+ %1 = torch.prim.ListConstruct %int2 , %int2 : (!torch.int , !torch.int ) -> !torch.list <int >
479+ %2 = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <int >
480+ %3 = torch.prim.ListConstruct %false , %true , %true : (!torch.bool , !torch.bool , !torch.bool ) -> !torch.list <bool >
481+ %result0 , %result1 , %result2 = torch.aten.convolution_backward %arg0 , %arg1 , %arg2 , %0 , %1 , %1 , %1 , %false , %2 , %int4 , %3 : !torch.vtensor <[2 ,16 ,33 ,33 ],bf16 >, !torch.vtensor <[2 ,128 ,64 ,64 ],bf16 >, !torch.vtensor <[16 ,32 ,2 ,2 ],bf16 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int , !torch.list <bool > -> !torch.none , !torch.vtensor <[16 ,32 ,2 ,2 ],bf16 >, !torch.vtensor <[16 ],bf16 >
482+ return %result1 , %result2 : !torch.vtensor <[16 ,32 ,2 ,2 ],bf16 >, !torch.vtensor <[16 ],bf16 >
483+ }
484+
485+ // -----
486+
487+ // CHECK-LABEL: func.func @convolution_backward_input_2x2s_2x2p_2x2d_4g_bf16(
488+ // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],bf16>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],bf16>,
489+ // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,32,2,2],bf16>) -> !torch.vtensor<[2,128,64,64],bf16> {
490+ func.func @convolution_backward_input_2x2s_2x2p_2x2d_4g_bf16 (%arg0: !torch.vtensor <[2 ,16 ,33 ,33 ],bf16 >, %arg1: !torch.vtensor <[2 ,128 ,64 ,64 ],bf16 >, %arg2: !torch.vtensor <[16 ,32 ,2 ,2 ],bf16 >) -> !torch.vtensor <[2 ,128 ,64 ,64 ],bf16 > {
491+ // CHECK: %[[CST1:.*]] = arith.constant 1 : index
492+ // CHECK: %[[CST0_F32:.*]] = arith.constant 0.000000e+00 : f32
493+ // CHECK: %[[CST0_BF16:.*]] = arith.constant 0.000000e+00 : bf16
494+ // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[16,32,2,2],bf16> -> tensor<16x32x2x2xbf16>
495+ // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,33,33],bf16> -> tensor<2x16x33x33xbf16>
496+ // CHECK: %[[T0_EXP:.*]] = tensor.expand_shape %[[T0]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 4, 33, 33] : tensor<2x16x33x33xbf16> into tensor<2x4x4x33x33xbf16>
497+ // CHECK: %[[W_EXP:.*]] = tensor.expand_shape %[[T1]] {{\[\[0, 1\], \[2\], \[3\], \[4\]\]}} output_shape [4, 4, 32, 2, 2] : tensor<16x32x2x2xbf16> into tensor<4x4x32x2x2xbf16>
498+ // CHECK: %[[W_EMPTY:.*]] = tensor.empty() : tensor<4x4x32x2x2xbf16>
499+ // CHECK: %[[W_FILLED:.*]] = linalg.fill ins(%[[CST0_BF16]] : bf16) outs(%[[W_EMPTY]] : tensor<4x4x32x2x2xbf16>) -> tensor<4x4x32x2x2xbf16>
500+ // CHECK: %[[W_REV:.*]] = linalg.generic {{.*}} ins(%[[W_EXP]] : tensor<4x4x32x2x2xbf16>) outs(%[[W_FILLED]] : tensor<4x4x32x2x2xbf16>) {
501+ // CHECK-NEXT: ^bb0(%[[IN_W:.*]]: bf16, %[[OUT_W:.*]]: bf16):
502+ // CHECK-NEXT: %[[I0:.*]] = linalg.index 0 : index
503+ // CHECK-NEXT: %[[I1:.*]] = linalg.index 1 : index
504+ // CHECK-NEXT: %[[I2:.*]] = linalg.index 2 : index
505+ // CHECK-NEXT: %[[I3:.*]] = linalg.index 3 : index
506+ // CHECK-NEXT: %[[I4:.*]] = linalg.index 4 : index
507+ // CHECK-NEXT: %[[R3:.*]] = arith.subi %[[CST1]], %[[I3]] : index
508+ // CHECK-NEXT: %[[R4:.*]] = arith.subi %[[CST1]], %[[I4]] : index
509+ // CHECK-NEXT: %[[EX:.*]] = tensor.extract %[[W_EXP]][%[[I0]], %[[I1]], %[[I2]], %[[R3]], %[[R4]]] : tensor<4x4x32x2x2xbf16>
510+ // CHECK-NEXT: linalg.yield %[[EX]] : bf16
511+ // CHECK-NEXT: } -> tensor<4x4x32x2x2xbf16>
512+ // CHECK: %[[SLICE_EMPTY:.*]] = tensor.empty() : tensor<2x4x4x66x66xbf16>
513+ // CHECK: %[[SLICE_FILLED:.*]] = linalg.fill ins(%[[CST_BF16]] : bf16) outs(%[[SLICE_EMPTY]] : tensor<2x4x4x66x66xbf16>) -> tensor<2x4x4x66x66xbf16>
514+ // CHECK: %[[SLICE:.*]] = tensor.insert_slice %[[T0_EXP]] into %[[SLICE_FILLED]][0, 0, 0, 0, 0] [2, 4, 4, 33, 33] [1, 1, 1, 2, 2] : tensor<2x4x4x33x33xbf16> into tensor<2x4x4x66x66xbf16>
515+ // CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<2x4x32x64x64xf32>
516+ // CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0_F32]] : f32) outs(%[[OUT_EMPTY]] : tensor<2x4x32x64x64xf32>) -> tensor<2x4x32x64x64xf32>
517+ // CHECK: %[[CONV_F32:.*]] = linalg.generic {{.*}} ins(%[[SLICE]], %[[W_REV]] : tensor<2x4x4x66x66xbf16>, tensor<4x4x32x2x2xbf16>) outs(%[[OUT_FILLED]] : tensor<2x4x32x64x64xf32>) {
518+ // CHECK-NEXT: ^bb0(%[[IN:.*]]: bf16, %[[IN1:.*]]: bf16, %[[OUT:.*]]: f32):
519+ // CHECK-NEXT: %[[EXT:.*]] = arith.extf %[[IN]] : bf16 to f32
520+ // CHECK-NEXT: %[[EXT1:.*]] = arith.extf %[[IN1]] : bf16 to f32
521+ // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[EXT]], %[[EXT1]] : f32
522+ // CHECK-NEXT: %[[ACC:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32
523+ // CHECK-NEXT: linalg.yield %[[ACC]] : f32
524+ // CHECK-NEXT: } -> tensor<2x4x32x64x64xf32>
525+ // CHECK: %[[EMPTY_BF16:.*]] = tensor.empty() : tensor<2x4x32x64x64xbf16>
526+ // CHECK: %[[CONV_BF16:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[CONV_F32]] : tensor<2x4x32x64x64xf32>) outs(%[[EMPTY_BF16]] : tensor<2x4x32x64x64xbf16>) {
527+ // CHECK: ^bb0(%[[IN_F32:.*]]: f32, %[[OUT_BF16:.*]]: bf16):
528+ // CHECK: %[[TRUNC_BF16:.*]] = arith.truncf %[[IN_F32]] : f32 to bf16
529+ // CHECK: linalg.yield %[[TRUNC_BF16]] : bf16
530+ // CHECK: } -> tensor<2x4x32x64x64xbf16>
531+ // CHECK: %[[CONV_COLLAPSED:.*]] = tensor.collapse_shape %[[CONV_BF16]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} : tensor<2x4x32x64x64xbf16> into tensor<2x128x64x64xbf16>
532+ // CHECK: %[[IGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV_COLLAPSED]] : tensor<2x128x64x64xbf16> -> !torch.vtensor<[2,128,64,64],bf16>
533+ // CHECK: return %[[IGRAD]] : !torch.vtensor<[2,128,64,64],bf16>
534+ %true = torch.constant.bool true
535+ %int0 = torch.constant.int 0
536+ %false = torch.constant.bool false
537+ %int1 = torch.constant.int 1
538+ %int2 = torch.constant.int 2
539+ %int4 = torch.constant.int 4
540+ %0 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
541+ %1 = torch.prim.ListConstruct %int2 , %int2 : (!torch.int , !torch.int ) -> !torch.list <int >
542+ %2 = torch.prim.ListConstruct %int0 , %int0 : (!torch.int , !torch.int ) -> !torch.list <int >
543+ %3 = torch.prim.ListConstruct %true , %false , %false : (!torch.bool , !torch.bool , !torch.bool ) -> !torch.list <bool >
544+ %result0 , %result1 , %result2 = torch.aten.convolution_backward %arg0 , %arg1 , %arg2 , %0 , %1 , %1 , %1 , %false , %2 , %int4 , %3 : !torch.vtensor <[2 ,16 ,33 ,33 ],bf16 >, !torch.vtensor <[2 ,128 ,64 ,64 ],bf16 >, !torch.vtensor <[16 ,32 ,2 ,2 ],bf16 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int , !torch.list <bool > -> !torch.vtensor <[2 ,128 ,64 ,64 ],bf16 >, !torch.none , !torch.none
545+ return %result0 : !torch.vtensor <[2 ,128 ,64 ,64 ],bf16 >
546+ }
547+
548+ // -----
0 commit comments