Prepare identity, stop_gradient, map_fn, vectorized_map, and py_function for making ResourceVariables as CompositeTensors.
ResourceVariables are not CompositeTensors at this moment. Making the changes necessary for that to happen.
PiperOrigin-RevId: 456812746
diff --git a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py
index de88325..88e41ed 100644
--- a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py
@@ -1595,6 +1595,13 @@
e = array_ops.identity(d)
_test(d, e, "gpu")
+ def testIdentityVariable(self):
+ v = resource_variable_ops.ResourceVariable(1.0)
+ self.evaluate(v.initializer)
+ result = array_ops.identity(v)
+ self.assertIsInstance(result, ops.Tensor)
+ self.assertAllEqual(result, v)
+
class PadTest(test_util.TensorFlowTestCase):
@@ -2393,7 +2400,7 @@
self.assertAllEqual(t, tiled_tensor_1)
-class StopGradientTest(test_util.TensorFlowTestCase):
+class StopGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def testStopGradient(self):
x = array_ops.zeros(3)
@@ -2419,6 +2426,24 @@
self.assertIsNone(tape.gradient(y, x))
+ @parameterized.named_parameters([
+ ("TFFunction", def_function.function),
+ ("PythonFunction", lambda f: f),
+ ])
+ def test_stop_gradient_resource_variable(self, decorator):
+ x = resource_variable_ops.ResourceVariable([1.0])
+ self.evaluate(x.initializer)
+
+ @decorator
+ def stop_gradient_f(x):
+ return array_ops.stop_gradient(x)
+
+ with backprop.GradientTape() as tape:
+ y = stop_gradient_f(x)
+ self.assertIsNone(tape.gradient(y, x))
+ # stop_gradient converts ResourceVariable to Tensor
+ self.assertIsInstance(y, ops.Tensor)
+ self.assertAllEqual(y, x)
if __name__ == "__main__":
test_lib.main()
diff --git a/tensorflow/python/kernel_tests/control_flow/map_fn_test.py b/tensorflow/python/kernel_tests/control_flow/map_fn_test.py
index dc026b3..d4781a1 100644
--- a/tensorflow/python/kernel_tests/control_flow/map_fn_test.py
+++ b/tensorflow/python/kernel_tests/control_flow/map_fn_test.py
@@ -29,6 +29,7 @@
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_factory_ops
@@ -91,6 +92,18 @@
self.assertEqual([2, None], result.shape.as_list())
@test_util.run_in_graph_and_eager_modes
+ def testMapVariable(self):
+ v = resource_variable_ops.ResourceVariable([1, 2])
+ self.evaluate(v.initializer)
+
+ def loop_fn(x):
+ return x + 1
+
+ result = map_fn.map_fn(loop_fn, v)
+ expected_result = [2, 3]
+ self.assertAllEqual(result, expected_result)
+
+ @test_util.run_in_graph_and_eager_modes
def testMapOverScalarErrors(self):
with self.assertRaisesRegex(ValueError, "must be .* Tensor.* not scalar"):
map_fn.map_fn(lambda x: x, [1, 2])
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index a64a4da..13a4ff0 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -39,6 +39,7 @@
from tensorflow.python.ops.gen_array_ops import *
from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse # pylint: disable=unused-import
from tensorflow.python.types import core
+from tensorflow.python.util import _pywrap_utils
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import lazy_loader
@@ -282,7 +283,9 @@
Returns:
A `Tensor` or CompositeTensor. Has the same type and contents as `input`.
"""
- if isinstance(input, composite_tensor.CompositeTensor):
+ # Don't expand ResourceVariables, so identity(variable) will return a Tensor.
+ if (isinstance(input, composite_tensor.CompositeTensor) and
+ not _pywrap_utils.IsResourceVariable(input)):
return nest.map_structure(identity, input, expand_composites=True)
if context.executing_eagerly() and not hasattr(input, "graph"):
# Make sure we get an input with handle data attached from resource
@@ -6992,7 +6995,10 @@
Returns:
A `Tensor`. Has the same dtype as `input`.
"""
- if isinstance(input, composite_tensor.CompositeTensor):
+ # Don't expand ResourceVariables, so stop_gradient(variable) will return a
+ # Tensor.
+ if (isinstance(input, composite_tensor.CompositeTensor) and
+ not _pywrap_utils.IsResourceVariable(input)):
return nest.map_structure(stop_gradient, input, expand_composites=True)
# The StopGradient op has a gradient function registered which returns None
# (meaning statically known to be zero). For correctness, that's all we
diff --git a/tensorflow/python/ops/map_fn.py b/tensorflow/python/ops/map_fn.py
index 23f0959..513f566 100644
--- a/tensorflow/python/ops/map_fn.py
+++ b/tensorflow/python/ops/map_fn.py
@@ -35,6 +35,7 @@
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
+from tensorflow.python.util import variable_utils
from tensorflow.python.util.tf_export import tf_export
@@ -370,6 +371,8 @@
"parallel.", 1)
parallel_iterations = 1
+ # Explicitly read values of ResourceVariables.
+ elems = variable_utils.convert_variables_to_tensors(elems)
# Flatten the input tensors, and get the TypeSpec for each one.
elems_flat = nest.flatten(elems)
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py
index b1b9864..88c7ea2 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py
@@ -38,6 +38,7 @@
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
+from tensorflow.python.util import variable_utils
from tensorflow.python.util.tf_export import tf_export
@@ -534,6 +535,7 @@
Raises:
ValueError: If vectorization fails and fallback_to_while_loop is False.
"""
+ elems = variable_utils.convert_variables_to_tensors(elems)
elems = nest.map_structure(ops.convert_to_tensor,
elems,
expand_composites=True)
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index fce00aa..b718ada 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -2828,6 +2828,18 @@
self._test_loop_fn(loop_fn, 2)
+ @test_util.run_all_in_graph_and_eager_modes
+ def test_variable_input(self):
+ v = resource_variable_ops.ResourceVariable([1, 2])
+ self.evaluate(v.initializer)
+
+ def loop_fn(x):
+ return x + 1
+
+ result = pfor_control_flow_ops.vectorized_map(loop_fn, v)
+ expected_result = [2, 3]
+ self.assertAllEqual(result, expected_result)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index b7e4880..f6103ed 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -44,6 +44,7 @@
from tensorflow.python.util import lazy_loader
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
+from tensorflow.python.util import variable_utils
from tensorflow.python.util.tf_export import tf_export
autograph = lazy_loader.LazyLoader(
@@ -312,7 +313,7 @@
original_func = func
func = autograph.do_not_convert(func)
- inp = list(inp)
+ inp = variable_utils.convert_variables_to_tensors(list(inp))
# Normalize Tout.
is_list_or_tuple = isinstance(Tout, (list, tuple))
diff --git a/tensorflow/python/ops/script_ops_test.py b/tensorflow/python/ops/script_ops_test.py
index e916874..45689b9 100644
--- a/tensorflow/python/ops/script_ops_test.py
+++ b/tensorflow/python/ops/script_ops_test.py
@@ -15,9 +15,10 @@
"""Tests for script operations."""
from tensorflow.python.eager import def_function
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
-from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops.script_ops import numpy_function
from tensorflow.python.platform import test
@@ -87,5 +88,21 @@
2) # as stateful, func is guaranteed to execute twice
+class PyFunctionTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_variable_arguments(self):
+
+ def plus(a, b):
+ return a + b
+
+ v1 = resource_variable_ops.ResourceVariable(1)
+ self.evaluate(v1.initializer)
+
+ actual_result = script_ops.eager_py_func(plus, [v1, 2], dtypes.int32)
+ expect_result = constant_op.constant(3, dtypes.int32)
+ self.assertAllEqual(actual_result, expect_result)
+
+
if __name__ == "__main__":
test.main()