[MHLO] Implement return type inference for GetTupleElementOp and TupleOp.
PiperOrigin-RevId: 439589720
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
index 6b26786..59d630e 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td
@@ -1055,7 +1055,8 @@
//===----------------------------------------------------------------------===//
// MHLO tuple op definitions.
//===----------------------------------------------------------------------===//
-def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]> {
+def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "GetTupleElement operator";
let description = [{
Returns a member of a tuple specified by an index.
@@ -1071,12 +1072,10 @@
let hasFolder = 1;
let hasVerifier = 1;
-
- let builders = [
- OpBuilder<(ins "Value":$value, "int32_t":$index)>];
}
-def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]> {
+def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "XLA's tuple op";
let description = [{
Groups a set of tensor inputs into a single tuple object.
@@ -1086,9 +1085,6 @@
let arguments = (ins Variadic<HLO_TensorOrTokenOrTuple>:$val);
let results = (outs HLO_Tuple);
- let builders = [
- OpBuilder<(ins "ValueRange":$values)>];
-
let hasCanonicalizer = 1;
let hasVerifier = 1;
}
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
index 3e1afd8..cceb5f4 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
@@ -5930,32 +5930,31 @@
// GetTupleElementOp
//===----------------------------------------------------------------------===//
-void GetTupleElementOp::build(OpBuilder& builder, OperationState& result,
- Value tuple, int32_t index) {
- if (auto tuple_type = tuple.getType().dyn_cast<TupleType>()) {
- auto element_type = tuple_type.getType(index);
- build(builder, result, element_type, tuple,
- builder.getI32IntegerAttr(index));
- return;
- }
+LogicalResult GetTupleElementOp::inferReturnTypes(
+ MLIRContext*, Optional<Location>, ValueRange operands,
+ DictionaryAttr attributes, RegionRange,
+ SmallVectorImpl<Type>& inferredReturnTypes) {
+ auto tuple_type = operands[0].getType().dyn_cast<TupleType>();
+ if (!tuple_type) return failure();
- build(builder, result, tuple.getType(), tuple,
- builder.getI32IntegerAttr(index));
+ auto index_attr = attributes.get("index").cast<IntegerAttr>();
+ auto index = index_attr.getInt();
+ if (index < 0 || index >= tuple_type.size()) return failure();
+
+ inferredReturnTypes.push_back(tuple_type.getType(index));
+ return success();
}
//===----------------------------------------------------------------------===//
// TupleOp
//===----------------------------------------------------------------------===//
-void TupleOp::build(OpBuilder& builder, OperationState& result,
- ValueRange values) {
- SmallVector<Type, 4> types;
- types.reserve(values.size());
- for (auto val : values) {
- types.push_back(val.getType());
- }
-
- build(builder, result, builder.getTupleType(types), values);
+LogicalResult TupleOp::inferReturnTypes(
+ MLIRContext* context, Optional<Location>, ValueRange operands,
+ DictionaryAttr attributes, RegionRange,
+ SmallVectorImpl<Type>& inferredReturnTypes) {
+ inferredReturnTypes.push_back(TupleType::get(context, TypeRange(operands)));
+ return success();
}
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index 9322c8d..6821fe5 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -46,7 +46,7 @@
_version = 63
# Version number for MLIR:Python components.
-mlir_api_version = 5
+mlir_api_version = 6
xla_platform_names = {
'cpu': 'Host',