Allow empty value for "_input_shapes" attr of functions in graphdef library
PiperOrigin-RevId: 360279759
Change-Id: I14edacb9aad6d4b995ef231520acad2b752b029d
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-input-shapes.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-input-shapes.pbtxt
new file mode 100644
index 0000000..61ccf82
--- /dev/null
+++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-input-shapes.pbtxt
@@ -0,0 +1,39 @@
+# RUN: tf-mlir-translate -graphdef-to-mlir %s
+
+node {
+ name: "input"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+
+node {
+ name: "func0"
+ op: "func_name"
+ input: "input"
+}
+
+library {
+ function {
+ signature {
+ name: "func_name"
+ input_arg {
+ name: "arg0"
+ type: DT_BOOL
+ }
+ }
+ ret {
+ key: "retval0"
+ value: "arg0"
+ }
+ attr: {
+ key: "_input_shapes"
+ value: {
+ }
+ }
+ }
+}
diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
index d41dca5..ffe3e1e 100644
--- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
+++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc
@@ -1292,17 +1292,23 @@
if (name_and_value.first == "_input_shapes") {
auto& list = name_and_value.second.list();
auto& signature = func_def->signature();
- if (list.shape_size() != signature.input_arg_size()) {
+ // Some models have "_input_shapes" attribute, but with its value empty
+ if (list.shape_size() > 0 &&
+ list.shape_size() != signature.input_arg_size()) {
return errors::FailedPrecondition(
"Number of input arguments must be equal to the length of "
"_input_shapes attribute in function '",
StringRefToView(func_name), "'.");
}
- for (int i = 0; i < list.shape_size(); i++) {
+ for (int i = 0; i < signature.input_arg_size(); i++) {
auto& input_arg = signature.input_arg(i);
auto& array_info = specs.inputs[input_arg.name()];
array_info.imported_dtype = input_arg.type();
- array_info.shape = list.shape(i);
+ // set to unranked for empty "_input_shapes" attribute
+ if (list.shape_size() > 0)
+ array_info.shape = list.shape(i);
+ else
+ array_info.shape.set_unknown_rank(true);
}
}
}