Fix failures with forwardprop_test_xla_gpu
PiperOrigin-RevId: 402593920
Change-Id: I159b8d615cbb4425bc360f203de20e35cb460900
diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py
index ebb15ca..bef025b 100644
--- a/tensorflow/python/eager/forwardprop_test.py
+++ b/tensorflow/python/eager/forwardprop_test.py
@@ -23,6 +23,7 @@
from tensorflow.python import pywrap_tfe
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import forwardprop
from tensorflow.python.eager import forwardprop_util
@@ -64,13 +65,17 @@
"""Compute the jacobian of `f` at `primals` using forward-mode autodiff."""
jac_flat = []
flat_primals = nest.flatten(primals)
- tangent_mask = [array_ops.zeros_like(primal) for primal in flat_primals]
+ tangent_mask = [
+ array_ops.zeros_like(primal, dtype=primal.dtype)
+ for primal in flat_primals
+ ]
for primal_index, primal in enumerate(flat_primals):
primal_vector = array_ops.reshape(primal, [-1])
primal_vector_length = array_ops.size(primal_vector)
jac_columns = []
for element_index in math_ops.range(primal_vector_length):
- mask = array_ops.one_hot(element_index, primal_vector_length)
+ mask = array_ops.one_hot(
+ element_index, primal_vector_length, dtype=primal.dtype)
tangent_mask[primal_index] = array_ops.reshape(mask,
array_ops.shape(primal))
jac_columns.append(
@@ -197,7 +202,9 @@
order,
delta=1e-3,
rtol=1e-2,
- atol=1e-6):
+ atol=1e-6,
+ srtol=1e-6,
+ satol=1e-6):
"""Tests forward/backward jacobians of `f`'s [0, `order`)-order gradients."""
if order < 1:
raise ValueError(
@@ -210,16 +217,19 @@
order=order - 1,
delta=delta,
rtol=rtol,
- atol=atol)
+ atol=atol,
+ srtol=srtol,
+ satol=satol)
sym_jac_back, num_jac = gradient_checker_v2.compute_gradient(
f, primals, delta=delta)
testcase.assertAllClose(num_jac, sym_jac_back, rtol=rtol, atol=atol)
sym_jac_fwd = _jacfwd(f, primals)
testcase.assertAllClose(num_jac, sym_jac_fwd, rtol=rtol, atol=atol)
# And the symbolic computations should be much closer.
- testcase.assertAllClose(sym_jac_back, sym_jac_fwd)
+ testcase.assertAllClose(sym_jac_back, sym_jac_fwd, rtol=srtol, atol=satol)
+@test_util.with_eager_op_as_function
class ForwardpropTest(test.TestCase, parameterized.TestCase):
def testJVPFunction(self):
@@ -430,8 +440,34 @@
return math_ops.reduce_prod(
pointwise + math_ops.reduce_sum(pointwise), axis=1)
+ if (context.run_eager_op_as_function_enabled() and
+ test_util.is_xla_enabled()):
+ # Autoclustering kicks in when eager_op_as_function is enabled.
+ # Under XLA the symbolic tolerances are less than under TF.
+ # Ref: b/202559426
+ _test_gradients(
+ self,
+ f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])],
+ order=3,
+ srtol=1e-6,
+ satol=1e-3)
+ else:
+ _test_gradients(
+ self, f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])], order=3)
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testNumericHigherOrderFloat64(self):
+
+ def f(x):
+ pointwise = math_ops.sin(x) * math_ops.tan(x)
+ return math_ops.reduce_prod(
+ pointwise + math_ops.reduce_sum(pointwise), axis=1)
+
_test_gradients(
- self, f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])], order=3)
+ self,
+ f,
+ [constant_op.constant([[2.0, 3.0], [1.0, 4.0]], dtype=dtypes.float64)],
+ order=3)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testCustomGradient(self):