[MHLO] Implement return type inference for GetTupleElementOp and TupleOp.

PiperOrigin-RevId: 439589720
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
index 6b26786..59d630e 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
@@ -1055,7 +1055,8 @@
 //===----------------------------------------------------------------------===//
 // MHLO tuple op definitions.
 //===----------------------------------------------------------------------===//
-def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]> {
+def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect,
+     DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let summary = "GetTupleElement operator";
   let description = [{
     Returns a member of a tuple specified by an index.
@@ -1071,12 +1072,10 @@
 
   let hasFolder = 1;
   let hasVerifier = 1;
-
-  let builders = [
-    OpBuilder<(ins "Value":$value, "int32_t":$index)>];
 }
 
-def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]> {
+def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect,
+     DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let summary = "XLA's tuple op";
   let description = [{
      Groups a set of tensor inputs into a single tuple object.
@@ -1086,9 +1085,6 @@
   let arguments = (ins Variadic<HLO_TensorOrTokenOrTuple>:$val);
   let results = (outs HLO_Tuple);
 
-  let builders = [
-    OpBuilder<(ins "ValueRange":$values)>];
-
   let hasCanonicalizer = 1;
   let hasVerifier = 1;
 }
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index 3e1afd8..cceb5f4 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -5930,32 +5930,31 @@
 // GetTupleElementOp
 //===----------------------------------------------------------------------===//
 
-void GetTupleElementOp::build(OpBuilder& builder, OperationState& result,
-                              Value tuple, int32_t index) {
-  if (auto tuple_type = tuple.getType().dyn_cast<TupleType>()) {
-    auto element_type = tuple_type.getType(index);
-    build(builder, result, element_type, tuple,
-          builder.getI32IntegerAttr(index));
-    return;
-  }
+LogicalResult GetTupleElementOp::inferReturnTypes(
+    MLIRContext*, Optional<Location>, ValueRange operands,
+    DictionaryAttr attributes, RegionRange,
+    SmallVectorImpl<Type>& inferredReturnTypes) {
+  auto tuple_type = operands[0].getType().dyn_cast<TupleType>();
+  if (!tuple_type) return failure();
 
-  build(builder, result, tuple.getType(), tuple,
-        builder.getI32IntegerAttr(index));
+  auto index_attr = attributes.get("index").cast<IntegerAttr>();
+  auto index = index_attr.getInt();
+  if (index < 0 || index >= tuple_type.size()) return failure();
+
+  inferredReturnTypes.push_back(tuple_type.getType(index));
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
 // TupleOp
 //===----------------------------------------------------------------------===//
 
-void TupleOp::build(OpBuilder& builder, OperationState& result,
-                    ValueRange values) {
-  SmallVector<Type, 4> types;
-  types.reserve(values.size());
-  for (auto val : values) {
-    types.push_back(val.getType());
-  }
-
-  build(builder, result, builder.getTupleType(types), values);
+LogicalResult TupleOp::inferReturnTypes(
+    MLIRContext* context, Optional<Location>, ValueRange operands,
+    DictionaryAttr attributes, RegionRange,
+    SmallVectorImpl<Type>& inferredReturnTypes) {
+  inferredReturnTypes.push_back(TupleType::get(context, TypeRange(operands)));
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index 9322c8d..6821fe5 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -46,7 +46,7 @@
 _version = 63
 
 # Version number for MLIR:Python components.
-mlir_api_version = 5
+mlir_api_version = 6
 
 xla_platform_names = {
     'cpu': 'Host',