Skip to content

Conversation

@batzor
Copy link
Contributor

@batzor batzor commented Dec 17, 2025

When we memref.load from a buffer, it folded to tensor.extract even when the buffer was writable, causing unexpected results. For example:

func.func @load_after_write_from_buffer_cast(%arg0: index, %arg1: index,
                            %arg2: tensor<?x?xf32>) -> f32 {
  %0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
  linalg.ceil ins(%0 : memref<?x?xf32>) outs(%0 : memref<?x?xf32>)
  %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
  return %1 : f32
}

would fold into

module {
  func.func @load_after_write_from_buffer_cast(%arg0: index, %arg1: index, %arg2: tensor<?x?xf32>) -> f32 {
    %0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
    linalg.ceil ins(%0 : memref<?x?xf32>) outs(%0 : memref<?x?xf32>)
    %extracted = tensor.extract %arg2[%arg0, %arg1] : tensor<?x?xf32>
    return %extracted : f32
  }
}

@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir mlir:bufferization Bufferization infrastructure labels Dec 17, 2025
@llvmbot
Copy link
Member

llvmbot commented Dec 17, 2025

@llvm/pr-subscribers-mlir

Author: Batzorig Zorigoo (batzor)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/172595.diff

6 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/canonicalize.mlir (+19-1)
  • (modified) mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir (+2-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_pack.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir (+2-2)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 56ff2121e4620..6bf81a2727204 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -846,7 +846,7 @@ struct LoadOfToBuffer : public OpRewritePattern<memref::LoadOp> {
   LogicalResult matchAndRewrite(memref::LoadOp load,
                                 PatternRewriter &rewriter) const override {
     auto toBuffer = load.getMemref().getDefiningOp<ToBufferOp>();
-    if (!toBuffer)
+    if (!toBuffer || !toBuffer.getReadOnly())
       return failure();
 
     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toBuffer.getTensor(),
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index ae1d1fcfc19dc..df07511798b91 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -294,11 +294,29 @@ func.func @tensor_cast_to_buffer_layout_and_memspace(%arg0 : tensor<4x6x16x32xi8
 
 // -----
 
+// Verify LoadOfToBuffer skips writable buffers
+// CHECK-LABEL: func @load_after_write_from_buffer_cast(
+func.func @load_after_write_from_buffer_cast(%arg0: index, %arg1: index,
+                            %arg2: tensor<?x?xf32>) -> f32 {
+  %0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
+  linalg.ceil ins(%0 : memref<?x?xf32>) outs(%0 : memref<?x?xf32>)
+  %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
+  return %1 : f32
+}
+// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
+// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
+//      CHECK: %[[M:.+]] = bufferization.to_buffer %[[TENSOR]] : tensor<?x?xf32> to memref<?x?xf32>
+//      CHECK: linalg.ceil ins(%[[M]] : memref<?x?xf32>) outs(%[[M]] : memref<?x?xf32>)
+//      CHECK: %[[RES:.*]] = memref.load %[[M]][%[[IDX0]], %[[IDX1]]] : memref<?x?xf32>
+//      CHECK: return %[[RES]] : f32
+
+// -----
+
 // Folding of memref.load(to_buffer(%v, %idxs)) -> tensor.extract(%v, %idx)
 // CHECK-LABEL: func @load_from_buffer_cast(
 func.func @load_from_buffer_cast(%arg0: index, %arg1: index,
                             %arg2: tensor<?x?xf32>) -> f32 {
-  %0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
+  %0 = bufferization.to_buffer %arg2 read_only : tensor<?x?xf32> to memref<?x?xf32>
   %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
   return %1 : f32
 }
diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
index d828afe13c622..78a0d4fedf690 100644
--- a/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
@@ -31,6 +31,7 @@
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf32, #sparse> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_buffer %[[VAL_10]] :
+// CHECK-DAG:       %[[WTF:.*]] = bufferization.to_buffer %[[VAL_1]] :
 // CHECK-DAG:       linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_14]] : memref<8x8xf32>)
 // CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_5]] {
 // CHECK:             %[[VAL_16:.*]] = arith.subi %[[VAL_15]], %[[VAL_7]] : index
@@ -49,7 +50,7 @@
 // CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex>
 // CHECK:               %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_7]] : index
 // CHECK:               %[[VAL_28:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref<?xf32>
-// CHECK:               %[[VAL_29:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : tensor<8x8xf32>
+// CHECK:               %[[VAL_29:.*]] = memref.load %[[WTF]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
 // CHECK:               %[[VAL_30:.*]] = arith.mulf %[[VAL_28]], %[[VAL_29]] : f32
 // CHECK:               memref.store %[[VAL_30]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
 // CHECK:             } {"Emitted from" = "linalg.generic"}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index bf3473ead204e..912f78a0b81fc 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -31,7 +31,7 @@
 // CHECK:                 } do {
 // CHECK:                   %[[D3:.*]] = "subsect<trivial<compressed[0,1]>>.deref"
 // CHECK:                   "trivial<batch[1,1]>.locate"(%{{.*}}, %[[D3]])
-// CHECK:                   tensor.extract %{{.*}}{{\[}}%[[D2]], %[[D3]]]
+// CHECK:                   memref.load %{{.*}}{{\[}}%[[D2]], %[[D3]]]
 // CHECK:                   arith.muli
 // CHECK:                   arith.addi
 // CHECK:                   "subsect<trivial<compressed[0,1]>>.next
diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 4546d3367b16d..ebbcc5fc7c7cf 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -22,7 +22,7 @@
 // CHECK:           %[[VAL_13:.*]] = sparse_tensor.storage_specifier.init
 // CHECK:           %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_13]]  lvl_sz at 0 with %[[VAL_4]]
 // CHECK:           %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]]  pos_mem_sz at 0 with %[[VAL_3]]
-// CHECK:           %[[VAL_16:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_5]]] : tensor<2xindex>
+// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<2xindex>
 // CHECK:           %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %[[VAL_3]] : index
 // CHECK:           %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_15]]  crd_mem_sz at 0 with %[[VAL_17]]
 // CHECK:           %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]]  lvl_sz at 1 with %[[VAL_4]]
diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
index 4abaf03dff50f..e2c841b1ac7d5 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
@@ -27,7 +27,7 @@
 // CHECK-HIR-DAG:       %[[VAL_7:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
 // CHECK-HIR-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[DEMAP]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
 // CHECK-HIR-DAG:       %[[VAL_10:.*]] = bufferization.to_buffer %[[VAL_1]] : tensor<f32> to memref<f32>
-// CHECK-HIR:           %[[VAL_11:.*]] = tensor.extract %[[VAL_1]][] : tensor<f32>
+// CHECK-HIR:           %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
 // CHECK-HIR:           %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
 // CHECK-HIR:             %[[VAL_18:.*]] = arith.muli %[[VAL_13]], %[[VAL_6]] : index
 // CHECK-HIR:             %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
@@ -59,7 +59,7 @@
 // CHECK-MIR-DAG:       %[[DimSize2:.*]] = call @sparseLvlSize(%[[ARGA]], %[[I2]])
 // CHECK-MIR-DAG:       %[[VAL_8:.*]] = call @sparseValuesF32(%[[ARGA]]) : (!llvm.ptr) -> memref<?xf32>
 // CHECK-MIR-DAG:       %[[VAL_10:.*]] = bufferization.to_buffer %[[ARGX]] : tensor<f32> to memref<f32>
-// CHECK-MIR:           %[[VAL_11:.*]] = tensor.extract %[[ARGX]][] : tensor<f32>
+// CHECK-MIR:           %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
 // CHECK-MIR:           %[[VAL_12:.*]] = scf.for %[[D2:.*]] = %[[I0]] to %[[DimSize0]] step %[[I1]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
 // CHECK-MIR:             %[[VAL_18:.*]] = arith.muli %[[D2]], %[[DimSize1]] : index
 // CHECK-MIR:             %[[VAL_15:.*]] = scf.for %[[D0:.*]] = %[[I0]] to %[[DimSize1]] step %[[I1]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {

@llvmbot
Copy link
Member

llvmbot commented Dec 17, 2025

@llvm/pr-subscribers-mlir-sparse

Author: Batzorig Zorigoo (batzor)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/172595.diff

6 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/canonicalize.mlir (+19-1)
  • (modified) mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir (+2-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_pack.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir (+2-2)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 56ff2121e4620..6bf81a2727204 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -846,7 +846,7 @@ struct LoadOfToBuffer : public OpRewritePattern<memref::LoadOp> {
   LogicalResult matchAndRewrite(memref::LoadOp load,
                                 PatternRewriter &rewriter) const override {
     auto toBuffer = load.getMemref().getDefiningOp<ToBufferOp>();
-    if (!toBuffer)
+    if (!toBuffer || !toBuffer.getReadOnly())
       return failure();
 
     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toBuffer.getTensor(),
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index ae1d1fcfc19dc..df07511798b91 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -294,11 +294,29 @@ func.func @tensor_cast_to_buffer_layout_and_memspace(%arg0 : tensor<4x6x16x32xi8
 
 // -----
 
+// Verify LoadOfToBuffer skips writable buffers
+// CHECK-LABEL: func @load_after_write_from_buffer_cast(
+func.func @load_after_write_from_buffer_cast(%arg0: index, %arg1: index,
+                            %arg2: tensor<?x?xf32>) -> f32 {
+  %0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
+  linalg.ceil ins(%0 : memref<?x?xf32>) outs(%0 : memref<?x?xf32>)
+  %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
+  return %1 : f32
+}
+// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
+// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
+//      CHECK: %[[M:.+]] = bufferization.to_buffer %[[TENSOR]] : tensor<?x?xf32> to memref<?x?xf32>
+//      CHECK: linalg.ceil ins(%[[M]] : memref<?x?xf32>) outs(%[[M]] : memref<?x?xf32>)
+//      CHECK: %[[RES:.*]] = memref.load %[[M]][%[[IDX0]], %[[IDX1]]] : memref<?x?xf32>
+//      CHECK: return %[[RES]] : f32
+
+// -----
+
 // Folding of memref.load(to_buffer(%v, %idxs)) -> tensor.extract(%v, %idx)
 // CHECK-LABEL: func @load_from_buffer_cast(
 func.func @load_from_buffer_cast(%arg0: index, %arg1: index,
                             %arg2: tensor<?x?xf32>) -> f32 {
-  %0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
+  %0 = bufferization.to_buffer %arg2 read_only : tensor<?x?xf32> to memref<?x?xf32>
   %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
   return %1 : f32
 }
diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
index d828afe13c622..78a0d4fedf690 100644
--- a/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
@@ -31,6 +31,7 @@
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf32, #sparse> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_buffer %[[VAL_10]] :
+// CHECK-DAG:       %[[WTF:.*]] = bufferization.to_buffer %[[VAL_1]] :
 // CHECK-DAG:       linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_14]] : memref<8x8xf32>)
 // CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_5]] {
 // CHECK:             %[[VAL_16:.*]] = arith.subi %[[VAL_15]], %[[VAL_7]] : index
@@ -49,7 +50,7 @@
 // CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex>
 // CHECK:               %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_7]] : index
 // CHECK:               %[[VAL_28:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref<?xf32>
-// CHECK:               %[[VAL_29:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : tensor<8x8xf32>
+// CHECK:               %[[VAL_29:.*]] = memref.load %[[WTF]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
 // CHECK:               %[[VAL_30:.*]] = arith.mulf %[[VAL_28]], %[[VAL_29]] : f32
 // CHECK:               memref.store %[[VAL_30]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
 // CHECK:             } {"Emitted from" = "linalg.generic"}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index bf3473ead204e..912f78a0b81fc 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -31,7 +31,7 @@
 // CHECK:                 } do {
 // CHECK:                   %[[D3:.*]] = "subsect<trivial<compressed[0,1]>>.deref"
 // CHECK:                   "trivial<batch[1,1]>.locate"(%{{.*}}, %[[D3]])
-// CHECK:                   tensor.extract %{{.*}}{{\[}}%[[D2]], %[[D3]]]
+// CHECK:                   memref.load %{{.*}}{{\[}}%[[D2]], %[[D3]]]
 // CHECK:                   arith.muli
 // CHECK:                   arith.addi
 // CHECK:                   "subsect<trivial<compressed[0,1]>>.next
diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 4546d3367b16d..ebbcc5fc7c7cf 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -22,7 +22,7 @@
 // CHECK:           %[[VAL_13:.*]] = sparse_tensor.storage_specifier.init
 // CHECK:           %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_13]]  lvl_sz at 0 with %[[VAL_4]]
 // CHECK:           %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]]  pos_mem_sz at 0 with %[[VAL_3]]
-// CHECK:           %[[VAL_16:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_5]]] : tensor<2xindex>
+// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<2xindex>
 // CHECK:           %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %[[VAL_3]] : index
 // CHECK:           %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_15]]  crd_mem_sz at 0 with %[[VAL_17]]
 // CHECK:           %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]]  lvl_sz at 1 with %[[VAL_4]]
diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
index 4abaf03dff50f..e2c841b1ac7d5 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
@@ -27,7 +27,7 @@
 // CHECK-HIR-DAG:       %[[VAL_7:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
 // CHECK-HIR-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[DEMAP]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
 // CHECK-HIR-DAG:       %[[VAL_10:.*]] = bufferization.to_buffer %[[VAL_1]] : tensor<f32> to memref<f32>
-// CHECK-HIR:           %[[VAL_11:.*]] = tensor.extract %[[VAL_1]][] : tensor<f32>
+// CHECK-HIR:           %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
 // CHECK-HIR:           %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
 // CHECK-HIR:             %[[VAL_18:.*]] = arith.muli %[[VAL_13]], %[[VAL_6]] : index
 // CHECK-HIR:             %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
@@ -59,7 +59,7 @@
 // CHECK-MIR-DAG:       %[[DimSize2:.*]] = call @sparseLvlSize(%[[ARGA]], %[[I2]])
 // CHECK-MIR-DAG:       %[[VAL_8:.*]] = call @sparseValuesF32(%[[ARGA]]) : (!llvm.ptr) -> memref<?xf32>
 // CHECK-MIR-DAG:       %[[VAL_10:.*]] = bufferization.to_buffer %[[ARGX]] : tensor<f32> to memref<f32>
-// CHECK-MIR:           %[[VAL_11:.*]] = tensor.extract %[[ARGX]][] : tensor<f32>
+// CHECK-MIR:           %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
 // CHECK-MIR:           %[[VAL_12:.*]] = scf.for %[[D2:.*]] = %[[I0]] to %[[DimSize0]] step %[[I1]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
 // CHECK-MIR:             %[[VAL_18:.*]] = arith.muli %[[D2]], %[[DimSize1]] : index
 // CHECK-MIR:             %[[VAL_15:.*]] = scf.for %[[D0:.*]] = %[[I0]] to %[[DimSize1]] step %[[I1]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {

@llvmbot
Copy link
Member

llvmbot commented Dec 17, 2025

@llvm/pr-subscribers-mlir-bufferization

Author: Batzorig Zorigoo (batzor)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/172595.diff

6 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/canonicalize.mlir (+19-1)
  • (modified) mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir (+2-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_pack.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir (+2-2)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 56ff2121e4620..6bf81a2727204 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -846,7 +846,7 @@ struct LoadOfToBuffer : public OpRewritePattern<memref::LoadOp> {
   LogicalResult matchAndRewrite(memref::LoadOp load,
                                 PatternRewriter &rewriter) const override {
     auto toBuffer = load.getMemref().getDefiningOp<ToBufferOp>();
-    if (!toBuffer)
+    if (!toBuffer || !toBuffer.getReadOnly())
       return failure();
 
     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toBuffer.getTensor(),
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index ae1d1fcfc19dc..df07511798b91 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -294,11 +294,29 @@ func.func @tensor_cast_to_buffer_layout_and_memspace(%arg0 : tensor<4x6x16x32xi8
 
 // -----
 
+// Verify LoadOfToBuffer skips writable buffers
+// CHECK-LABEL: func @load_after_write_from_buffer_cast(
+func.func @load_after_write_from_buffer_cast(%arg0: index, %arg1: index,
+                            %arg2: tensor<?x?xf32>) -> f32 {
+  %0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
+  linalg.ceil ins(%0 : memref<?x?xf32>) outs(%0 : memref<?x?xf32>)
+  %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
+  return %1 : f32
+}
+// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
+// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
+//      CHECK: %[[M:.+]] = bufferization.to_buffer %[[TENSOR]] : tensor<?x?xf32> to memref<?x?xf32>
+//      CHECK: linalg.ceil ins(%[[M]] : memref<?x?xf32>) outs(%[[M]] : memref<?x?xf32>)
+//      CHECK: %[[RES:.*]] = memref.load %[[M]][%[[IDX0]], %[[IDX1]]] : memref<?x?xf32>
+//      CHECK: return %[[RES]] : f32
+
+// -----
+
 // Folding of memref.load(to_buffer(%v, %idxs)) -> tensor.extract(%v, %idx)
 // CHECK-LABEL: func @load_from_buffer_cast(
 func.func @load_from_buffer_cast(%arg0: index, %arg1: index,
                             %arg2: tensor<?x?xf32>) -> f32 {
-  %0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
+  %0 = bufferization.to_buffer %arg2 read_only : tensor<?x?xf32> to memref<?x?xf32>
   %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
   return %1 : f32
 }
diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
index d828afe13c622..78a0d4fedf690 100644
--- a/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
@@ -31,6 +31,7 @@
 // CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
 // CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf32, #sparse> to memref<?xf32>
 // CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_buffer %[[VAL_10]] :
+// CHECK-DAG:       %[[WTF:.*]] = bufferization.to_buffer %[[VAL_1]] :
 // CHECK-DAG:       linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_14]] : memref<8x8xf32>)
 // CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_5]] {
 // CHECK:             %[[VAL_16:.*]] = arith.subi %[[VAL_15]], %[[VAL_7]] : index
@@ -49,7 +50,7 @@
 // CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex>
 // CHECK:               %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_7]] : index
 // CHECK:               %[[VAL_28:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref<?xf32>
-// CHECK:               %[[VAL_29:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : tensor<8x8xf32>
+// CHECK:               %[[VAL_29:.*]] = memref.load %[[WTF]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
 // CHECK:               %[[VAL_30:.*]] = arith.mulf %[[VAL_28]], %[[VAL_29]] : f32
 // CHECK:               memref.store %[[VAL_30]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
 // CHECK:             } {"Emitted from" = "linalg.generic"}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index bf3473ead204e..912f78a0b81fc 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -31,7 +31,7 @@
 // CHECK:                 } do {
 // CHECK:                   %[[D3:.*]] = "subsect<trivial<compressed[0,1]>>.deref"
 // CHECK:                   "trivial<batch[1,1]>.locate"(%{{.*}}, %[[D3]])
-// CHECK:                   tensor.extract %{{.*}}{{\[}}%[[D2]], %[[D3]]]
+// CHECK:                   memref.load %{{.*}}{{\[}}%[[D2]], %[[D3]]]
 // CHECK:                   arith.muli
 // CHECK:                   arith.addi
 // CHECK:                   "subsect<trivial<compressed[0,1]>>.next
diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 4546d3367b16d..ebbcc5fc7c7cf 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -22,7 +22,7 @@
 // CHECK:           %[[VAL_13:.*]] = sparse_tensor.storage_specifier.init
 // CHECK:           %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_13]]  lvl_sz at 0 with %[[VAL_4]]
 // CHECK:           %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]]  pos_mem_sz at 0 with %[[VAL_3]]
-// CHECK:           %[[VAL_16:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_5]]] : tensor<2xindex>
+// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<2xindex>
 // CHECK:           %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %[[VAL_3]] : index
 // CHECK:           %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_15]]  crd_mem_sz at 0 with %[[VAL_17]]
 // CHECK:           %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]]  lvl_sz at 1 with %[[VAL_4]]
diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
index 4abaf03dff50f..e2c841b1ac7d5 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
@@ -27,7 +27,7 @@
 // CHECK-HIR-DAG:       %[[VAL_7:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
 // CHECK-HIR-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[DEMAP]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
 // CHECK-HIR-DAG:       %[[VAL_10:.*]] = bufferization.to_buffer %[[VAL_1]] : tensor<f32> to memref<f32>
-// CHECK-HIR:           %[[VAL_11:.*]] = tensor.extract %[[VAL_1]][] : tensor<f32>
+// CHECK-HIR:           %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
 // CHECK-HIR:           %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
 // CHECK-HIR:             %[[VAL_18:.*]] = arith.muli %[[VAL_13]], %[[VAL_6]] : index
 // CHECK-HIR:             %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
@@ -59,7 +59,7 @@
 // CHECK-MIR-DAG:       %[[DimSize2:.*]] = call @sparseLvlSize(%[[ARGA]], %[[I2]])
 // CHECK-MIR-DAG:       %[[VAL_8:.*]] = call @sparseValuesF32(%[[ARGA]]) : (!llvm.ptr) -> memref<?xf32>
 // CHECK-MIR-DAG:       %[[VAL_10:.*]] = bufferization.to_buffer %[[ARGX]] : tensor<f32> to memref<f32>
-// CHECK-MIR:           %[[VAL_11:.*]] = tensor.extract %[[ARGX]][] : tensor<f32>
+// CHECK-MIR:           %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
 // CHECK-MIR:           %[[VAL_12:.*]] = scf.for %[[D2:.*]] = %[[I0]] to %[[DimSize0]] step %[[I1]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
 // CHECK-MIR:             %[[VAL_18:.*]] = arith.muli %[[D2]], %[[DimSize1]] : index
 // CHECK-MIR:             %[[VAL_15:.*]] = scf.for %[[D0:.*]] = %[[I0]] to %[[DimSize1]] step %[[I1]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {

@batzor batzor changed the title [MLIR][Bufferization] Fold Loadonly when the buffer is read only [MLIR][Bufferization] Fold LoadOp only when the buffer is read only Dec 17, 2025
@batzor batzor force-pushed the fix/load-after-write-from-buffer branch from d20c5c4 to 60040bf Compare December 17, 2025 05:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:bufferization Bufferization infrastructure mlir:sparse Sparse compiler in MLIR mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants