Skip to content

Commit 06eca3f

Browse files
authored
[Torch] Fix decomposition of matmul to bmm (#4404)
This change prevents decomposing `torch.matmul` to `torch.bmm` when the batch dimensions are broadcasted because `torch.bmm` does not support broadcasting. Before this change, the added test case would result in a compilation failure. --------- Signed-off-by: Ian Wood <ianwood@u.northwestern.edu>
1 parent 0844d4d commit 06eca3f

File tree

3 files changed

+64
-4
lines changed

3 files changed

+64
-4
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3222,7 +3222,27 @@ class DecomposeAtenMatmulOp : public OpRewritePattern<AtenMatmulOp> {
32223222
// If both lhs and rhs ranks are 2 then map it to `aten.mm` op.
32233223
rewriter.replaceOpWithNewOp<AtenMmOp>(op, op.getType(), lhs, rhs);
32243224
} else if (lhsRank == 3 && rhsRank == 3) {
3225-
// If both lhs and rhs ranks are 3 then map it to `aten.bmm` op.
3225+
// If both lhs and rhs ranks are 3, we can only map it to `aten.bmm` op
3226+
// if the batch dimensions are equal (since bmm doesn't support
3227+
// broadcasting).
3228+
auto lhsType = cast<BaseTensorType>(lhs.getType());
3229+
auto rhsType = cast<BaseTensorType>(rhs.getType());
3230+
3231+
if (!lhsType.hasSizes() || !rhsType.hasSizes())
3232+
return failure();
3233+
3234+
ArrayRef<int64_t> lhsShape = lhsType.getSizes();
3235+
ArrayRef<int64_t> rhsShape = rhsType.getSizes();
3236+
int64_t lhsBatchDim = lhsShape[0];
3237+
int64_t rhsBatchDim = rhsShape[0];
3238+
3239+
// Batch dimensions must be statically known and equal for bmm.
3240+
// Dynamic dimensions (kUnknownSize) or unequal dimensions require the
3241+
// general matmul lowering which handles broadcasting.
3242+
if (lhsBatchDim == kUnknownSize || rhsBatchDim == kUnknownSize ||
3243+
lhsBatchDim != rhsBatchDim)
3244+
return failure();
3245+
32263246
rewriter.replaceOpWithNewOp<AtenBmmOp>(op, op.getType(), lhs, rhs);
32273247
} else {
32283248
return failure();

projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,30 @@ def Matmul_3d(module, tu: TestUtils):
156156
# ==============================================================================
157157

158158

159+
class Matmul3DStaticBroadcast(torch.nn.Module):
160+
def __init__(self):
161+
super().__init__()
162+
163+
@export
164+
@annotate_args(
165+
[
166+
None,
167+
([4, 8, 5], torch.float32, True),
168+
([1, 5, 6], torch.float32, True),
169+
]
170+
)
171+
def forward(self, lhs, rhs):
172+
return torch.matmul(lhs, rhs)
173+
174+
175+
@register_test_case(module_factory=lambda: Matmul3DStaticBroadcast())
176+
def Matmul3DStaticBroadcast_basic(module, tu: TestUtils):
177+
module.forward(tu.rand(4, 8, 5), tu.rand(1, 5, 6))
178+
179+
180+
# ==============================================================================
181+
182+
159183
class Matmul4d(torch.nn.Module):
160184
def __init__(self):
161185
super().__init__()

test/Dialect/Torch/decompose-complex-ops.mlir

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,29 @@ func.func @matmul_decompose_2d(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v
1818
}
1919

2020
// -----
21-
// CHECK-LABEL: func.func @matmul_decompose_3d(
22-
// CHECK: torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
23-
func.func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
21+
// CHECK-LABEL: func.func @matmul_no_decompose_3d_dynamic(
22+
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
23+
func.func @matmul_no_decompose_3d_dynamic(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
2424
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32> -> !torch.tensor
2525
return %0 : !torch.tensor
2626
}
2727

28+
// -----
29+
// CHECK-LABEL: func.func @matmul_decompose_3d_static(
30+
// CHECK: torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[4,?,?],f32>, !torch.vtensor<[4,?,?],f32> -> !torch.tensor
31+
func.func @matmul_decompose_3d_static(%arg0: !torch.vtensor<[4,?,?],f32>, %arg1: !torch.vtensor<[4,?,?],f32>) -> !torch.tensor {
32+
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[4,?,?],f32>, !torch.vtensor<[4,?,?],f32> -> !torch.tensor
33+
return %0 : !torch.tensor
34+
}
35+
36+
// -----
37+
// CHECK-LABEL: func.func @matmul_no_decompose_3d_broadcast(
38+
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[4,?,?],f32>, !torch.vtensor<[1,?,?],f32> -> !torch.tensor
39+
func.func @matmul_no_decompose_3d_broadcast(%arg0: !torch.vtensor<[4,?,?],f32>, %arg1: !torch.vtensor<[1,?,?],f32>) -> !torch.tensor {
40+
%0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[4,?,?],f32>, !torch.vtensor<[1,?,?],f32> -> !torch.tensor
41+
return %0 : !torch.tensor
42+
}
43+
2844
// -----
2945
// CHECK-LABEL: func.func @argmax_rank_1
3046
// CHECK: %[[I0:.*]] = torch.constant.int 0

0 commit comments

Comments
 (0)