Fix variable handle input shape in function deserialization.
Previously, it was using the _input_shape attribute, but this always returned a scalar shape for resource handles, causing shape inference errors like:
```
ValueError: indices.shape[-1] must be <= params.rank, but saw indices shape: [?,2,2]
and params shape: [] for '{{node lookup/lookup_call}} =
ResourceGatherNd[Tindices=DT_INT32, _output_shapes=[[?,2]], dtype=DT_FLOAT]
(lookup_lookup_call_resource:0, inputs:0)' with input shapes: [], [?,2,2].
```
PiperOrigin-RevId: 424144483
Change-Id: I283c109a3502a0ecf3984ff57fa3f74d4508e612
diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD
index 09343c3..f185da8 100644
--- a/tensorflow/python/framework/BUILD
+++ b/tensorflow/python/framework/BUILD
@@ -413,10 +413,14 @@
":graph_to_function_def",
":op_def_library",
":ops",
+ ":tensor_shape",
+ ":tensor_spec",
":test_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/eager:def_function",
+ "//tensorflow/python/eager:function",
],
)
diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py
index d92e94f..f746bb2 100644
--- a/tensorflow/python/framework/function_def_to_graph.py
+++ b/tensorflow/python/framework/function_def_to_graph.py
@@ -66,7 +66,18 @@
if input_shapes is None:
input_shapes_attr = fdef.attr.get("_input_shapes", None)
if input_shapes_attr is not None:
- input_shapes = input_shapes_attr.list.shape
+ raw_input_shapes = input_shapes_attr.list.shape
+
+ # Replace resource handle shapes, since they are always stored as a scalar
+ # shape in the _input_shapes attribute.
+ input_shapes = []
+ for input_shape, arg_def in zip(raw_input_shapes,
+ fdef.signature.input_arg):
+ if arg_def.type == types_pb2.DT_RESOURCE and arg_def.handle_data:
+ input_shapes.append(arg_def.handle_data[0].shape)
+ else:
+ input_shapes.append(input_shape)
+
graph_def, nested_to_flat_tensor_name = function_def_to_graph_def(
fdef, input_shapes)
diff --git a/tensorflow/python/framework/function_def_to_graph_test.py b/tensorflow/python/framework/function_def_to_graph_test.py
index 7d1bec4..765dfec 100644
--- a/tensorflow/python/framework/function_def_to_graph_test.py
+++ b/tensorflow/python/framework/function_def_to_graph_test.py
@@ -14,6 +14,7 @@
# ==============================================================================
"""Tests for tensorflow.python.framework.function_def_to_graph."""
+from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -22,6 +23,7 @@
from tensorflow.python.framework import op_def_library
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@@ -95,6 +97,26 @@
g = function_def_to_graph.function_def_to_graph(
fdef, input_shapes=[tensor_shape.TensorShape([5, 7])])
+ def testResourceHandleInputShapes(self):
+ # Test that shape inference with resource handles work as expected.
+
+ # Create a graph to generate the input and handle shape attributes in the
+ # FunctionDef.
+ with ops.Graph().as_default() as g:
+ v = variables.Variable(array_ops.ones((2, 3), dtype=dtypes.float32))
+
+ @def_function.function(
+ input_signature=[tensor_spec.TensorSpec((None, 2, 2), dtypes.int32)])
+ def lookup(inp):
+ return array_ops.gather_nd(v, inp)
+
+ lookup.get_concrete_function().add_to_graph()
+ fdef = g.as_graph_def(add_shapes=True).library.function[0]
+
+ fg = function_def_to_graph.function_def_to_graph(fdef)
+ self.assertSequenceEqual(fg.inputs[0].shape.as_list(), [None, 2, 2])
+ self.assertSequenceEqual(fg.inputs[1].shape.dims, [2, 3])
+
class FunctionDefToGraphDefTest(test.TestCase):