Skip to content

Commit d20c5c4

Browse files
committed
fix load after write from buffer
1 parent 5bf5657 commit d20c5c4

File tree

6 files changed

+26
-7
lines changed

6 files changed

+26
-7
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -846,7 +846,7 @@ struct LoadOfToBuffer : public OpRewritePattern<memref::LoadOp> {
846846
LogicalResult matchAndRewrite(memref::LoadOp load,
847847
PatternRewriter &rewriter) const override {
848848
auto toBuffer = load.getMemref().getDefiningOp<ToBufferOp>();
849-
if (!toBuffer)
849+
if (!toBuffer || !toBuffer.getReadOnly())
850850
return failure();
851851

852852
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toBuffer.getTensor(),

mlir/test/Dialect/Bufferization/canonicalize.mlir

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,29 @@ func.func @tensor_cast_to_buffer_layout_and_memspace(%arg0 : tensor<4x6x16x32xi8
294294

295295
// -----
296296

297+
// Verify LoadOfToBuffer skips writable buffers
298+
// CHECK-LABEL: func @load_after_write_from_buffer_cast(
299+
func.func @load_after_write_from_buffer_cast(%arg0: index, %arg1: index,
300+
%arg2: tensor<?x?xf32>) -> f32 {
301+
%0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
302+
linalg.ceil ins(%0 : memref<?x?xf32>) outs(%0 : memref<?x?xf32>)
303+
%1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
304+
return %1 : f32
305+
}
306+
// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
307+
// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
308+
// CHECK: %[[M:.+]] = bufferization.to_buffer %[[TENSOR]] : tensor<?x?xf32> to memref<?x?xf32>
309+
// CHECK: linalg.ceil ins(%[[M]] : memref<?x?xf32>) outs(%[[M]] : memref<?x?xf32>)
310+
// CHECK: %[[RES:.*]] = memref.load %[[M]][%[[IDX0]], %[[IDX1]]] : memref<?x?xf32>
311+
// CHECK: return %[[RES]] : f32
312+
313+
// -----
314+
297315
// Folding of memref.load(to_buffer(%v, %idxs)) -> tensor.extract(%v, %idx)
298316
// CHECK-LABEL: func @load_from_buffer_cast(
299317
func.func @load_from_buffer_cast(%arg0: index, %arg1: index,
300318
%arg2: tensor<?x?xf32>) -> f32 {
301-
%0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
319+
%0 = bufferization.to_buffer %arg2 read_only : tensor<?x?xf32> to memref<?x?xf32>
302320
%1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
303321
return %1 : f32
304322
}

mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
// CHECK-DAG: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
3232
// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf32, #sparse> to memref<?xf32>
3333
// CHECK-DAG: %[[VAL_14:.*]] = bufferization.to_buffer %[[VAL_10]] :
34+
// CHECK-DAG: %[[WTF:.*]] = bufferization.to_buffer %[[VAL_1]] :
3435
// CHECK-DAG: linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_14]] : memref<8x8xf32>)
3536
// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_5]] {
3637
// CHECK: %[[VAL_16:.*]] = arith.subi %[[VAL_15]], %[[VAL_7]] : index
@@ -49,7 +50,7 @@
4950
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex>
5051
// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_7]] : index
5152
// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref<?xf32>
52-
// CHECK: %[[VAL_29:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : tensor<8x8xf32>
53+
// CHECK: %[[VAL_29:.*]] = memref.load %[[WTF]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
5354
// CHECK: %[[VAL_30:.*]] = arith.mulf %[[VAL_28]], %[[VAL_29]] : f32
5455
// CHECK: memref.store %[[VAL_30]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
5556
// CHECK: } {"Emitted from" = "linalg.generic"}

mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
// CHECK: } do {
3232
// CHECK: %[[D3:.*]] = "subsect<trivial<compressed[0,1]>>.deref"
3333
// CHECK: "trivial<batch[1,1]>.locate"(%{{.*}}, %[[D3]])
34-
// CHECK: tensor.extract %{{.*}}{{\[}}%[[D2]], %[[D3]]]
34+
// CHECK: memref.load %{{.*}}{{\[}}%[[D2]], %[[D3]]]
3535
// CHECK: arith.muli
3636
// CHECK: arith.addi
3737
// CHECK: "subsect<trivial<compressed[0,1]>>.next

mlir/test/Dialect/SparseTensor/sparse_pack.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
// CHECK: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.init
2323
// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_13]] lvl_sz at 0 with %[[VAL_4]]
2424
// CHECK: %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] pos_mem_sz at 0 with %[[VAL_3]]
25-
// CHECK: %[[VAL_16:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_5]]] : tensor<2xindex>
25+
// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<2xindex>
2626
// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %[[VAL_3]] : index
2727
// CHECK: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_15]] crd_mem_sz at 0 with %[[VAL_17]]
2828
// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] lvl_sz at 1 with %[[VAL_4]]

mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
// CHECK-HIR-DAG: %[[VAL_7:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
2828
// CHECK-HIR-DAG: %[[VAL_8:.*]] = sparse_tensor.values %[[DEMAP]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
2929
// CHECK-HIR-DAG: %[[VAL_10:.*]] = bufferization.to_buffer %[[VAL_1]] : tensor<f32> to memref<f32>
30-
// CHECK-HIR: %[[VAL_11:.*]] = tensor.extract %[[VAL_1]][] : tensor<f32>
30+
// CHECK-HIR: %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
3131
// CHECK-HIR: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
3232
// CHECK-HIR: %[[VAL_18:.*]] = arith.muli %[[VAL_13]], %[[VAL_6]] : index
3333
// CHECK-HIR: %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
@@ -59,7 +59,7 @@
5959
// CHECK-MIR-DAG: %[[DimSize2:.*]] = call @sparseLvlSize(%[[ARGA]], %[[I2]])
6060
// CHECK-MIR-DAG: %[[VAL_8:.*]] = call @sparseValuesF32(%[[ARGA]]) : (!llvm.ptr) -> memref<?xf32>
6161
// CHECK-MIR-DAG: %[[VAL_10:.*]] = bufferization.to_buffer %[[ARGX]] : tensor<f32> to memref<f32>
62-
// CHECK-MIR: %[[VAL_11:.*]] = tensor.extract %[[ARGX]][] : tensor<f32>
62+
// CHECK-MIR: %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
6363
// CHECK-MIR: %[[VAL_12:.*]] = scf.for %[[D2:.*]] = %[[I0]] to %[[DimSize0]] step %[[I1]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
6464
// CHECK-MIR: %[[VAL_18:.*]] = arith.muli %[[D2]], %[[DimSize1]] : index
6565
// CHECK-MIR: %[[VAL_15:.*]] = scf.for %[[D0:.*]] = %[[I0]] to %[[DimSize1]] step %[[I1]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {

0 commit comments

Comments
 (0)