Add canonicalization pattern for HLO_ScalarsToDimensionTensor.

PiperOrigin-RevId: 297848928
Change-Id: I4b2413c6105bf34b56d12aba24b65f0a5f972211
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
index b011b60..f44bb9d 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc
@@ -32,11 +32,13 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MathExtras.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 #include "mlir/IR/Attributes.h"  // TF:llvm-project
 #include "mlir/IR/Builders.h"  // TF:llvm-project
 #include "mlir/IR/Dialect.h"  // TF:llvm-project
 #include "mlir/IR/Location.h"  // TF:llvm-project
 #include "mlir/IR/MLIRContext.h"  // TF:llvm-project
+#include "mlir/IR/Matchers.h"  // TF:llvm-project
 #include "mlir/IR/OpDefinition.h"  // TF:llvm-project
 #include "mlir/IR/OpImplementation.h"  // TF:llvm-project
 #include "mlir/IR/Operation.h"  // TF:llvm-project
@@ -46,6 +48,7 @@
 #include "mlir/IR/TypeUtilities.h"  // TF:llvm-project
 #include "mlir/IR/Types.h"  // TF:llvm-project
 #include "mlir/IR/Value.h"  // TF:llvm-project
+#include "mlir/Support/LLVM.h"  // TF:llvm-project
 #include "mlir/Support/LogicalResult.h"  // TF:llvm-project
 #include "mlir/Transforms/InliningUtils.h"  // TF:llvm-project
 #include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
@@ -480,6 +483,48 @@
 }
 
 //===----------------------------------------------------------------------===//
+// ScalarsToDimensionTensorOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+// Canonicalizes the pattern of the form
+//
+// %2 = "xla_hlo.scalars_to_dimension_tensor"(%0, %1)
+//          : (i32, i32) -> tensor<2xi32>
+// %3 = extract_element %2[%c0] : tensor<2xi32>
+//
+// to just %0.
+struct ExtractElementFromScalarsToDimensionTensor
+    : public OpRewritePattern<ExtractElementOp> {
+  using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(ExtractElementOp extract,
+                                     PatternRewriter& rewriter) const override {
+    if (extract.indices().size() != 1) return matchFailure();
+
+    if (auto scalars_to_tensor = dyn_cast_or_null<ScalarsToDimensionTensorOp>(
+            extract.aggregate().getDefiningOp())) {
+      APInt index;
+      if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) {
+        return matchFailure();
+      }
+      rewriter.replaceOp(extract,
+                         scalars_to_tensor.getOperand(index.getZExtValue()));
+      return matchSuccess();
+    }
+    return matchFailure();
+  }
+};
+
+}  // namespace
+
+void ScalarsToDimensionTensorOp::getCanonicalizationPatterns(
+    OwningRewritePatternList& results, MLIRContext* context) {
+  results.insert<ExtractElementFromScalarsToDimensionTensor>(context);
+}
+
+//===----------------------------------------------------------------------===//
 // DynamicBroadcastInDimOp
 //===----------------------------------------------------------------------===//
 
diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
index 42b42d9..8fe7bb9 100644
--- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
+++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td
@@ -784,11 +784,12 @@
     compute shape arguments to dynamic operations.
   }];
 
-  let arguments = (ins Variadic<AnySignlessInteger>);
+  let arguments = (ins Variadic<AnySignlessInteger>:$scalars);
   let results = (outs HLO_DimensionTensor);
 
   // Cannot be exported to legacy formats.
   let hasCustomHLOConverter = 1;
+  let hasCanonicalizer = 1;
 }
 
 def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim",
diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
index 2232063..b73cfcf 100644
--- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir
@@ -64,3 +64,13 @@
   %0 = "xla_hlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
   return %0 : tensor<2x2xf32>
 }
+
+// CHECK-LABEL: @extract_scalars_to_tensor
+// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
+func @extract_scalars_to_tensor(%arg0: i32, %arg1: i32) -> i32 {
+  %0 = "xla_hlo.scalars_to_dimension_tensor"(%arg0, %arg1) : (i32, i32) -> tensor<2xi32>
+  %1 = constant 0 : index
+  %2 = extract_element %0[%1] : tensor<2xi32>
+  // CHECK: return %[[ARG0]]
+  return %2 : i32
+}
diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir
index 7e2845d..ce70c78 100644
--- a/tensorflow/compiler/mlir/xla/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir
@@ -453,6 +453,14 @@
 
 // -----
 
+// CHECK-LABEL: @scalars_to_dimension_tensor
+func @scalars_to_dimension_tensor(%arg0: i32, %arg1: i32) -> tensor<2xi32> {
+  %0 = "xla_hlo.scalars_to_dimension_tensor"(%arg0, %arg1) : (i32, i32) -> tensor<2xi32>
+  return %0 : tensor<2xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @select
 func @select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
   %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>