Fix type of gradient of DT_RESOURCE in while_v2. It used to default to float32.
PiperOrigin-RevId: 272885600
diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py
index 24b4aa1..c6e3ede 100644
--- a/tensorflow/python/kernel_tests/while_v2_test.py
+++ b/tensorflow/python/kernel_tests/while_v2_test.py
@@ -137,9 +137,10 @@
r"but has type <dtype: 'float16'> after 1 iteration."):
BuildWhile()
- def testGradientTapeResourceVariable(self):
+ @parameterized.parameters(dtypes.float32, dtypes.float64)
+ def testGradientTapeResourceVariable(self, dtype):
with context.eager_mode():
- v = variables.Variable(1.)
+ v = variables.Variable(1., dtype=dtype)
@def_function.function
def fnWithLoop(): # pylint: disable=invalid-name
@@ -147,7 +148,7 @@
_, x = while_loop_v2(
lambda i, _: i < 2,
lambda i, x: (i + 1, x * v),
- [0, 2.])
+ [0, constant_op.constant(2., dtype=dtype)])
return tape.gradient(x, v)
self.assertAllEqual(fnWithLoop(), 4.0)
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 0171fbc..140c463 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -37,6 +37,7 @@
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util_v2 as util
from tensorflow.python.ops import custom_gradient
+from tensorflow.python.ops import default_gradient
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gradients_util
@@ -508,7 +509,8 @@
"""Like array_ops.zeros_like() but also accepts resource var handles."""
if op_output.dtype == dtypes.resource:
return array_ops.zeros(
- gen_resource_variable_ops.variable_shape(op_output))
+ gen_resource_variable_ops.variable_shape(op_output),
+ dtype=default_gradient.get_zeros_dtype(op_output))
return array_ops.zeros_like(op_output)