Support ResourceGather in freeze graph in 2.0.

PiperOrigin-RevId: 265497869
diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py
index e4412aa..c946533 100644
--- a/tensorflow/lite/python/lite_v2_test.py
+++ b/tensorflow/lite/python/lite_v2_test.py
@@ -31,6 +31,7 @@
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import variables
@@ -250,6 +251,42 @@
     self.assertLess(len(quantized_tflite), len(float_tflite))
 
   @test_util.run_v2_only
+  def testEmbeddings(self):
+    """Test model with embeddings."""
+    input_data = constant_op.constant(
+        np.array(np.random.random_sample((20)), dtype=np.int32))
+
+    class EmbeddingModel(keras.Model):
+
+      def __init__(self):
+        super(EmbeddingModel, self).__init__()
+        self.shared_weights = self.add_weight(
+            'weights',
+            shape=(2000, 300),
+            dtype=dtypes.float32,
+            initializer=init_ops.random_normal_initializer(
+                mean=0.0, stddev=300**(-0.5)))
+
+      @def_function.function(input_signature=[
+          tensor_spec.TensorSpec(shape=(20), dtype=dtypes.int32)
+      ])
+      def func(self, x):
+        return array_ops.gather(self.shared_weights, x)
+
+    # Building the model.
+    root = EmbeddingModel()
+    concrete_func = root.func.get_concrete_function()
+
+    # Convert model.
+    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
+    tflite_model = converter.convert()
+
+    # Check values from converted model.
+    expected_value = root.func(input_data)
+    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
+    np.testing.assert_almost_equal(expected_value.numpy(), actual_value, 5)
+
+  @test_util.run_v2_only
   def testGraphDebugInfo(self):
     """Test a concrete function has debug info captured."""
     root = tracking.AutoTrackable()
diff --git a/tensorflow/python/framework/convert_to_constants.py b/tensorflow/python/framework/convert_to_constants.py
index c6efc85..c40ffcd 100644
--- a/tensorflow/python/framework/convert_to_constants.py
+++ b/tensorflow/python/framework/convert_to_constants.py
@@ -18,6 +18,8 @@
 from __future__ import division
 from __future__ import print_function
 
+import numpy as np
+
 from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.framework import graph_pb2
 from tensorflow.core.framework import tensor_shape_pb2
@@ -397,7 +399,6 @@
   Returns:
     ConcreteFunction containing a simplified version of the original.
   """
-  # TODO(nupurgarg): Replace ResourceGather with Gather.
   # Inline the graph in order to remove functions when possible.
   graph_def = _run_inline_graph_optimization(func, lower_control_flow)
 
@@ -467,10 +468,10 @@
       # Get dtype and data for non-variable Placeholders (ex. values for 1.X
       # Const ops that are loaded as Placeholders in 2.0)
       _save_placeholder(node.name, node.attr["dtype"])
-    elif node.op == "ReadVariableOp":
-      # Get dtype and data for Placeholder ops associated with ReadVariableOp.
-      # There can be an Identity in between the ReadVariableOp and Placeholder.
-      # Store the dtype for the Identity ops.
+    elif node.op in ["ReadVariableOp", "ResourceGather"]:
+      # Get dtype and data for Placeholder ops associated with ReadVariableOp
+      # and ResourceGather ops. There can be an Identity in between the
+      # resource op and Placeholder. Store the dtype for the Identity ops.
       input_name = _get_tensor_name(node.input[0])
       while name_to_node[input_name].op == "Identity":
         resource_identities[input_name] = node.attr["dtype"]
@@ -503,6 +504,26 @@
     # Convert ReadVariableOps to Identity ops.
     elif input_node.op == "ReadVariableOp":
       _populate_identity_op(output_node, input_node)
+    # Convert ResourceGather to Gather ops with a Const axis feeding into it.
+    elif input_node.op == "ResourceGather":
+      if input_node.attr["batch_dims"].i != 0:
+        raise ValueError("batch_dims != 0 is not supported by freeze_graph.")
+      output_axis_node = output_graph_def.node.add()
+      axis_node_name = input_node.name + "/axis"
+      axis_dtype = input_node.attr["Tindices"]
+      axis_data = np.array(input_node.attr["batch_dims"].i)
+      _populate_const_op(output_axis_node, axis_node_name, axis_dtype,
+                         axis_data, axis_data.shape)
+
+      output_node.op = "GatherV2"
+      output_node.name = input_node.name
+      output_node.input.extend(
+          [input_node.input[0], input_node.input[1], axis_node_name])
+      output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"])
+      output_node.attr["Tindices"].CopyFrom(input_node.attr["Tindices"])
+      output_node.attr["Taxis"].CopyFrom(axis_dtype)
+      if "_class" in input_node.attr:
+        output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
     # Update the function names and argument types for the conditional ops.
     elif input_node.op in _CONDITIONAL_OPS:
       _populate_if_op(output_node, input_node, function_data)
diff --git a/tensorflow/python/framework/convert_to_constants_test.py b/tensorflow/python/framework/convert_to_constants_test.py
index cbe8528..315fe23 100644
--- a/tensorflow/python/framework/convert_to_constants_test.py
+++ b/tensorflow/python/framework/convert_to_constants_test.py
@@ -33,6 +33,7 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import cond_v2
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import rnn
 from tensorflow.python.ops import rnn_cell_impl
@@ -433,6 +434,36 @@
     root, output_func = self._freezeModel(to_save)
     self._testConvertedFunction(root, root.f, output_func, input_data)
 
+  @test_util.run_v2_only
+  def testEmbeddings(self):
+    """Test model with embeddings."""
+    input_data = {
+        "x":
+            constant_op.constant(
+                np.array(np.random.random_sample((20)), dtype=np.int32))
+    }
+
+    class EmbeddingModel(keras.Model):
+
+      def __init__(self):
+        super(EmbeddingModel, self).__init__()
+        self.shared_weights = self.add_weight(
+            "weights",
+            shape=(2000, 300),
+            dtype=dtypes.float32,
+            initializer=init_ops.random_normal_initializer(
+                mean=0.0, stddev=300**(-0.5)))
+
+      @def_function.function(input_signature=[
+          tensor_spec.TensorSpec(shape=(20), dtype=dtypes.int32)
+      ])
+      def func(self, x):
+        return array_ops.gather(self.shared_weights, x)
+
+    model = EmbeddingModel()
+    root, output_func = self._freezeModel(model.func)
+    self._testConvertedFunction(root, root.f, output_func, input_data)
+
 
 if __name__ == "__main__":
   test.main()