Fix failures with forwardprop_test_xla_gpu

PiperOrigin-RevId: 402593920
Change-Id: I159b8d615cbb4425bc360f203de20e35cb460900
diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py
index ebb15ca..bef025b 100644
--- a/tensorflow/python/eager/forwardprop_test.py
+++ b/tensorflow/python/eager/forwardprop_test.py
@@ -23,6 +23,7 @@
 from tensorflow.python import pywrap_tfe
 from tensorflow.python.distribute import mirrored_strategy
 from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import forwardprop
 from tensorflow.python.eager import forwardprop_util
@@ -64,13 +65,17 @@
   """Compute the jacobian of `f` at `primals` using forward-mode autodiff."""
   jac_flat = []
   flat_primals = nest.flatten(primals)
-  tangent_mask = [array_ops.zeros_like(primal) for primal in flat_primals]
+  tangent_mask = [
+      array_ops.zeros_like(primal, dtype=primal.dtype)
+      for primal in flat_primals
+  ]
   for primal_index, primal in enumerate(flat_primals):
     primal_vector = array_ops.reshape(primal, [-1])
     primal_vector_length = array_ops.size(primal_vector)
     jac_columns = []
     for element_index in math_ops.range(primal_vector_length):
-      mask = array_ops.one_hot(element_index, primal_vector_length)
+      mask = array_ops.one_hot(
+          element_index, primal_vector_length, dtype=primal.dtype)
       tangent_mask[primal_index] = array_ops.reshape(mask,
                                                      array_ops.shape(primal))
       jac_columns.append(
@@ -197,7 +202,9 @@
                     order,
                     delta=1e-3,
                     rtol=1e-2,
-                    atol=1e-6):
+                    atol=1e-6,
+                    srtol=1e-6,
+                    satol=1e-6):
   """Tests forward/backward jacobians of `f`'s [0, `order`)-order gradients."""
   if order < 1:
     raise ValueError(
@@ -210,16 +217,19 @@
         order=order - 1,
         delta=delta,
         rtol=rtol,
-        atol=atol)
+        atol=atol,
+        srtol=srtol,
+        satol=satol)
   sym_jac_back, num_jac = gradient_checker_v2.compute_gradient(
       f, primals, delta=delta)
   testcase.assertAllClose(num_jac, sym_jac_back, rtol=rtol, atol=atol)
   sym_jac_fwd = _jacfwd(f, primals)
   testcase.assertAllClose(num_jac, sym_jac_fwd, rtol=rtol, atol=atol)
   # And the symbolic computations should be much closer.
-  testcase.assertAllClose(sym_jac_back, sym_jac_fwd)
+  testcase.assertAllClose(sym_jac_back, sym_jac_fwd, rtol=srtol, atol=satol)
 
 
+@test_util.with_eager_op_as_function
 class ForwardpropTest(test.TestCase, parameterized.TestCase):
 
   def testJVPFunction(self):
@@ -430,8 +440,34 @@
       return math_ops.reduce_prod(
           pointwise + math_ops.reduce_sum(pointwise), axis=1)
 
+    if (context.run_eager_op_as_function_enabled() and
+        test_util.is_xla_enabled()):
+      # Autoclustering kicks in when eager_op_as_function is enabled.
+      # Under XLA the symbolic tolerances are less than under TF.
+      # Ref: b/202559426
+      _test_gradients(
+          self,
+          f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])],
+          order=3,
+          srtol=1e-6,
+          satol=1e-3)
+    else:
+      _test_gradients(
+          self, f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])], order=3)
+
+  @test_util.assert_no_new_pyobjects_executing_eagerly
+  def testNumericHigherOrderFloat64(self):
+
+    def f(x):
+      pointwise = math_ops.sin(x) * math_ops.tan(x)
+      return math_ops.reduce_prod(
+          pointwise + math_ops.reduce_sum(pointwise), axis=1)
+
     _test_gradients(
-        self, f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])], order=3)
+        self,
+        f,
+        [constant_op.constant([[2.0, 3.0], [1.0, 4.0]], dtype=dtypes.float64)],
+        order=3)
 
   @test_util.assert_no_new_pyobjects_executing_eagerly
   def testCustomGradient(self):