Allow alignment to be specified for FinalBufferizePass

Update jitrt pipeline to use this instead of ConstantBufferization

PiperOrigin-RevId: 428708577
Change-Id: I5e59253578064982a3bfb2bf99d0cb820a4a8dbb
diff --git a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc
index cf09089..503819a 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.cc
@@ -196,10 +196,8 @@
   // bufferizing anything.
   pm.addPass(mlir::createCSEPass());
   pm.addPass(mlir::createCanonicalizerPass());
-  // Turn tensor constants into global memrefs.
-  // TODO(kramerb): Expose the patterns and add them to the bufferize passes.
-  pm.addPass(mlir::arith::createConstantBufferizePass(/*alignment=*/64));
-  pm.addPass(mlir::kernel_gen::transforms::CreateFinalBufferizePass());
+  pm.addPass(
+      mlir::kernel_gen::transforms::CreateFinalBufferizePass(/*alignment=*/64));
   pm.addPass(mlir::createCSEPass());
   pm.addPass(mlir::createCanonicalizerPass());
 
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir
index 46bb029..0760545 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/bufferize.mlir
@@ -1,8 +1,9 @@
-// RUN: kernel-gen-opt %s --computeop-and-func-bufferize --final-bufferize \
-// RUN:   --split-input-file | FileCheck %s --check-prefixes=CHECK,ALLOC
-// RUN: kernel-gen-opt %s --computeop-and-func-bufferize --final-bufferize \
-// RUN:  --promote-buffers-to-stack --split-input-file |\
-// RUN:  FileCheck %s  --check-prefixes=CHECK,ALLOCA
+// RUN: kernel-gen-opt %s --computeop-and-func-bufferize \
+// RUN:    --final-bufferize=alignment=128 --split-input-file | FileCheck %s \
+// RUN:    --check-prefixes=CHECK,ALLOC
+// RUN: kernel-gen-opt %s --computeop-and-func-bufferize \
+// RUN:    --final-bufferize=alignment=128 --promote-buffers-to-stack \
+// RUN:    --split-input-file | FileCheck %s  --check-prefixes=CHECK,ALLOCA
 
 // CHECK-LABEL: @tensor.extract
 // CHECK-SAME: (%[[ARG:.*]]: memref<?xf32>) -> f32
@@ -80,6 +81,7 @@
 // -----
 
 // CHECK: memref.global "private" constant @[[BUFFER:.*]] : memref<3xf32> = dense<[4.000000e+00, 5.000000e+00, 6.000000e+00]>
+// CHECK-SAME: alignment = 128
 // CHECK: @const
 // CHECK-SAME: -> memref<3xf32>
 func @const() -> tensor<3xf32> {
@@ -92,6 +94,7 @@
 // -----
 
 // CHECK: memref.global "private" constant @[[BUFFER:.*]] : memref<3xf32> = dense<4.000000e+00>
+// CHECK-SAME: alignment = 128
 // CHECK: @const_splat
 // CHECK-SAME: -> memref<3xf32>
 func @const_splat() -> tensor<3xf32> {
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
index 45cc77f..5ce31b1 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
@@ -16,6 +16,7 @@
 // This file implements logic for translating mixed IR to buffer form.
 // Currently it supports MHLO and some operations from the Standard dialect.
 
+#include <cstdint>
 #include <memory>
 #include <utility>
 
@@ -249,6 +250,10 @@
     tensor::registerBufferizableOpInterfaceExternalModels(registry);
     arith::registerBufferizableOpInterfaceExternalModels(registry);
   }
+  // Default alignment_ specified in passes.td
+  FinalBufferizePass() = default;
+
+  explicit FinalBufferizePass(uint64_t alignment) { alignment_ = alignment; }
 
   void runOnOperation() override {
     // Bufferize ops using BufferizableOpInterface. This could be switched to
@@ -256,6 +261,7 @@
     RewritePatternSet patterns(&getContext());
     bufferization::BufferizationOptions options =
         bufferization::getPartialBufferizationOptions();
+    options.bufferAlignment = alignment_;
     // TODO(springerm): Add dialects to this filter as more and more dialects
     // will be migrated to BufferizableOpInterface-based bufferization.
     options.addToDialectFilter<arith::ArithmeticDialect, StandardOpsDialect,
@@ -329,6 +335,11 @@
   return std::make_unique<FinalBufferizePass>();
 }
 
+std::unique_ptr<OperationPass<ModuleOp>> CreateFinalBufferizePass(
+    uint64_t alignment) {
+  return std::make_unique<FinalBufferizePass>(alignment);
+}
+
 }  // namespace transforms
 }  // namespace kernel_gen
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
index 6c52c79..057263e 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
@@ -126,6 +126,9 @@
 // Pass to remove copies which are consumed by a GenericOp.
 std::unique_ptr<OperationPass<FuncOp>> CreateCopyCleanupPass();
 
+std::unique_ptr<OperationPass<ModuleOp>> CreateFinalBufferizePass(
+    uint64_t alignment);
+
 }  // namespace transforms
 
 #define GEN_PASS_REGISTRATION
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
index 13342f2..ad64a5a 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
@@ -87,6 +87,10 @@
   let summary = "Pass to transform late operations on values to buffer based "
                 "ones.";
   let constructor = "transforms::CreateFinalBufferizePass()";
+  let options = [
+      Option<"alignment_", "alignment", "uint64_t",
+             /*default=*/"64", "Memory alignment">,
+  ];
 }
 
 def ConvertToSignlessPass : Pass<"convert-to-signless", "ModuleOp"> {