Support ResourceGather in freeze graph in 2.0.
PiperOrigin-RevId: 265497869
diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py
index e4412aa..c946533 100644
--- a/tensorflow/lite/python/lite_v2_test.py
+++ b/tensorflow/lite/python/lite_v2_test.py
@@ -31,6 +31,7 @@
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
@@ -250,6 +251,42 @@
self.assertLess(len(quantized_tflite), len(float_tflite))
@test_util.run_v2_only
+ def testEmbeddings(self):
+ """Test model with embeddings."""
+ input_data = constant_op.constant(
+ np.array(np.random.random_sample((20)), dtype=np.int32))
+
+ class EmbeddingModel(keras.Model):
+
+ def __init__(self):
+ super(EmbeddingModel, self).__init__()
+ self.shared_weights = self.add_weight(
+ 'weights',
+ shape=(2000, 300),
+ dtype=dtypes.float32,
+ initializer=init_ops.random_normal_initializer(
+ mean=0.0, stddev=300**(-0.5)))
+
+ @def_function.function(input_signature=[
+ tensor_spec.TensorSpec(shape=(20), dtype=dtypes.int32)
+ ])
+ def func(self, x):
+ return array_ops.gather(self.shared_weights, x)
+
+ # Building the model.
+ root = EmbeddingModel()
+ concrete_func = root.func.get_concrete_function()
+
+ # Convert model.
+ converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
+ tflite_model = converter.convert()
+
+ # Check values from converted model.
+ expected_value = root.func(input_data)
+ actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
+ np.testing.assert_almost_equal(expected_value.numpy(), actual_value, 5)
+
+ @test_util.run_v2_only
def testGraphDebugInfo(self):
"""Test a concrete function has debug info captured."""
root = tracking.AutoTrackable()
diff --git a/tensorflow/python/framework/convert_to_constants.py b/tensorflow/python/framework/convert_to_constants.py
index c6efc85..c40ffcd 100644
--- a/tensorflow/python/framework/convert_to_constants.py
+++ b/tensorflow/python/framework/convert_to_constants.py
@@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import tensor_shape_pb2
@@ -397,7 +399,6 @@
Returns:
ConcreteFunction containing a simplified version of the original.
"""
- # TODO(nupurgarg): Replace ResourceGather with Gather.
# Inline the graph in order to remove functions when possible.
graph_def = _run_inline_graph_optimization(func, lower_control_flow)
@@ -467,10 +468,10 @@
# Get dtype and data for non-variable Placeholders (ex. values for 1.X
# Const ops that are loaded as Placeholders in 2.0)
_save_placeholder(node.name, node.attr["dtype"])
- elif node.op == "ReadVariableOp":
- # Get dtype and data for Placeholder ops associated with ReadVariableOp.
- # There can be an Identity in between the ReadVariableOp and Placeholder.
- # Store the dtype for the Identity ops.
+ elif node.op in ["ReadVariableOp", "ResourceGather"]:
+ # Get dtype and data for Placeholder ops associated with ReadVariableOp
+ # and ResourceGather ops. There can be an Identity in between the
+ # resource op and Placeholder. Store the dtype for the Identity ops.
input_name = _get_tensor_name(node.input[0])
while name_to_node[input_name].op == "Identity":
resource_identities[input_name] = node.attr["dtype"]
@@ -503,6 +504,26 @@
# Convert ReadVariableOps to Identity ops.
elif input_node.op == "ReadVariableOp":
_populate_identity_op(output_node, input_node)
+ # Convert ResourceGather to Gather ops with a Const axis feeding into it.
+ elif input_node.op == "ResourceGather":
+ if input_node.attr["batch_dims"].i != 0:
+ raise ValueError("batch_dims != 0 is not supported by freeze_graph.")
+ output_axis_node = output_graph_def.node.add()
+ axis_node_name = input_node.name + "/axis"
+ axis_dtype = input_node.attr["Tindices"]
+ axis_data = np.array(input_node.attr["batch_dims"].i)
+ _populate_const_op(output_axis_node, axis_node_name, axis_dtype,
+ axis_data, axis_data.shape)
+
+ output_node.op = "GatherV2"
+ output_node.name = input_node.name
+ output_node.input.extend(
+ [input_node.input[0], input_node.input[1], axis_node_name])
+ output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"])
+ output_node.attr["Tindices"].CopyFrom(input_node.attr["Tindices"])
+ output_node.attr["Taxis"].CopyFrom(axis_dtype)
+ if "_class" in input_node.attr:
+ output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
# Update the function names and argument types for the conditional ops.
elif input_node.op in _CONDITIONAL_OPS:
_populate_if_op(output_node, input_node, function_data)
diff --git a/tensorflow/python/framework/convert_to_constants_test.py b/tensorflow/python/framework/convert_to_constants_test.py
index cbe8528..315fe23 100644
--- a/tensorflow/python/framework/convert_to_constants_test.py
+++ b/tensorflow/python/framework/convert_to_constants_test.py
@@ -33,6 +33,7 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
@@ -433,6 +434,36 @@
root, output_func = self._freezeModel(to_save)
self._testConvertedFunction(root, root.f, output_func, input_data)
+ @test_util.run_v2_only
+ def testEmbeddings(self):
+ """Test model with embeddings."""
+ input_data = {
+ "x":
+ constant_op.constant(
+ np.array(np.random.random_sample((20)), dtype=np.int32))
+ }
+
+ class EmbeddingModel(keras.Model):
+
+ def __init__(self):
+ super(EmbeddingModel, self).__init__()
+ self.shared_weights = self.add_weight(
+ "weights",
+ shape=(2000, 300),
+ dtype=dtypes.float32,
+ initializer=init_ops.random_normal_initializer(
+ mean=0.0, stddev=300**(-0.5)))
+
+ @def_function.function(input_signature=[
+ tensor_spec.TensorSpec(shape=(20), dtype=dtypes.int32)
+ ])
+ def func(self, x):
+ return array_ops.gather(self.shared_weights, x)
+
+ model = EmbeddingModel()
+ root, output_func = self._freezeModel(model.func)
+ self._testConvertedFunction(root, root.f, output_func, input_data)
+
if __name__ == "__main__":
test.main()