Add bufferization pass that transforms hlo and some standard ops.
This is good enough to do a tanh operation.
PiperOrigin-RevId: 324989254
Change-Id: Ief17856bd17dc9d21feba4ed909d7499a54bdc9d
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD
index 0d346da..66a378d 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD
@@ -21,6 +21,20 @@
)
cc_library(
+ name = "bufferize",
+ srcs = ["bufferize.cc"],
+ hdrs = ["rewriters.h"],
+ deps = [
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Pass",
+ "@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:Transforms",
+ ],
+)
+
+cc_library(
name = "embed_tf_framework",
srcs = ["embed_tf_framework.cc"],
hdrs = ["rewriters.h"],
@@ -36,7 +50,7 @@
)
gentbl(
- name = "tf_framework_passes_inc_gen",
+ name = "kernel_gen_passes_inc_gen",
tbl_outs = [("-gen-pass-decls -name KernelGen", "kernel_gen_passes.h.inc")],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "passes.td",
@@ -46,15 +60,20 @@
cc_library(
name = "passes",
srcs = [
+ "bufferize_pass.cc",
"embed_tf_framework_pass.cc",
"shape_to_descriptors_pass.cc",
"tf_framework_legalize_to_llvm_pass.cc",
],
hdrs = ["passes.h"],
deps = [
+ ":bufferize",
":embed_tf_framework",
+ ":kernel_gen_passes_inc_gen",
":tf_framework_legalize_to_llvm",
- ":tf_framework_passes_inc_gen",
+ "//tensorflow/compiler/mlir/hlo",
+ "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
+ "//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc
new file mode 100644
index 0000000..3d5c820
--- /dev/null
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc
@@ -0,0 +1,110 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file implements logic for translating mixed IR to buffer form.
+
+#include <cstddef>
+#include <memory>
+
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
+#include "mlir/IR/Function.h" // from @llvm-project
+#include "mlir/IR/MLIRContext.h" // from @llvm-project
+#include "mlir/IR/Operation.h" // from @llvm-project
+#include "mlir/IR/StandardTypes.h" // from @llvm-project
+#include "mlir/Pass/Pass.h" // from @llvm-project
+#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project
+#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
+
+namespace mlir {
+namespace kernel_gen {
+namespace transforms {
+
+namespace {
+
+class TensorFromElementsOpConverter
+ : public BufferAssignmentOpConversionPattern<TensorFromElementsOp> {
+ public:
+ using BufferAssignmentOpConversionPattern<
+ TensorFromElementsOp>::BufferAssignmentOpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ TensorFromElementsOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ Location loc = op.getLoc();
+ ShapedType result_type = op.getType().cast<ShapedType>();
+ int number_of_elements = op.elements().size();
+ MemRefType memref_type =
+ MemRefType::get({number_of_elements}, result_type.getElementType());
+ Value result = rewriter.create<AllocaOp>(loc, memref_type);
+ for (auto operand : llvm::enumerate(operands)) {
+ Value index = rewriter.create<ConstantIndexOp>(loc, operand.index());
+ rewriter.create<StoreOp>(loc, operand.value(), result, index);
+ }
+ rewriter.replaceOp(op, {result});
+ return success();
+ }
+};
+
+class TensorLoadOpConversion
+ : public BufferAssignmentOpConversionPattern<TensorLoadOp> {
+ public:
+ using BufferAssignmentOpConversionPattern<
+ TensorLoadOp>::BufferAssignmentOpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ TensorLoadOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ TensorLoadOpAdaptor adaptor(operands);
+ rewriter.replaceOp(op, {adaptor.memref()});
+ return success();
+ }
+};
+
+class ExtractElementOpConversion
+ : public BufferAssignmentOpConversionPattern<ExtractElementOp> {
+ public:
+ using BufferAssignmentOpConversionPattern<
+ ExtractElementOp>::BufferAssignmentOpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ ExtractElementOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ ExtractElementOpAdaptor adaptor(operands);
+
+ if (!adaptor.aggregate().getType().isa<MemRefType>()) {
+ return failure();
+ }
+
+ rewriter.replaceOpWithNewOp<LoadOp>(op, adaptor.aggregate(),
+ adaptor.indices());
+ return success();
+ }
+};
+
+} // namespace
+
+void populateStandardBufferizePattern(MLIRContext *context,
+ BufferAssignmentPlacer *bufferAssignment,
+ TypeConverter *converter,
+ OwningRewritePatternList *patterns) {
+ patterns->insert<ExtractElementOpConversion, TensorFromElementsOpConverter,
+ TensorLoadOpConversion>(context, bufferAssignment,
+ converter);
+}
+
+} // namespace transforms
+} // namespace kernel_gen
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
new file mode 100644
index 0000000..ebbc92f
--- /dev/null
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc
@@ -0,0 +1,107 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file implements logic for translating mixed IR to buffer form.
+// Currently it supports MHLO and some operations from the Standard dialect.
+
+#include <memory>
+
+#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
+#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
+#include "mlir/IR/Function.h" // from @llvm-project
+#include "mlir/IR/MLIRContext.h" // from @llvm-project
+#include "mlir/IR/Operation.h" // from @llvm-project
+#include "mlir/IR/PatternMatch.h" // from @llvm-project
+#include "mlir/IR/StandardTypes.h" // from @llvm-project
+#include "mlir/IR/Visitors.h" // from @llvm-project
+#include "mlir/Pass/Pass.h" // from @llvm-project
+#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project
+#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
+#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
+#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
+
+namespace mlir {
+namespace kernel_gen {
+namespace transforms {
+namespace {
+
+#define GEN_PASS_CLASSES
+#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
+
+struct BufferizePass : public BufferizePassBase<BufferizePass> {
+ public:
+ void runOnOperation() override {
+ OwningRewritePatternList patterns;
+ auto& context = getContext();
+ ConversionTarget target(context);
+ target.addLegalDialect<lmhlo::LmhloDialect>();
+ target.addLegalDialect<StandardOpsDialect>();
+ target.addLegalDialect<scf::SCFDialect>();
+ target.addLegalOp<ModuleOp>();
+ target.addLegalOp<ModuleTerminatorOp>();
+ target.addIllegalDialect<mhlo::MhloDialect>();
+ target.addIllegalOp<TensorFromElementsOp>();
+ target.addIllegalOp<TensorLoadOp>();
+ target.addIllegalOp<ExtractElementOp>();
+
+ BufferAssignmentTypeConverter converter;
+ auto typesAreLegal = [&converter](Operation* op) {
+ return converter.isLegal(op->getOperandTypes()) &&
+ converter.isLegal(op->getResultTypes());
+ };
+ target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+ auto inputs = op.getType().getInputs();
+ auto results = op.getType().getResults();
+ return converter.isLegal(inputs) && converter.isLegal(results) &&
+ converter.isLegal(&op.getBody());
+ });
+ target.addDynamicallyLegalOp<CallOp>(typesAreLegal);
+ target.addDynamicallyLegalOp<ReturnOp>(typesAreLegal);
+
+ auto module = getOperation();
+ WalkResult result = module.walk([&](FuncOp func) -> WalkResult {
+ BufferAssignmentPlacer bufferAssignment(func);
+ OwningRewritePatternList patterns;
+ mhlo::populateHLOToLHLOConversionPattern(
+ func.getContext(), &bufferAssignment, &converter, &patterns);
+ populateWithBufferAssignmentOpConversionPatterns<
+ ReturnOp, ReturnOp, lmhlo::CopyOp,
+ /*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment,
+ &converter, &patterns);
+ populateStandardBufferizePattern(func.getContext(), &bufferAssignment,
+ &converter, &patterns);
+
+ return applyFullConversion(func, target, patterns);
+ });
+ module.dump();
+ if (result.wasInterrupted()) {
+ signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<ModuleOp> > CreateBufferizePass() {
+ return std::make_unique<BufferizePass>();
+}
+
+} // namespace transforms
+} // namespace kernel_gen
+} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
index 13f367c..e65d840 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h
@@ -41,7 +41,11 @@
// Pass to tranform shape computations in shape dialect to standard and scf
// using memref descriptors.
-std::unique_ptr<Pass> CreateShapeToDescriptorsPass();
+std::unique_ptr<OperationPass<ModuleOp> > CreateShapeToDescriptorsPass();
+
+// Pass to tranform computations on values to their corresponding parts on
+// buffers.
+std::unique_ptr<OperationPass<ModuleOp> > CreateBufferizePass();
} // namespace transforms
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
index 6172067..6a0e328 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td
@@ -34,4 +34,9 @@
let constructor = "transforms::CreateShapeToDescriptorsPass()";
}
+def BufferizePass : Pass<"test-bufferize", "ModuleOp"> {
+ let summary = "Pass to transform operations on values to buffer based ones";
+ let constructor = "transforms::CreateBufferizePass()";
+}
+
#endif // TF_FRAMEWORK_PASSES
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h
index 257e84b..4efc1e9 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h
@@ -20,6 +20,7 @@
namespace mlir {
+class BufferAssignmentPlacer;
class LLVMTypeConverter;
class MLIRContext;
class OwningRewritePatternList;
@@ -37,6 +38,16 @@
MLIRContext *context, OwningRewritePatternList *patterns);
} // namespace tf_framework
+
+namespace transforms {
+
+/// Collects a set of patterns that bufferize operations from the standard
+/// dialect.
+void populateStandardBufferizePattern(MLIRContext *context,
+ BufferAssignmentPlacer *bufferAssignment,
+ TypeConverter *converter,
+ OwningRewritePatternList *patterns);
+} // namespace transforms
} // namespace kernel_gen
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc
index 9c1b434..28d3647 100644
--- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc
+++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc
@@ -26,6 +26,7 @@
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
namespace mlir {
namespace kernel_gen {