Allow alignment to be specified for FinalBufferizePass
Update jitrt pipeline to use this instead of ConstantBufferization
PiperOrigin-RevId: 428708577
Change-Id: I5e59253578064982a3bfb2bf99d0cb820a4a8dbb
diff --git a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc
index cf09089..503819a 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc
@@ -196,10 +196,8 @@
// bufferizing anything.
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createCanonicalizerPass());
- // Turn tensor constants into global memrefs.
- // TODO(kramerb): Expose the patterns and add them to the bufferize passes.
- pm.addPass(mlir::arith::createConstantBufferizePass(/*alignment=*/64));
- pm.addPass(mlir::kernel_gen::transforms::CreateFinalBufferizePass());
+ pm.addPass(
+ mlir::kernel_gen::transforms::CreateFinalBufferizePass(/*alignment=*/64));
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createCanonicalizerPass());
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir
index 46bb029..0760545 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir
@@ -1,8 +1,9 @@
-// RUN: kernel-gen-opt %s --computeop-and-func-bufferize --final-bufferize \
-// RUN: --split-input-file | FileCheck %s --check-prefixes=CHECK,ALLOC
-// RUN: kernel-gen-opt %s --computeop-and-func-bufferize --final-bufferize \
-// RUN: --promote-buffers-to-stack --split-input-file |\
-// RUN: FileCheck %s --check-prefixes=CHECK,ALLOCA
+// RUN: kernel-gen-opt %s --computeop-and-func-bufferize \
+// RUN: --final-bufferize=alignment=128 --split-input-file | FileCheck %s \
+// RUN: --check-prefixes=CHECK,ALLOC
+// RUN: kernel-gen-opt %s --computeop-and-func-bufferize \
+// RUN: --final-bufferize=alignment=128 --promote-buffers-to-stack \
+// RUN: --split-input-file | FileCheck %s --check-prefixes=CHECK,ALLOCA
// CHECK-LABEL: @tensor.extract
// CHECK-SAME: (%[[ARG:.*]]: memref<?xf32>) -> f32
@@ -80,6 +81,7 @@
// -----
// CHECK: memref.global "private" constant @[[BUFFER:.*]] : memref<3xf32> = dense<[4.000000e+00, 5.000000e+00, 6.000000e+00]>
+// CHECK-SAME: alignment = 128
// CHECK: @const
// CHECK-SAME: -> memref<3xf32>
func @const() -> tensor<3xf32> {
@@ -92,6 +94,7 @@
// -----
// CHECK: memref.global "private" constant @[[BUFFER:.*]] : memref<3xf32> = dense<4.000000e+00>
+// CHECK-SAME: alignment = 128
// CHECK: @const_splat
// CHECK-SAME: -> memref<3xf32>
func @const_splat() -> tensor<3xf32> {
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
index 45cc77f..5ce31b1 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
@@ -16,6 +16,7 @@
// This file implements logic for translating mixed IR to buffer form.
// Currently it supports MHLO and some operations from the Standard dialect.
+#include <cstdint>
#include <memory>
#include <utility>
@@ -249,6 +250,10 @@
tensor::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
}
+ // Default alignment_ specified in passes.td
+ FinalBufferizePass() = default;
+
+ explicit FinalBufferizePass(uint64_t alignment) { alignment_ = alignment; }
void runOnOperation() override {
// Bufferize ops using BufferizableOpInterface. This could be switched to
@@ -256,6 +261,7 @@
RewritePatternSet patterns(&getContext());
bufferization::BufferizationOptions options =
bufferization::getPartialBufferizationOptions();
+ options.bufferAlignment = alignment_;
// TODO(springerm): Add dialects to this filter as more and more dialects
// will be migrated to BufferizableOpInterface-based bufferization.
options.addToDialectFilter<arith::ArithmeticDialect, StandardOpsDialect,
@@ -329,6 +335,11 @@
return std::make_unique<FinalBufferizePass>();
}
+std::unique_ptr<OperationPass<ModuleOp>> CreateFinalBufferizePass(
+ uint64_t alignment) {
+ return std::make_unique<FinalBufferizePass>(alignment);
+}
+
} // namespace transforms
} // namespace kernel_gen
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
index 6c52c79..057263e 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
@@ -126,6 +126,9 @@
// Pass to remove copies which are consumed by a GenericOp.
std::unique_ptr<OperationPass<FuncOp>> CreateCopyCleanupPass();
+std::unique_ptr<OperationPass<ModuleOp>> CreateFinalBufferizePass(
+ uint64_t alignment);
+
} // namespace transforms
#define GEN_PASS_REGISTRATION
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
index 13342f2..ad64a5a 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
@@ -87,6 +87,10 @@
let summary = "Pass to transform late operations on values to buffer based "
"ones.";
let constructor = "transforms::CreateFinalBufferizePass()";
+ let options = [
+ Option<"alignment_", "alignment", "uint64_t",
+ /*default=*/"64", "Memory alignment">,
+ ];
}
def ConvertToSignlessPass : Pass<"convert-to-signless", "ModuleOp"> {