[tf][tfg] Fix: Importer checks for null control values
Missing control values during import generic function weren't being caught.
PiperOrigin-RevId: 427065845
Change-Id: I3aa808445f39fae636bf45d5bcc24f148fb713bd
diff --git a/tensorflow/core/ir/importexport/functiondef_import.cc b/tensorflow/core/ir/importexport/functiondef_import.cc
index 2fd178e..527f8da 100644
--- a/tensorflow/core/ir/importexport/functiondef_import.cc
+++ b/tensorflow/core/ir/importexport/functiondef_import.cc
@@ -446,8 +446,8 @@
// We pre-allocate the array of operands and populate it using the
// `output_name_to_position` and `control_output_to_position` populated
// previously.
- SmallVector<Value> ret_vals(func.ret_size(), Value());
- SmallVector<Value> ret_ctls(func.control_ret_size(), Value());
+ SmallVector<Value> ret_vals(func.ret_size() + func.control_ret_size(),
+ Value());
for (const auto& ret_val : func.ret()) {
auto position = output_name_to_position.find(ret_val.first);
if (position == output_name_to_position.end())
@@ -470,7 +470,7 @@
if (!result.getType().isa<ControlType>())
return InvalidArgument("failed to map returned value ", ret_val.second,
", isn't a control output");
- ret_ctls[position->second] = result;
+ ret_vals[func.ret_size() + position->second] = result;
}
// Check that all the of the return operands have been populated.
for (auto& indexed_val : llvm::enumerate(ret_vals)) {
@@ -479,13 +479,15 @@
"Failed to import function, missing output for position ",
indexed_val.index());
}
- ReturnOp ret_op =
- body_builder.create<ReturnOp>(unknown_loc, ret_vals, ret_ctls);
+ MutableArrayRef<Value> operands = ret_vals;
+ ReturnOp ret_op = body_builder.create<ReturnOp>(
+ unknown_loc, operands.slice(0, func.ret_size()),
+ operands.slice(func.ret_size()));
// Now that we have all the types, set the function signature as the "type"
// attribute.
{
- llvm::SmallVector<Type> arg_types_with_ctl;
+ SmallVector<Type> arg_types_with_ctl;
for (Type type : arg_types) {
arg_types_with_ctl.push_back(type);
arg_types_with_ctl.push_back(control_ty);