Add an optimization that converts some Gathers to Slices.
Some Gathers can be represented as slices. This lowering transforms
these gathers into slices.
PiperOrigin-RevId: 321394868
Change-Id: I905a235e951bf1034a31cc89a86126e830e15495
diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD
index 5cbf305..bc6393f 100644
--- a/tensorflow/compiler/mlir/hlo/BUILD
+++ b/tensorflow/compiler/mlir/hlo/BUILD
@@ -610,6 +610,7 @@
"lib/Dialect/mhlo/transforms/generated_lower_complex.inc",
"lib/Dialect/mhlo/transforms/lower_complex.cc",
"lib/Dialect/mhlo/transforms/lower_general_dot.cc",
+ "lib/Dialect/mhlo/transforms/optimize_mhlo.cc",
],
hdrs = [
"include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
@@ -681,6 +682,7 @@
"lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc",
"lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc",
"lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc",
+ "lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc",
"lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc",
"lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc",
],
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
index cb9a85a..f3f4405 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
@@ -38,6 +38,9 @@
void PopulateComplexLoweringPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
+void PopulateOptimizeMHLOPatterns(MLIRContext *context,
+ OwningRewritePatternList *patterns);
+
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
MLIRContext *ctx);
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo.cc
new file mode 100644
index 0000000..dfed951
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo.cc
@@ -0,0 +1,187 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file provides optional optimization patterns for mhlo, canonocalizing
+// operations to equivalent but potentially more efficient operations.
+
+#include <cstddef>
+#include <cstdint>
+#include <iterator>
+#include <numeric>
+
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/IR/Attributes.h" // from @llvm-project
+#include "mlir/IR/MLIRContext.h" // from @llvm-project
+#include "mlir/IR/Operation.h" // from @llvm-project
+#include "mlir/IR/PatternMatch.h" // from @llvm-project
+#include "mlir/IR/TypeUtilities.h" // from @llvm-project
+#include "mlir/IR/Types.h" // from @llvm-project
+#include "mlir/Pass/Pass.h" // from @llvm-project
+#include "mlir/Pass/PassRegistry.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
+
+using mlir::OwningRewritePatternList;
+
+namespace mlir {
+namespace mhlo {
+namespace {
+
+// Returns 1D 64-bit dense elements attribute with the given values.
+static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
+ Builder* builder) {
+ RankedTensorType ty = RankedTensorType::get(
+ {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
+ return DenseIntElementsAttr::get(ty, values);
+}
+
+//===----------------------------------------------------------------------===//
+// GatherOp
+//===----------------------------------------------------------------------===//
+
+class GatherIsSlice : public OpRewritePattern<GatherOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(GatherOp gather,
+ PatternRewriter& rewriter) const override {
+ auto dimension_numbers = gather.dimension_numbers();
+
+ // Inputs need to be ranked to lower.
+ if (!gather.operand().getType().cast<ShapedType>().hasRank() ||
+ !gather.operand().getType().cast<ShapedType>().hasStaticShape() ||
+ !gather.start_indices().getType().cast<ShapedType>().hasRank() ||
+ !gather.start_indices().getType().cast<ShapedType>().hasStaticShape()) {
+ return failure();
+ }
+
+ if (dimension_numbers.index_vector_dim().getValue().getSExtValue() != 0) {
+ return failure();
+ }
+
+ // TODO(suderman): Handle start index map != {0}.
+ if (!dimension_numbers.start_index_map() ||
+ dimension_numbers.start_index_map().getType().getRank() != 1 ||
+ dimension_numbers.start_index_map().getType().getDimSize(0) != 1 ||
+ dimension_numbers.start_index_map()
+ .getValue({0})
+ .cast<IntegerAttr>()
+ .getValue() != 0) {
+ return failure();
+ }
+
+ auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
+
+ // Requires a ranked output.
+ if (!result_ty) {
+ return failure();
+ }
+ if (dimension_numbers.offset_dims().getType().getNumElements() !=
+ result_ty.getRank()) {
+ return failure();
+ }
+ for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
+ if (it.index() != it.value()) {
+ return failure();
+ }
+ }
+
+ // Verify the gather slice sizes are correct.
+ if (gather.slice_sizes().getNumElements() !=
+ gather.operand().getType().cast<ShapedType>().getRank()) {
+ return failure();
+ }
+
+ // Validate the slice sizes are correct.
+ if (gather.slice_sizes().getType().cast<ShapedType>().getNumElements() <
+ result_ty.getShape().size() + 1) {
+ return failure();
+ }
+
+ for (auto it : llvm::enumerate(result_ty.getShape())) {
+ if (gather.slice_sizes()
+ .getValue(it.index() + 1)
+ .cast<IntegerAttr>()
+ .getValue() != it.value()) {
+ return failure();
+ }
+ }
+
+ auto gather_start_indices = gather.start_indices();
+ auto gather_start_indices_ty =
+ gather_start_indices.getType().cast<ShapedType>();
+
+ llvm::SmallVector<Value, 4> slice_start_indices;
+
+ if (gather_start_indices_ty.getRank() == 0) {
+ slice_start_indices.push_back(gather_start_indices);
+ } else if (gather_start_indices_ty.getRank() == 1) {
+ for (int i = 0; i < gather_start_indices_ty.getDimSize(0); i++) {
+ auto start = GetI64ElementsAttr({i}, &rewriter);
+ auto limit = GetI64ElementsAttr({i + 1}, &rewriter);
+ auto stride = GetI64ElementsAttr({1}, &rewriter);
+ auto indicesSlice = rewriter.create<SliceOp>(
+ gather.getLoc(), gather_start_indices, start, limit, stride);
+ auto reshaped = rewriter.create<ReshapeOp>(
+ gather.getLoc(),
+ RankedTensorType::get(
+ {}, indicesSlice.getType().cast<ShapedType>().getElementType()),
+ indicesSlice);
+ slice_start_indices.push_back(reshaped);
+ }
+ } else {
+ return failure();
+ }
+
+ auto sliceSizes = gather.slice_sizes();
+ auto sliceSizesTy = sliceSizes.getType();
+ if (sliceSizesTy.getRank() != 1) {
+ return failure();
+ }
+
+ // Start indices have implicit zeros when not specified. This is because
+ // Gather occurs similar to slicing where full slices are inferred. Add any
+ // missing zeros as necessary.
+ auto zero = rewriter.create<ConstOp>(
+ gather.getLoc(), rewriter.getZeroAttr(RankedTensorType::get(
+ {}, gather_start_indices_ty.getElementType())));
+ while (slice_start_indices.size() < sliceSizesTy.getDimSize(0)) {
+ slice_start_indices.push_back(zero);
+ }
+
+ SmallVector<int64_t, 5> sliceShape;
+ for (auto shapeValue : gather.slice_sizes().getIntValues()) {
+ sliceShape.push_back(shapeValue.getSExtValue());
+ }
+
+ auto sliceTy =
+ RankedTensorType::get(sliceShape, result_ty.getElementType());
+ auto slice = rewriter.create<DynamicSliceOp>(
+ gather.getLoc(), sliceTy, gather.operand(), slice_start_indices,
+ gather.slice_sizes());
+
+ rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(), slice);
+
+ return success();
+ }
+};
+
+} // end anonymous namespace
+
+void PopulateOptimizeMHLOPatterns(MLIRContext* context,
+ OwningRewritePatternList* patterns) {
+ patterns->insert<GatherIsSlice>(context);
+}
+} // end namespace mhlo
+} // end namespace mlir
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc
new file mode 100644
index 0000000..3d1f29e
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc
@@ -0,0 +1,49 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
+#include "mlir/IR/MLIRContext.h" // from @llvm-project
+#include "mlir/IR/Operation.h" // from @llvm-project
+#include "mlir/IR/PatternMatch.h" // from @llvm-project
+#include "mlir/Pass/Pass.h" // from @llvm-project
+#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
+
+using mlir::FunctionPass;
+using mlir::PassRegistration;
+using mlir::PassWrapper;
+
+namespace {
+class OptimizeMhlo : public PassWrapper<OptimizeMhlo, FunctionPass> {
+ public:
+ explicit OptimizeMhlo() : PassWrapper<OptimizeMhlo, FunctionPass>() {}
+
+ /// Performs the lowering to MHLO dialect.
+ void runOnFunction() override;
+};
+} // end anonymous namespace
+
+// Lowers the complex operations that can be represented using other operations.
+void OptimizeMhlo::runOnFunction() {
+ // Add lowering patterns to the list.
+ mlir::OwningRewritePatternList patterns;
+ mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns);
+
+ applyPatternsAndFoldGreedily(getFunction(), patterns);
+}
+
+static PassRegistration<OptimizeMhlo> pass("mhlo-test-optimize",
+ "Run optional HLO optimizations.");
diff --git a/tensorflow/compiler/mlir/hlo/tests/optimize-hlo.mlir b/tensorflow/compiler/mlir/hlo/tests/optimize-hlo.mlir
new file mode 100644
index 0000000..c20de0b
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/tests/optimize-hlo.mlir
@@ -0,0 +1,64 @@
+// RUN: mlir-hlo-opt %s -pass-pipeline='func(mhlo-test-optimize)' | FileCheck %s
+
+// CHECK-LABEL: @gather_is_slice_no_rank
+func @gather_is_slice_no_rank(%arg0: tensor<2x1x2xi32>, %arg1: tensor<i64>) -> tensor<1x2xi32> {
+ // CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor<i64>
+ // CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, %arg1, [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
+ // CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"([[SLICE]])
+ %res = "mhlo.gather"(%arg0, %arg1) {
+ dimension_numbers = {
+ collapsed_slice_dims = dense<0> : tensor<1xi64>,
+ index_vector_dim = 0 : i64,
+ offset_dims = dense<[0, 1]> : tensor<2xi64>,
+ start_index_map = dense<0> : tensor<1xi64>
+ },
+ slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
+ } : (tensor<2x1x2xi32>, tensor<i64>) -> tensor<1x2xi32>
+
+ // CHECK: return [[RESHAPE]]
+ return %res : tensor<1x2xi32>
+}
+
+// CHECK-LABEL: @gather_is_slice
+func @gather_is_slice(%arg0: tensor<2x1x2xi32>, %arg1: tensor<1xi64>) -> tensor<1x2xi32> {
+ // CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor<i64>
+ // CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"(%arg1)
+ // CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE]], [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
+ // CHECK: [[RES:%.+]] = "mhlo.reshape"([[SLICE]])
+
+ %res = "mhlo.gather"(%arg0, %arg1) {
+ dimension_numbers = {
+ collapsed_slice_dims = dense<0> : tensor<1xi64>,
+ index_vector_dim = 0 : i64,
+ offset_dims = dense<[0, 1]> : tensor<2xi64>,
+ start_index_map = dense<0> : tensor<1xi64>
+ },
+ slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
+ } : (tensor<2x1x2xi32>, tensor<1xi64>) -> tensor<1x2xi32>
+
+ // CHECK: return [[RES]]
+ return %res : tensor<1x2xi32>
+}
+
+// CHECK-LABEL: @gather_is_slice_multiple_start_indices
+func @gather_is_slice_multiple_start_indices(%arg0: tensor<2x1x2xi32>, %arg1: tensor<2xi64>) -> tensor<1x2xi32> {
+ // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0>
+ // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]])
+ // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+ // CHECK-DAG: [[RESHAPE2:%.+]] = "mhlo.reshape"([[SLICE2]])
+ // CHECK-DAG: [[DSLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE1]], [[RESHAPE2]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
+ // CHECK-DAG: [[RES:%.+]] = "mhlo.reshape"([[DSLICE]])
+ %res = "mhlo.gather"(%arg0, %arg1) {
+ dimension_numbers = {
+ collapsed_slice_dims = dense<0> : tensor<1xi64>,
+ index_vector_dim = 0 : i64,
+ offset_dims = dense<[0, 1]> : tensor<2xi64>,
+ start_index_map = dense<0> : tensor<1xi64>
+ },
+ slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
+ } : (tensor<2x1x2xi32>, tensor<2xi64>) -> tensor<1x2xi32>
+
+ // CHECK: return [[RES]]
+ return %res : tensor<1x2xi32>
+}