[MLIR] Update lhlo.const to linalg lowering to use affine.store instead of std.store
The xla_lhlo.const lowering uses std.store to store a constant to
0-d memrefs. Update it to affine.store since such an access is trivially
affine (no indices). An affine.store can always be lowered to std.store.
diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD
index 6c29bd6..c7bda88 100644
--- a/tensorflow/compiler/mlir/hlo/BUILD
+++ b/tensorflow/compiler/mlir/hlo/BUILD
@@ -359,6 +359,7 @@
":map_lmhlo_to_scalar_op",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:Pass",
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
index 7286914..4191987 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
+#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
@@ -692,7 +693,8 @@
if (valueAttr.getType().getRank() != 0) return failure();
auto stdConstOp =
rewriter.create<mlir::ConstantOp>(loc, valueAttr.getValue({}));
- rewriter.create<mlir::StoreOp>(loc, stdConstOp, constOp.getOperand());
+ rewriter.create<mlir::AffineStoreOp>(loc, stdConstOp, constOp.getOperand(),
+ ValueRange());
rewriter.eraseOp(constOp);
return success();
}
@@ -827,7 +829,8 @@
void runOnFunction() override {
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
- target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
+ target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
+ AffineDialect>();
auto func = getFunction();
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir
index 6981466..dd88e5c 100644
--- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir
@@ -329,7 +329,7 @@
return
}
// CHECK: %[[CONSTANT:.*]] = constant 10 : i32
-// CHECK: store %[[CONSTANT]], %{{.*}}[] : memref<i32>
+// CHECK: affine.store %[[CONSTANT]], %{{.*}}[] : memref<i32>
// -----