Fix variable handle input shape in function deserialization.

Previously, it was using the _input_shape attribute, but this always returned a scalar shape for resource handles, causing shape inference errors like:

```
ValueError: indices.shape[-1] must be <= params.rank, but saw indices shape: [?,2,2]
 and params shape: [] for '{{node lookup/lookup_call}} =
ResourceGatherNd[Tindices=DT_INT32, _output_shapes=[[?,2]], dtype=DT_FLOAT]
(lookup_lookup_call_resource:0, inputs:0)' with input shapes: [], [?,2,2].
```
PiperOrigin-RevId: 424144483
Change-Id: I283c109a3502a0ecf3984ff57fa3f74d4508e612
diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD
index 09343c3..f185da8 100644
--- a/tensorflow/python/framework/BUILD
+++ b/tensorflow/python/framework/BUILD
@@ -413,10 +413,14 @@
         ":graph_to_function_def",
         ":op_def_library",
         ":ops",
+        ":tensor_shape",
+        ":tensor_spec",
         ":test_ops",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:math_ops",
+        "//tensorflow/python/eager:def_function",
+        "//tensorflow/python/eager:function",
     ],
 )
 
diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py
index d92e94f..f746bb2 100644
--- a/tensorflow/python/framework/function_def_to_graph.py
+++ b/tensorflow/python/framework/function_def_to_graph.py
@@ -66,7 +66,18 @@
   if input_shapes is None:
     input_shapes_attr = fdef.attr.get("_input_shapes", None)
     if input_shapes_attr is not None:
-      input_shapes = input_shapes_attr.list.shape
+      raw_input_shapes = input_shapes_attr.list.shape
+
+      # Replace resource handle shapes, since they are always stored as a scalar
+      # shape in the _input_shapes attribute.
+      input_shapes = []
+      for input_shape, arg_def in zip(raw_input_shapes,
+                                      fdef.signature.input_arg):
+        if arg_def.type == types_pb2.DT_RESOURCE and arg_def.handle_data:
+          input_shapes.append(arg_def.handle_data[0].shape)
+        else:
+          input_shapes.append(input_shape)
+
   graph_def, nested_to_flat_tensor_name = function_def_to_graph_def(
       fdef, input_shapes)
 
diff --git a/tensorflow/python/framework/function_def_to_graph_test.py b/tensorflow/python/framework/function_def_to_graph_test.py
index 7d1bec4..765dfec 100644
--- a/tensorflow/python/framework/function_def_to_graph_test.py
+++ b/tensorflow/python/framework/function_def_to_graph_test.py
@@ -14,6 +14,7 @@
 # ==============================================================================
 """Tests for tensorflow.python.framework.function_def_to_graph."""
 
+from tensorflow.python.eager import def_function
 from tensorflow.python.eager import function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -22,6 +23,7 @@
 from tensorflow.python.framework import op_def_library
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import test_ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
@@ -95,6 +97,26 @@
       g = function_def_to_graph.function_def_to_graph(
           fdef, input_shapes=[tensor_shape.TensorShape([5, 7])])
 
+  def testResourceHandleInputShapes(self):
+    # Test that shape inference with resource handles work as expected.
+
+    # Create a graph to generate the input and handle shape attributes in the
+    # FunctionDef.
+    with ops.Graph().as_default() as g:
+      v = variables.Variable(array_ops.ones((2, 3), dtype=dtypes.float32))
+
+      @def_function.function(
+          input_signature=[tensor_spec.TensorSpec((None, 2, 2), dtypes.int32)])
+      def lookup(inp):
+        return array_ops.gather_nd(v, inp)
+
+      lookup.get_concrete_function().add_to_graph()
+      fdef = g.as_graph_def(add_shapes=True).library.function[0]
+
+    fg = function_def_to_graph.function_def_to_graph(fdef)
+    self.assertSequenceEqual(fg.inputs[0].shape.as_list(), [None, 2, 2])
+    self.assertSequenceEqual(fg.inputs[1].shape.dims, [2, 3])
+
 
 class FunctionDefToGraphDefTest(test.TestCase):