-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[MLIR][Bufferization] Fold LoadOp only when the buffer is read only #172595
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir Author: Batzorig Zorigoo (batzor) ChangesFull diff: https://github.com/llvm/llvm-project/pull/172595.diff 6 Files Affected:
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 56ff2121e4620..6bf81a2727204 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -846,7 +846,7 @@ struct LoadOfToBuffer : public OpRewritePattern<memref::LoadOp> {
LogicalResult matchAndRewrite(memref::LoadOp load,
PatternRewriter &rewriter) const override {
auto toBuffer = load.getMemref().getDefiningOp<ToBufferOp>();
- if (!toBuffer)
+ if (!toBuffer || !toBuffer.getReadOnly())
return failure();
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toBuffer.getTensor(),
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index ae1d1fcfc19dc..df07511798b91 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -294,11 +294,29 @@ func.func @tensor_cast_to_buffer_layout_and_memspace(%arg0 : tensor<4x6x16x32xi8
// -----
+// Verify LoadOfToBuffer skips writable buffers
+// CHECK-LABEL: func @load_after_write_from_buffer_cast(
+func.func @load_after_write_from_buffer_cast(%arg0: index, %arg1: index,
+ %arg2: tensor<?x?xf32>) -> f32 {
+ %0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
+ linalg.ceil ins(%0 : memref<?x?xf32>) outs(%0 : memref<?x?xf32>)
+ %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
+ return %1 : f32
+}
+// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
+// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK: %[[M:.+]] = bufferization.to_buffer %[[TENSOR]] : tensor<?x?xf32> to memref<?x?xf32>
+// CHECK: linalg.ceil ins(%[[M]] : memref<?x?xf32>) outs(%[[M]] : memref<?x?xf32>)
+// CHECK: %[[RES:.*]] = memref.load %[[M]][%[[IDX0]], %[[IDX1]]] : memref<?x?xf32>
+// CHECK: return %[[RES]] : f32
+
+// -----
+
// Folding of memref.load(to_buffer(%v, %idxs)) -> tensor.extract(%v, %idx)
// CHECK-LABEL: func @load_from_buffer_cast(
func.func @load_from_buffer_cast(%arg0: index, %arg1: index,
%arg2: tensor<?x?xf32>) -> f32 {
- %0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
+ %0 = bufferization.to_buffer %arg2 read_only : tensor<?x?xf32> to memref<?x?xf32>
%1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
return %1 : f32
}
diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
index d828afe13c622..78a0d4fedf690 100644
--- a/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
@@ -31,6 +31,7 @@
// CHECK-DAG: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf32, #sparse> to memref<?xf32>
// CHECK-DAG: %[[VAL_14:.*]] = bufferization.to_buffer %[[VAL_10]] :
+// CHECK-DAG: %[[WTF:.*]] = bufferization.to_buffer %[[VAL_1]] :
// CHECK-DAG: linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_14]] : memref<8x8xf32>)
// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_5]] {
// CHECK: %[[VAL_16:.*]] = arith.subi %[[VAL_15]], %[[VAL_7]] : index
@@ -49,7 +50,7 @@
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex>
// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_7]] : index
// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref<?xf32>
-// CHECK: %[[VAL_29:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : tensor<8x8xf32>
+// CHECK: %[[VAL_29:.*]] = memref.load %[[WTF]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
// CHECK: %[[VAL_30:.*]] = arith.mulf %[[VAL_28]], %[[VAL_29]] : f32
// CHECK: memref.store %[[VAL_30]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
// CHECK: } {"Emitted from" = "linalg.generic"}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index bf3473ead204e..912f78a0b81fc 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -31,7 +31,7 @@
// CHECK: } do {
// CHECK: %[[D3:.*]] = "subsect<trivial<compressed[0,1]>>.deref"
// CHECK: "trivial<batch[1,1]>.locate"(%{{.*}}, %[[D3]])
-// CHECK: tensor.extract %{{.*}}{{\[}}%[[D2]], %[[D3]]]
+// CHECK: memref.load %{{.*}}{{\[}}%[[D2]], %[[D3]]]
// CHECK: arith.muli
// CHECK: arith.addi
// CHECK: "subsect<trivial<compressed[0,1]>>.next
diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 4546d3367b16d..ebbcc5fc7c7cf 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -22,7 +22,7 @@
// CHECK: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.init
// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_13]] lvl_sz at 0 with %[[VAL_4]]
// CHECK: %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] pos_mem_sz at 0 with %[[VAL_3]]
-// CHECK: %[[VAL_16:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_5]]] : tensor<2xindex>
+// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<2xindex>
// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %[[VAL_3]] : index
// CHECK: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_15]] crd_mem_sz at 0 with %[[VAL_17]]
// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] lvl_sz at 1 with %[[VAL_4]]
diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
index 4abaf03dff50f..e2c841b1ac7d5 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
@@ -27,7 +27,7 @@
// CHECK-HIR-DAG: %[[VAL_7:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
// CHECK-HIR-DAG: %[[VAL_8:.*]] = sparse_tensor.values %[[DEMAP]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
// CHECK-HIR-DAG: %[[VAL_10:.*]] = bufferization.to_buffer %[[VAL_1]] : tensor<f32> to memref<f32>
-// CHECK-HIR: %[[VAL_11:.*]] = tensor.extract %[[VAL_1]][] : tensor<f32>
+// CHECK-HIR: %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
// CHECK-HIR: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
// CHECK-HIR: %[[VAL_18:.*]] = arith.muli %[[VAL_13]], %[[VAL_6]] : index
// CHECK-HIR: %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
@@ -59,7 +59,7 @@
// CHECK-MIR-DAG: %[[DimSize2:.*]] = call @sparseLvlSize(%[[ARGA]], %[[I2]])
// CHECK-MIR-DAG: %[[VAL_8:.*]] = call @sparseValuesF32(%[[ARGA]]) : (!llvm.ptr) -> memref<?xf32>
// CHECK-MIR-DAG: %[[VAL_10:.*]] = bufferization.to_buffer %[[ARGX]] : tensor<f32> to memref<f32>
-// CHECK-MIR: %[[VAL_11:.*]] = tensor.extract %[[ARGX]][] : tensor<f32>
+// CHECK-MIR: %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
// CHECK-MIR: %[[VAL_12:.*]] = scf.for %[[D2:.*]] = %[[I0]] to %[[DimSize0]] step %[[I1]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
// CHECK-MIR: %[[VAL_18:.*]] = arith.muli %[[D2]], %[[DimSize1]] : index
// CHECK-MIR: %[[VAL_15:.*]] = scf.for %[[D0:.*]] = %[[I0]] to %[[DimSize1]] step %[[I1]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
|
|
@llvm/pr-subscribers-mlir-sparse Author: Batzorig Zorigoo (batzor) ChangesFull diff: https://github.com/llvm/llvm-project/pull/172595.diff 6 Files Affected:
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 56ff2121e4620..6bf81a2727204 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -846,7 +846,7 @@ struct LoadOfToBuffer : public OpRewritePattern<memref::LoadOp> {
LogicalResult matchAndRewrite(memref::LoadOp load,
PatternRewriter &rewriter) const override {
auto toBuffer = load.getMemref().getDefiningOp<ToBufferOp>();
- if (!toBuffer)
+ if (!toBuffer || !toBuffer.getReadOnly())
return failure();
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toBuffer.getTensor(),
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index ae1d1fcfc19dc..df07511798b91 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -294,11 +294,29 @@ func.func @tensor_cast_to_buffer_layout_and_memspace(%arg0 : tensor<4x6x16x32xi8
// -----
+// Verify LoadOfToBuffer skips writable buffers
+// CHECK-LABEL: func @load_after_write_from_buffer_cast(
+func.func @load_after_write_from_buffer_cast(%arg0: index, %arg1: index,
+ %arg2: tensor<?x?xf32>) -> f32 {
+ %0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
+ linalg.ceil ins(%0 : memref<?x?xf32>) outs(%0 : memref<?x?xf32>)
+ %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
+ return %1 : f32
+}
+// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
+// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK: %[[M:.+]] = bufferization.to_buffer %[[TENSOR]] : tensor<?x?xf32> to memref<?x?xf32>
+// CHECK: linalg.ceil ins(%[[M]] : memref<?x?xf32>) outs(%[[M]] : memref<?x?xf32>)
+// CHECK: %[[RES:.*]] = memref.load %[[M]][%[[IDX0]], %[[IDX1]]] : memref<?x?xf32>
+// CHECK: return %[[RES]] : f32
+
+// -----
+
// Folding of memref.load(to_buffer(%v, %idxs)) -> tensor.extract(%v, %idx)
// CHECK-LABEL: func @load_from_buffer_cast(
func.func @load_from_buffer_cast(%arg0: index, %arg1: index,
%arg2: tensor<?x?xf32>) -> f32 {
- %0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
+ %0 = bufferization.to_buffer %arg2 read_only : tensor<?x?xf32> to memref<?x?xf32>
%1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
return %1 : f32
}
diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
index d828afe13c622..78a0d4fedf690 100644
--- a/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
@@ -31,6 +31,7 @@
// CHECK-DAG: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf32, #sparse> to memref<?xf32>
// CHECK-DAG: %[[VAL_14:.*]] = bufferization.to_buffer %[[VAL_10]] :
+// CHECK-DAG: %[[WTF:.*]] = bufferization.to_buffer %[[VAL_1]] :
// CHECK-DAG: linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_14]] : memref<8x8xf32>)
// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_5]] {
// CHECK: %[[VAL_16:.*]] = arith.subi %[[VAL_15]], %[[VAL_7]] : index
@@ -49,7 +50,7 @@
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex>
// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_7]] : index
// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref<?xf32>
-// CHECK: %[[VAL_29:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : tensor<8x8xf32>
+// CHECK: %[[VAL_29:.*]] = memref.load %[[WTF]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
// CHECK: %[[VAL_30:.*]] = arith.mulf %[[VAL_28]], %[[VAL_29]] : f32
// CHECK: memref.store %[[VAL_30]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
// CHECK: } {"Emitted from" = "linalg.generic"}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index bf3473ead204e..912f78a0b81fc 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -31,7 +31,7 @@
// CHECK: } do {
// CHECK: %[[D3:.*]] = "subsect<trivial<compressed[0,1]>>.deref"
// CHECK: "trivial<batch[1,1]>.locate"(%{{.*}}, %[[D3]])
-// CHECK: tensor.extract %{{.*}}{{\[}}%[[D2]], %[[D3]]]
+// CHECK: memref.load %{{.*}}{{\[}}%[[D2]], %[[D3]]]
// CHECK: arith.muli
// CHECK: arith.addi
// CHECK: "subsect<trivial<compressed[0,1]>>.next
diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 4546d3367b16d..ebbcc5fc7c7cf 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -22,7 +22,7 @@
// CHECK: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.init
// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_13]] lvl_sz at 0 with %[[VAL_4]]
// CHECK: %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] pos_mem_sz at 0 with %[[VAL_3]]
-// CHECK: %[[VAL_16:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_5]]] : tensor<2xindex>
+// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<2xindex>
// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %[[VAL_3]] : index
// CHECK: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_15]] crd_mem_sz at 0 with %[[VAL_17]]
// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] lvl_sz at 1 with %[[VAL_4]]
diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
index 4abaf03dff50f..e2c841b1ac7d5 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
@@ -27,7 +27,7 @@
// CHECK-HIR-DAG: %[[VAL_7:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
// CHECK-HIR-DAG: %[[VAL_8:.*]] = sparse_tensor.values %[[DEMAP]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
// CHECK-HIR-DAG: %[[VAL_10:.*]] = bufferization.to_buffer %[[VAL_1]] : tensor<f32> to memref<f32>
-// CHECK-HIR: %[[VAL_11:.*]] = tensor.extract %[[VAL_1]][] : tensor<f32>
+// CHECK-HIR: %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
// CHECK-HIR: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
// CHECK-HIR: %[[VAL_18:.*]] = arith.muli %[[VAL_13]], %[[VAL_6]] : index
// CHECK-HIR: %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
@@ -59,7 +59,7 @@
// CHECK-MIR-DAG: %[[DimSize2:.*]] = call @sparseLvlSize(%[[ARGA]], %[[I2]])
// CHECK-MIR-DAG: %[[VAL_8:.*]] = call @sparseValuesF32(%[[ARGA]]) : (!llvm.ptr) -> memref<?xf32>
// CHECK-MIR-DAG: %[[VAL_10:.*]] = bufferization.to_buffer %[[ARGX]] : tensor<f32> to memref<f32>
-// CHECK-MIR: %[[VAL_11:.*]] = tensor.extract %[[ARGX]][] : tensor<f32>
+// CHECK-MIR: %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
// CHECK-MIR: %[[VAL_12:.*]] = scf.for %[[D2:.*]] = %[[I0]] to %[[DimSize0]] step %[[I1]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
// CHECK-MIR: %[[VAL_18:.*]] = arith.muli %[[D2]], %[[DimSize1]] : index
// CHECK-MIR: %[[VAL_15:.*]] = scf.for %[[D0:.*]] = %[[I0]] to %[[DimSize1]] step %[[I1]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
|
|
@llvm/pr-subscribers-mlir-bufferization Author: Batzorig Zorigoo (batzor) ChangesFull diff: https://github.com/llvm/llvm-project/pull/172595.diff 6 Files Affected:
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 56ff2121e4620..6bf81a2727204 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -846,7 +846,7 @@ struct LoadOfToBuffer : public OpRewritePattern<memref::LoadOp> {
LogicalResult matchAndRewrite(memref::LoadOp load,
PatternRewriter &rewriter) const override {
auto toBuffer = load.getMemref().getDefiningOp<ToBufferOp>();
- if (!toBuffer)
+ if (!toBuffer || !toBuffer.getReadOnly())
return failure();
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toBuffer.getTensor(),
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index ae1d1fcfc19dc..df07511798b91 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -294,11 +294,29 @@ func.func @tensor_cast_to_buffer_layout_and_memspace(%arg0 : tensor<4x6x16x32xi8
// -----
+// Verify LoadOfToBuffer skips writable buffers
+// CHECK-LABEL: func @load_after_write_from_buffer_cast(
+func.func @load_after_write_from_buffer_cast(%arg0: index, %arg1: index,
+ %arg2: tensor<?x?xf32>) -> f32 {
+ %0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
+ linalg.ceil ins(%0 : memref<?x?xf32>) outs(%0 : memref<?x?xf32>)
+ %1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
+ return %1 : f32
+}
+// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
+// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK: %[[M:.+]] = bufferization.to_buffer %[[TENSOR]] : tensor<?x?xf32> to memref<?x?xf32>
+// CHECK: linalg.ceil ins(%[[M]] : memref<?x?xf32>) outs(%[[M]] : memref<?x?xf32>)
+// CHECK: %[[RES:.*]] = memref.load %[[M]][%[[IDX0]], %[[IDX1]]] : memref<?x?xf32>
+// CHECK: return %[[RES]] : f32
+
+// -----
+
// Folding of memref.load(to_buffer(%v, %idxs)) -> tensor.extract(%v, %idx)
// CHECK-LABEL: func @load_from_buffer_cast(
func.func @load_from_buffer_cast(%arg0: index, %arg1: index,
%arg2: tensor<?x?xf32>) -> f32 {
- %0 = bufferization.to_buffer %arg2 : tensor<?x?xf32> to memref<?x?xf32>
+ %0 = bufferization.to_buffer %arg2 read_only : tensor<?x?xf32> to memref<?x?xf32>
%1 = memref.load %0[%arg0, %arg1] : memref<?x?xf32>
return %1 : f32
}
diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
index d828afe13c622..78a0d4fedf690 100644
--- a/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
@@ -31,6 +31,7 @@
// CHECK-DAG: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf32, #sparse> to memref<?xf32>
// CHECK-DAG: %[[VAL_14:.*]] = bufferization.to_buffer %[[VAL_10]] :
+// CHECK-DAG: %[[WTF:.*]] = bufferization.to_buffer %[[VAL_1]] :
// CHECK-DAG: linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_14]] : memref<8x8xf32>)
// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_5]] {
// CHECK: %[[VAL_16:.*]] = arith.subi %[[VAL_15]], %[[VAL_7]] : index
@@ -49,7 +50,7 @@
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex>
// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_7]] : index
// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref<?xf32>
-// CHECK: %[[VAL_29:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : tensor<8x8xf32>
+// CHECK: %[[VAL_29:.*]] = memref.load %[[WTF]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
// CHECK: %[[VAL_30:.*]] = arith.mulf %[[VAL_28]], %[[VAL_29]] : f32
// CHECK: memref.store %[[VAL_30]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
// CHECK: } {"Emitted from" = "linalg.generic"}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index bf3473ead204e..912f78a0b81fc 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -31,7 +31,7 @@
// CHECK: } do {
// CHECK: %[[D3:.*]] = "subsect<trivial<compressed[0,1]>>.deref"
// CHECK: "trivial<batch[1,1]>.locate"(%{{.*}}, %[[D3]])
-// CHECK: tensor.extract %{{.*}}{{\[}}%[[D2]], %[[D3]]]
+// CHECK: memref.load %{{.*}}{{\[}}%[[D2]], %[[D3]]]
// CHECK: arith.muli
// CHECK: arith.addi
// CHECK: "subsect<trivial<compressed[0,1]>>.next
diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
index 4546d3367b16d..ebbcc5fc7c7cf 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir
@@ -22,7 +22,7 @@
// CHECK: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.init
// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_13]] lvl_sz at 0 with %[[VAL_4]]
// CHECK: %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] pos_mem_sz at 0 with %[[VAL_3]]
-// CHECK: %[[VAL_16:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_5]]] : tensor<2xindex>
+// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<2xindex>
// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_16]], %[[VAL_3]] : index
// CHECK: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_15]] crd_mem_sz at 0 with %[[VAL_17]]
// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] lvl_sz at 1 with %[[VAL_4]]
diff --git a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
index 4abaf03dff50f..e2c841b1ac7d5 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_perm_lower.mlir
@@ -27,7 +27,7 @@
// CHECK-HIR-DAG: %[[VAL_7:.*]] = sparse_tensor.lvl %[[DEMAP]], %[[VAL_4]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
// CHECK-HIR-DAG: %[[VAL_8:.*]] = sparse_tensor.values %[[DEMAP]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
// CHECK-HIR-DAG: %[[VAL_10:.*]] = bufferization.to_buffer %[[VAL_1]] : tensor<f32> to memref<f32>
-// CHECK-HIR: %[[VAL_11:.*]] = tensor.extract %[[VAL_1]][] : tensor<f32>
+// CHECK-HIR: %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
// CHECK-HIR: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_2]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
// CHECK-HIR: %[[VAL_18:.*]] = arith.muli %[[VAL_13]], %[[VAL_6]] : index
// CHECK-HIR: %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_2]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
@@ -59,7 +59,7 @@
// CHECK-MIR-DAG: %[[DimSize2:.*]] = call @sparseLvlSize(%[[ARGA]], %[[I2]])
// CHECK-MIR-DAG: %[[VAL_8:.*]] = call @sparseValuesF32(%[[ARGA]]) : (!llvm.ptr) -> memref<?xf32>
// CHECK-MIR-DAG: %[[VAL_10:.*]] = bufferization.to_buffer %[[ARGX]] : tensor<f32> to memref<f32>
-// CHECK-MIR: %[[VAL_11:.*]] = tensor.extract %[[ARGX]][] : tensor<f32>
+// CHECK-MIR: %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
// CHECK-MIR: %[[VAL_12:.*]] = scf.for %[[D2:.*]] = %[[I0]] to %[[DimSize0]] step %[[I1]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) {
// CHECK-MIR: %[[VAL_18:.*]] = arith.muli %[[D2]], %[[DimSize1]] : index
// CHECK-MIR: %[[VAL_15:.*]] = scf.for %[[D0:.*]] = %[[I0]] to %[[DimSize1]] step %[[I1]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) {
|
d20c5c4 to
60040bf
Compare
When we
memref.loadfrom a buffer, it folded totensor.extracteven when the buffer was writable, causing unexpected results. For example:would fold into