Add an optimization that converts some Gathers to Slices.

Some Gathers can be represented as slices. This lowering transforms
these gathers into slices.

PiperOrigin-RevId: 321394868
Change-Id: I905a235e951bf1034a31cc89a86126e830e15495
diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD
index 5cbf305..bc6393f 100644
--- a/tensorflow/compiler/mlir/hlo/BUILD
+++ b/tensorflow/compiler/mlir/hlo/BUILD
@@ -610,6 +610,7 @@
         "lib/Dialect/mhlo/transforms/generated_lower_complex.inc",
         "lib/Dialect/mhlo/transforms/lower_complex.cc",
         "lib/Dialect/mhlo/transforms/lower_general_dot.cc",
+        "lib/Dialect/mhlo/transforms/optimize_mhlo.cc",
     ],
     hdrs = [
         "include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
@@ -681,6 +682,7 @@
         "lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc",
         "lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc",
         "lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc",
+        "lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc",
         "lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc",
         "lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc",
     ],
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
index cb9a85a..f3f4405 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
@@ -38,6 +38,9 @@
 void PopulateComplexLoweringPatterns(MLIRContext *context,
                                      OwningRewritePatternList *patterns);
 
+void PopulateOptimizeMHLOPatterns(MLIRContext *context,
+                                  OwningRewritePatternList *patterns);
+
 void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
                                MLIRContext *ctx);
 
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo.cc
new file mode 100644
index 0000000..dfed951
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo.cc
@@ -0,0 +1,187 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file provides optional optimization patterns for mhlo, canonocalizing
+// operations to equivalent but potentially more efficient operations.
+
+#include <cstddef>
+#include <cstdint>
+#include <iterator>
+#include <numeric>
+
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/IR/Attributes.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/Operation.h"  // from @llvm-project
+#include "mlir/IR/PatternMatch.h"  // from @llvm-project
+#include "mlir/IR/TypeUtilities.h"  // from @llvm-project
+#include "mlir/IR/Types.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "mlir/Pass/PassRegistry.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
+
+using mlir::OwningRewritePatternList;
+
+namespace mlir {
+namespace mhlo {
+namespace {
+
+// Returns 1D 64-bit dense elements attribute with the given values.
+static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
+                                               Builder* builder) {
+  RankedTensorType ty = RankedTensorType::get(
+      {static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
+  return DenseIntElementsAttr::get(ty, values);
+}
+
+//===----------------------------------------------------------------------===//
+// GatherOp
+//===----------------------------------------------------------------------===//
+
+class GatherIsSlice : public OpRewritePattern<GatherOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(GatherOp gather,
+                                PatternRewriter& rewriter) const override {
+    auto dimension_numbers = gather.dimension_numbers();
+
+    // Inputs need to be ranked to lower.
+    if (!gather.operand().getType().cast<ShapedType>().hasRank() ||
+        !gather.operand().getType().cast<ShapedType>().hasStaticShape() ||
+        !gather.start_indices().getType().cast<ShapedType>().hasRank() ||
+        !gather.start_indices().getType().cast<ShapedType>().hasStaticShape()) {
+      return failure();
+    }
+
+    if (dimension_numbers.index_vector_dim().getValue().getSExtValue() != 0) {
+      return failure();
+    }
+
+    // TODO(suderman): Handle start index map != {0}.
+    if (!dimension_numbers.start_index_map() ||
+        dimension_numbers.start_index_map().getType().getRank() != 1 ||
+        dimension_numbers.start_index_map().getType().getDimSize(0) != 1 ||
+        dimension_numbers.start_index_map()
+                .getValue({0})
+                .cast<IntegerAttr>()
+                .getValue() != 0) {
+      return failure();
+    }
+
+    auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
+
+    // Requires a ranked output.
+    if (!result_ty) {
+      return failure();
+    }
+    if (dimension_numbers.offset_dims().getType().getNumElements() !=
+        result_ty.getRank()) {
+      return failure();
+    }
+    for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
+      if (it.index() != it.value()) {
+        return failure();
+      }
+    }
+
+    // Verify the gather slice sizes are correct.
+    if (gather.slice_sizes().getNumElements() !=
+        gather.operand().getType().cast<ShapedType>().getRank()) {
+      return failure();
+    }
+
+    // Validate the slice sizes are correct.
+    if (gather.slice_sizes().getType().cast<ShapedType>().getNumElements() <
+        result_ty.getShape().size() + 1) {
+      return failure();
+    }
+
+    for (auto it : llvm::enumerate(result_ty.getShape())) {
+      if (gather.slice_sizes()
+              .getValue(it.index() + 1)
+              .cast<IntegerAttr>()
+              .getValue() != it.value()) {
+        return failure();
+      }
+    }
+
+    auto gather_start_indices = gather.start_indices();
+    auto gather_start_indices_ty =
+        gather_start_indices.getType().cast<ShapedType>();
+
+    llvm::SmallVector<Value, 4> slice_start_indices;
+
+    if (gather_start_indices_ty.getRank() == 0) {
+      slice_start_indices.push_back(gather_start_indices);
+    } else if (gather_start_indices_ty.getRank() == 1) {
+      for (int i = 0; i < gather_start_indices_ty.getDimSize(0); i++) {
+        auto start = GetI64ElementsAttr({i}, &rewriter);
+        auto limit = GetI64ElementsAttr({i + 1}, &rewriter);
+        auto stride = GetI64ElementsAttr({1}, &rewriter);
+        auto indicesSlice = rewriter.create<SliceOp>(
+            gather.getLoc(), gather_start_indices, start, limit, stride);
+        auto reshaped = rewriter.create<ReshapeOp>(
+            gather.getLoc(),
+            RankedTensorType::get(
+                {}, indicesSlice.getType().cast<ShapedType>().getElementType()),
+            indicesSlice);
+        slice_start_indices.push_back(reshaped);
+      }
+    } else {
+      return failure();
+    }
+
+    auto sliceSizes = gather.slice_sizes();
+    auto sliceSizesTy = sliceSizes.getType();
+    if (sliceSizesTy.getRank() != 1) {
+      return failure();
+    }
+
+    // Start indices have implicit zeros when not specified. This is because
+    // Gather occurs similar to slicing where full slices are inferred. Add any
+    // missing zeros as necessary.
+    auto zero = rewriter.create<ConstOp>(
+        gather.getLoc(), rewriter.getZeroAttr(RankedTensorType::get(
+                             {}, gather_start_indices_ty.getElementType())));
+    while (slice_start_indices.size() < sliceSizesTy.getDimSize(0)) {
+      slice_start_indices.push_back(zero);
+    }
+
+    SmallVector<int64_t, 5> sliceShape;
+    for (auto shapeValue : gather.slice_sizes().getIntValues()) {
+      sliceShape.push_back(shapeValue.getSExtValue());
+    }
+
+    auto sliceTy =
+        RankedTensorType::get(sliceShape, result_ty.getElementType());
+    auto slice = rewriter.create<DynamicSliceOp>(
+        gather.getLoc(), sliceTy, gather.operand(), slice_start_indices,
+        gather.slice_sizes());
+
+    rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(), slice);
+
+    return success();
+  }
+};
+
+}  // end anonymous namespace
+
+void PopulateOptimizeMHLOPatterns(MLIRContext* context,
+                                  OwningRewritePatternList* patterns) {
+  patterns->insert<GatherIsSlice>(context);
+}
+}  // end namespace mhlo
+}  // end namespace mlir
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc
new file mode 100644
index 0000000..3d1f29e
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc
@@ -0,0 +1,49 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/Operation.h"  // from @llvm-project
+#include "mlir/IR/PatternMatch.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
+
+using mlir::FunctionPass;
+using mlir::PassRegistration;
+using mlir::PassWrapper;
+
+namespace {
+class OptimizeMhlo : public PassWrapper<OptimizeMhlo, FunctionPass> {
+ public:
+  explicit OptimizeMhlo() : PassWrapper<OptimizeMhlo, FunctionPass>() {}
+
+  /// Performs the lowering to MHLO dialect.
+  void runOnFunction() override;
+};
+}  // end anonymous namespace
+
+// Lowers the complex operations that can be represented using other operations.
+void OptimizeMhlo::runOnFunction() {
+  // Add lowering patterns to the list.
+  mlir::OwningRewritePatternList patterns;
+  mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns);
+
+  applyPatternsAndFoldGreedily(getFunction(), patterns);
+}
+
+static PassRegistration<OptimizeMhlo> pass("mhlo-test-optimize",
+                                           "Run optional HLO optimizations.");
diff --git a/tensorflow/compiler/mlir/hlo/tests/optimize-hlo.mlir b/tensorflow/compiler/mlir/hlo/tests/optimize-hlo.mlir
new file mode 100644
index 0000000..c20de0b
--- /dev/null
+++ b/tensorflow/compiler/mlir/hlo/tests/optimize-hlo.mlir
@@ -0,0 +1,64 @@
+// RUN: mlir-hlo-opt %s -pass-pipeline='func(mhlo-test-optimize)' | FileCheck %s
+
+// CHECK-LABEL: @gather_is_slice_no_rank
+func @gather_is_slice_no_rank(%arg0: tensor<2x1x2xi32>, %arg1: tensor<i64>) -> tensor<1x2xi32> {
+  // CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor<i64>
+  // CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, %arg1, [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
+  // CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"([[SLICE]])
+   %res = "mhlo.gather"(%arg0, %arg1) {
+    dimension_numbers = {
+      collapsed_slice_dims = dense<0> : tensor<1xi64>,
+      index_vector_dim = 0 : i64,
+      offset_dims = dense<[0, 1]> : tensor<2xi64>,
+      start_index_map = dense<0> : tensor<1xi64>
+    },
+    slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
+  } : (tensor<2x1x2xi32>, tensor<i64>) -> tensor<1x2xi32>
+
+  // CHECK: return [[RESHAPE]]
+  return %res : tensor<1x2xi32>
+}
+
+// CHECK-LABEL: @gather_is_slice
+func @gather_is_slice(%arg0: tensor<2x1x2xi32>, %arg1: tensor<1xi64>) -> tensor<1x2xi32> {
+   // CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor<i64>
+   // CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"(%arg1)
+   // CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE]], [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
+   // CHECK: [[RES:%.+]] = "mhlo.reshape"([[SLICE]])
+
+   %res = "mhlo.gather"(%arg0, %arg1) {
+    dimension_numbers = {
+      collapsed_slice_dims = dense<0> : tensor<1xi64>,
+      index_vector_dim = 0 : i64,
+      offset_dims = dense<[0, 1]> : tensor<2xi64>,
+      start_index_map = dense<0> : tensor<1xi64>
+    },
+    slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
+  } : (tensor<2x1x2xi32>, tensor<1xi64>) -> tensor<1x2xi32>
+
+  // CHECK: return [[RES]]
+  return %res : tensor<1x2xi32>
+}
+
+// CHECK-LABEL: @gather_is_slice_multiple_start_indices
+func @gather_is_slice_multiple_start_indices(%arg0: tensor<2x1x2xi32>, %arg1: tensor<2xi64>) -> tensor<1x2xi32> {
+  // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0>
+  // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+  // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]])
+  // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
+  // CHECK-DAG: [[RESHAPE2:%.+]] = "mhlo.reshape"([[SLICE2]])
+  // CHECK-DAG: [[DSLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE1]], [[RESHAPE2]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
+  // CHECK-DAG: [[RES:%.+]] = "mhlo.reshape"([[DSLICE]])
+   %res = "mhlo.gather"(%arg0, %arg1) {
+    dimension_numbers = {
+      collapsed_slice_dims = dense<0> : tensor<1xi64>,
+      index_vector_dim = 0 : i64,
+      offset_dims = dense<[0, 1]> : tensor<2xi64>,
+      start_index_map = dense<0> : tensor<1xi64>
+    },
+    slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
+  } : (tensor<2x1x2xi32>, tensor<2xi64>) -> tensor<1x2xi32>
+
+  // CHECK: return [[RES]]
+  return %res : tensor<1x2xi32>
+}