Add support for input list in MLIR tracing API
PiperOrigin-RevId: 327081439
Change-Id: I354e1daae7de9c49ea9280446a3cee96ab544c25
diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc
index cb1f997..80b1f15 100644
--- a/tensorflow/c/eager/gradients_test.cc
+++ b/tensorflow/c/eager/gradients_test.cc
@@ -507,20 +507,18 @@
result_tensor = nullptr;
}
-// TODO(b/160888630): Enable this test with mlir after AddInputList is
-// supported. It is needed for IdentityN.
// TODO(b/164171226): Enable this test with tfrt after AddInputList is
// supported. It is needed for IdentityN.
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
- ::testing::Combine(::testing::Values("graphdef"),
+ ::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
- ::testing::Combine(::testing::Values("graphdef"),
+ ::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#endif
diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
index 343cbce..c62d62a 100644
--- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
+++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc
@@ -299,7 +299,7 @@
"op_type must be specified before specifying attrs.");
Type mlir_type;
Builder builder(context_);
- TF_RETURN_IF_ERROR(ConvertDataTypeToTensor(dtype, builder, &mlir_type));
+ TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &mlir_type));
attrs_[attr_name] = TypeAttr::get(mlir_type);
return Status::OK();
}
@@ -367,12 +367,12 @@
"' required for output '", output_arg.name(),
"' isn't a type attribute");
for (int i = 0; i < repeats; ++i)
- state_->types.push_back(type_attr.getType());
+ state_->types.push_back(UnrankedTensorType::get(type_attr.getType()));
} else if (output_arg.type() != tensorflow::DT_INVALID) {
for (int i = 0; i < repeats; ++i) {
Type type;
TF_RETURN_IF_ERROR(
- ConvertDataTypeToTensor(output_arg.type(), builder, &type));
+ ConvertDataType(output_arg.type(), builder, &type));
state_->types.push_back(type);
}
} else {
@@ -390,7 +390,7 @@
return InvalidArgument("Attribute '", output_arg.type_attr(),
"' required for output '", output_arg.name(),
"' isn't a type attribute");
- state_->types.push_back(type_attr.getValue());
+ state_->types.push_back(UnrankedTensorType::get(type_attr.getValue()));
} else if (!output_arg.type_list_attr().empty()) {
// This is pointing to an attribute which is an array of types.
Attribute attr = attrs_[output_arg.type_list_attr()];
@@ -410,13 +410,12 @@
output_arg.type_list_attr(),
"' required for output '", output_arg.name(),
"' has a non-Type element");
- state_->types.push_back(type_attr.getValue());
+ state_->types.push_back(UnrankedTensorType::get(type_attr.getValue()));
}
} else if (output_arg.type() != tensorflow::DT_INVALID) {
Type type;
Builder builder(context_);
- TF_RETURN_IF_ERROR(
- ConvertDataTypeToTensor(output_arg.type(), builder, &type));
+ TF_RETURN_IF_ERROR(ConvertDataType(output_arg.type(), builder, &type));
state_->types.push_back(type);
} else {
return InvalidArgument("No type fields in ",
@@ -446,12 +445,6 @@
return Status::OK();
}
-Status MlirAbstractOp::AddInputList(
- absl::Span<AbstractTensorHandle* const> inputs) {
- return tensorflow::errors::Unimplemented(
- "AddInputList has not been implemented yet.");
-}
-
Status MlirAbstractOp::SetAttrString(const char* attr_name, const char* data,
size_t length) {
return Unimplemented("SetAttrString has not been implemented yet.");
@@ -589,13 +582,65 @@
expected_type = output_type;
}
} else {
- expected_type = operands_.back().getType();
+ expected_type = cast<MlirTensor>(input)->getElementType();
}
if (!arg_def.type_attr().empty())
attrs_[arg_def.type_attr()] = TypeAttr::get(expected_type);
return Status::OK();
}
+
+Status MlirAbstractOp::AddInputList(
+ absl::Span<AbstractTensorHandle* const> inputs) {
+ if (current_ods_input_ >= op_def_->input_arg_size())
+ return InvalidArgument(
+ absl::StrCat("More Input() (", current_ods_input_, ") calls than the ",
+ op_def_->input_arg_size(), " allowed input_args"));
+
+ for (AbstractTensorHandle* input : inputs) {
+ auto* operand = dyn_cast<MlirTensor>(input);
+ if (!operand) return InvalidArgument("Unable to cast input to MlirTensor");
+ operands_.push_back(operand->getValue());
+ }
+
+ // Get the next ArgDef and use it to infer the derived attributes associated
+ // to this input.
+ const tensorflow::OpDef::ArgDef& arg_def =
+ op_def_->input_arg(current_ods_input_++);
+ if (!arg_def.number_attr().empty()) {
+ Builder builder(context_);
+ attrs_[arg_def.number_attr()] = builder.getI32IntegerAttr(inputs.size());
+ // TODO(aminim): handle ref variable.
+ if (arg_def.type() != tensorflow::DT_INVALID) {
+ // TODO(aminim): check type wrt input
+ Type arg_def_type;
+ TF_RETURN_IF_ERROR(
+ ConvertDataType(arg_def.type(), builder, &arg_def_type));
+ // Ensure each of the type in the list matches the op def type.
+ // TODO(aminim): can we improve the error message with the actual types?
+ for (AbstractTensorHandle* input : inputs)
+ if (arg_def_type != cast<MlirTensor>(input)->getElementType())
+ return InvalidArgument(
+ "Invalid input list: type mismatch the op def expectation");
+ } else if (!inputs.empty()) {
+ if (arg_def.type_attr().empty())
+ return FailedPrecondition(
+ "Invalid opdef type constraint: either type or type_attr required");
+
+ attrs_[arg_def.type_attr()] =
+ TypeAttr::get(cast<MlirTensor>(inputs.front())->getElementType());
+ }
+ } else if (!arg_def.type_list_attr().empty()) {
+ // TODO(aminim): handle ref variable.
+ SmallVector<Attribute, 8> types;
+ types.reserve(inputs.size());
+ for (AbstractTensorHandle* input : inputs)
+ types.push_back(TypeAttr::get(cast<MlirTensor>(input)->getElementType()));
+ attrs_[arg_def.type_list_attr()] = ArrayAttr::get(types, GetContext());
+ }
+ return Status::OK();
+}
+
Status MlirFunctionContext::Finalize(OutputList* outputs,
AbstractFunction** f) {
Block& body = func_.getBody().front();