Add bufferization pass that transforms hlo and some standard ops.

This is good enough to do a tanh operation.

PiperOrigin-RevId: 324989254
Change-Id: Ief17856bd17dc9d21feba4ed909d7499a54bdc9d
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD
index 0d346da..66a378d 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD
@@ -21,6 +21,20 @@
 )
 
 cc_library(
+    name = "bufferize",
+    srcs = ["bufferize.cc"],
+    hdrs = ["rewriters.h"],
+    deps = [
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:StandardOps",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:Transforms",
+    ],
+)
+
+cc_library(
     name = "embed_tf_framework",
     srcs = ["embed_tf_framework.cc"],
     hdrs = ["rewriters.h"],
@@ -36,7 +50,7 @@
 )
 
 gentbl(
-    name = "tf_framework_passes_inc_gen",
+    name = "kernel_gen_passes_inc_gen",
     tbl_outs = [("-gen-pass-decls -name KernelGen", "kernel_gen_passes.h.inc")],
     tblgen = "@llvm-project//mlir:mlir-tblgen",
     td_file = "passes.td",
@@ -46,15 +60,20 @@
 cc_library(
     name = "passes",
     srcs = [
+        "bufferize_pass.cc",
         "embed_tf_framework_pass.cc",
         "shape_to_descriptors_pass.cc",
         "tf_framework_legalize_to_llvm_pass.cc",
     ],
     hdrs = ["passes.h"],
     deps = [
+        ":bufferize",
         ":embed_tf_framework",
+        ":kernel_gen_passes_inc_gen",
         ":tf_framework_legalize_to_llvm",
-        ":tf_framework_passes_inc_gen",
+        "//tensorflow/compiler/mlir/hlo",
+        "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
+        "//tensorflow/compiler/mlir/hlo:lhlo",
         "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:IR",
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc
new file mode 100644
index 0000000..3d5c820
--- /dev/null
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc
@@ -0,0 +1,110 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file implements logic for translating mixed IR to buffer form.
+
+#include <cstddef>
+#include <memory>
+
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/IR/Function.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/Operation.h"  // from @llvm-project
+#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "mlir/Transforms/BufferPlacement.h"  // from @llvm-project
+#include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
+
+namespace mlir {
+namespace kernel_gen {
+namespace transforms {
+
+namespace {
+
+class TensorFromElementsOpConverter
+    : public BufferAssignmentOpConversionPattern<TensorFromElementsOp> {
+ public:
+  using BufferAssignmentOpConversionPattern<
+      TensorFromElementsOp>::BufferAssignmentOpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      TensorFromElementsOp op, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const final {
+    Location loc = op.getLoc();
+    ShapedType result_type = op.getType().cast<ShapedType>();
+    int number_of_elements = op.elements().size();
+    MemRefType memref_type =
+        MemRefType::get({number_of_elements}, result_type.getElementType());
+    Value result = rewriter.create<AllocaOp>(loc, memref_type);
+    for (auto operand : llvm::enumerate(operands)) {
+      Value index = rewriter.create<ConstantIndexOp>(loc, operand.index());
+      rewriter.create<StoreOp>(loc, operand.value(), result, index);
+    }
+    rewriter.replaceOp(op, {result});
+    return success();
+  }
+};
+
+class TensorLoadOpConversion
+    : public BufferAssignmentOpConversionPattern<TensorLoadOp> {
+ public:
+  using BufferAssignmentOpConversionPattern<
+      TensorLoadOp>::BufferAssignmentOpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      TensorLoadOp op, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const final {
+    TensorLoadOpAdaptor adaptor(operands);
+    rewriter.replaceOp(op, {adaptor.memref()});
+    return success();
+  }
+};
+
+class ExtractElementOpConversion
+    : public BufferAssignmentOpConversionPattern<ExtractElementOp> {
+ public:
+  using BufferAssignmentOpConversionPattern<
+      ExtractElementOp>::BufferAssignmentOpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      ExtractElementOp op, ArrayRef<Value> operands,
+      ConversionPatternRewriter &rewriter) const final {
+    ExtractElementOpAdaptor adaptor(operands);
+
+    if (!adaptor.aggregate().getType().isa<MemRefType>()) {
+      return failure();
+    }
+
+    rewriter.replaceOpWithNewOp<LoadOp>(op, adaptor.aggregate(),
+                                        adaptor.indices());
+    return success();
+  }
+};
+
+}  // namespace
+
+void populateStandardBufferizePattern(MLIRContext *context,
+                                      BufferAssignmentPlacer *bufferAssignment,
+                                      TypeConverter *converter,
+                                      OwningRewritePatternList *patterns) {
+  patterns->insert<ExtractElementOpConversion, TensorFromElementsOpConverter,
+                   TensorLoadOpConversion>(context, bufferAssignment,
+                                           converter);
+}
+
+}  // namespace transforms
+}  // namespace kernel_gen
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
new file mode 100644
index 0000000..ebbc92f
--- /dev/null
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
@@ -0,0 +1,107 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file implements logic for translating mixed IR to buffer form.
+// Currently it supports MHLO and some operations from the Standard dialect.
+
+#include <memory>
+
+#include "mlir/Dialect/SCF/SCF.h"  // from @llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/IR/Function.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/Operation.h"  // from @llvm-project
+#include "mlir/IR/PatternMatch.h"  // from @llvm-project
+#include "mlir/IR/StandardTypes.h"  // from @llvm-project
+#include "mlir/IR/Visitors.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "mlir/Transforms/BufferPlacement.h"  // from @llvm-project
+#include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
+#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
+
+namespace mlir {
+namespace kernel_gen {
+namespace transforms {
+namespace {
+
+#define GEN_PASS_CLASSES
+#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
+
+struct BufferizePass : public BufferizePassBase<BufferizePass> {
+ public:
+  void runOnOperation() override {
+    OwningRewritePatternList patterns;
+    auto& context = getContext();
+    ConversionTarget target(context);
+    target.addLegalDialect<lmhlo::LmhloDialect>();
+    target.addLegalDialect<StandardOpsDialect>();
+    target.addLegalDialect<scf::SCFDialect>();
+    target.addLegalOp<ModuleOp>();
+    target.addLegalOp<ModuleTerminatorOp>();
+    target.addIllegalDialect<mhlo::MhloDialect>();
+    target.addIllegalOp<TensorFromElementsOp>();
+    target.addIllegalOp<TensorLoadOp>();
+    target.addIllegalOp<ExtractElementOp>();
+
+    BufferAssignmentTypeConverter converter;
+    auto typesAreLegal = [&converter](Operation* op) {
+      return converter.isLegal(op->getOperandTypes()) &&
+             converter.isLegal(op->getResultTypes());
+    };
+    target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+      auto inputs = op.getType().getInputs();
+      auto results = op.getType().getResults();
+      return converter.isLegal(inputs) && converter.isLegal(results) &&
+             converter.isLegal(&op.getBody());
+    });
+    target.addDynamicallyLegalOp<CallOp>(typesAreLegal);
+    target.addDynamicallyLegalOp<ReturnOp>(typesAreLegal);
+
+    auto module = getOperation();
+    WalkResult result = module.walk([&](FuncOp func) -> WalkResult {
+      BufferAssignmentPlacer bufferAssignment(func);
+      OwningRewritePatternList patterns;
+      mhlo::populateHLOToLHLOConversionPattern(
+          func.getContext(), &bufferAssignment, &converter, &patterns);
+      populateWithBufferAssignmentOpConversionPatterns<
+          ReturnOp, ReturnOp, lmhlo::CopyOp,
+          /*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment,
+                                               &converter, &patterns);
+      populateStandardBufferizePattern(func.getContext(), &bufferAssignment,
+                                       &converter, &patterns);
+
+      return applyFullConversion(func, target, patterns);
+    });
+    module.dump();
+    if (result.wasInterrupted()) {
+      signalPassFailure();
+    }
+  }
+};
+
+}  // namespace
+
+std::unique_ptr<OperationPass<ModuleOp> > CreateBufferizePass() {
+  return std::make_unique<BufferizePass>();
+}
+
+}  // namespace transforms
+}  // namespace kernel_gen
+}  // namespace mlir
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
index 13f367c..e65d840 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
@@ -41,7 +41,11 @@
 
 // Pass to tranform shape computations in shape dialect to standard and scf
 // using memref descriptors.
-std::unique_ptr<Pass> CreateShapeToDescriptorsPass();
+std::unique_ptr<OperationPass<ModuleOp> > CreateShapeToDescriptorsPass();
+
+// Pass to tranform computations on values to their corresponding parts on
+// buffers.
+std::unique_ptr<OperationPass<ModuleOp> > CreateBufferizePass();
 
 }  // namespace transforms
 
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
index 6172067..6a0e328 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
@@ -34,4 +34,9 @@
   let constructor = "transforms::CreateShapeToDescriptorsPass()";
 }
 
+def BufferizePass : Pass<"test-bufferize", "ModuleOp"> {
+  let summary = "Pass to transform operations on values to buffer based ones";
+  let constructor = "transforms::CreateBufferizePass()";
+}
+
 #endif // TF_FRAMEWORK_PASSES
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h
index 257e84b..4efc1e9 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h
@@ -20,6 +20,7 @@
 
 namespace mlir {
 
+class BufferAssignmentPlacer;
 class LLVMTypeConverter;
 class MLIRContext;
 class OwningRewritePatternList;
@@ -37,6 +38,16 @@
     MLIRContext *context, OwningRewritePatternList *patterns);
 
 }  // namespace tf_framework
+
+namespace transforms {
+
+/// Collects a set of patterns that bufferize operations from the standard
+/// dialect.
+void populateStandardBufferizePattern(MLIRContext *context,
+                                      BufferAssignmentPlacer *bufferAssignment,
+                                      TypeConverter *converter,
+                                      OwningRewritePatternList *patterns);
+}  // namespace transforms
 }  // namespace kernel_gen
 }  // namespace mlir
 
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc
index 9c1b434..28d3647 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc
@@ -26,6 +26,7 @@
 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
 #include "mlir/Pass/Pass.h"  // from @llvm-project
 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
 
 namespace mlir {
 namespace kernel_gen {