Fix major Adamax gpu bug.

PiperOrigin-RevId: 268469299
diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc
index 03c24ff..f8c3d7e 100644
--- a/tensorflow/core/kernels/training_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc
@@ -306,15 +306,16 @@
     bcast[0] = grad.dimension(0);
     Eigen::Sizes<1> single;
     const auto one = static_cast<T>(1.0);
-    m.device(d) =
-        m + (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) *
-                (grad - m);
+    m.device(d) +=
+        (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) *
+        (grad - m);
     v.device(d) =
         (beta2.reshape(single).broadcast(bcast) * v).cwiseMax(grad.abs());
-    var.device(d) -=
-        lr / (beta1_power.constant(one) -
-                 beta1_power).reshape(single).broadcast(bcast) *
-                     (m / (v + epsilon));
+    var.device(d) -= lr.reshape(single).broadcast(bcast) /
+                     (beta1_power.constant(one) - beta1_power)
+                         .reshape(single)
+                         .broadcast(bcast) *
+                     (m / (v + epsilon.reshape(single).broadcast(bcast)));
   }
 };
 
diff --git a/tensorflow/python/keras/optimizer_v2/BUILD b/tensorflow/python/keras/optimizer_v2/BUILD
index c1f7244..715c793 100644
--- a/tensorflow/python/keras/optimizer_v2/BUILD
+++ b/tensorflow/python/keras/optimizer_v2/BUILD
@@ -201,21 +201,13 @@
     xla_enable_strict_auto_jit = True,
 )
 
-py_test(
+cuda_py_test(
     name = "optimizer_v2_test",
     size = "medium",
     srcs = ["optimizer_v2_test.py"],
-    python_version = "PY2",
-    shard_count = 8,
-    tags = [
-        "no_gpu",  # b/127001953
-        "no_windows",
-        # TODO(b/127092862): Re-enable this test in Kokoro.
-        "no_oss",
-        "notap",  # b/140242244
-    ],
-    deps = [
+    additional_deps = [
         ":optimizer_v2",
+        "@absl_py//absl/testing:parameterized",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:clip_ops",
@@ -227,8 +219,12 @@
         "//tensorflow/python:variables",
         "//tensorflow/python/eager:def_function",
         "//tensorflow/python/keras",
-        "@absl_py//absl/testing:parameterized",
     ],
+    shard_count = 8,
+    tags = [
+        "no_windows",
+    ],
+    xla_enable_strict_auto_jit = True,
 )
 
 cuda_py_test(
diff --git a/tensorflow/python/keras/optimizer_v2/adamax_test.py b/tensorflow/python/keras/optimizer_v2/adamax_test.py
index b246a1d..ea9e8a5 100644
--- a/tensorflow/python/keras/optimizer_v2/adamax_test.py
+++ b/tensorflow/python/keras/optimizer_v2/adamax_test.py
@@ -80,7 +80,7 @@
 
   def doTestSparse(self, use_resource=False):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.cached_session():
+      with self.cached_session(use_gpu=True):
         # Initialize variables for numpy implementation.
         zero_slots = lambda: np.zeros((3), dtype=dtype.as_numpy_dtype)  # pylint: disable=cell-var-from-loop
         m0, v0, m1, v1 = zero_slots(), zero_slots(), zero_slots(), zero_slots()
@@ -176,9 +176,12 @@
   @test_util.run_in_graph_and_eager_modes(reset_test=True)
   def testBasic(self):
     for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
-      with self.session(graph=ops.Graph()):
+      with self.session(graph=ops.Graph(), use_gpu=True):
         # Initialize variables for numpy implementation.
-        m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+        m0 = np.array([0.0, 0.0])
+        v0 = np.array([0.0, 0.0])
+        m1 = np.array([0.0, 0.0])
+        v1 = np.array([0.0, 0.0])
         var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
         grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
         var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
@@ -224,7 +227,7 @@
   @test_util.run_in_graph_and_eager_modes(reset_test=True)
   def testBasicWithLearningRateDecay(self):
     for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
-      with self.session(graph=ops.Graph()):
+      with self.session(graph=ops.Graph(), use_gpu=True):
         # Initialize variables for numpy implementation.
         m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
         var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -278,7 +281,7 @@
   @test_util.run_deprecated_v1
   def testTensorLearningRate(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.cached_session():
+      with self.cached_session(use_gpu=True):
         # Initialize variables for numpy implementation.
         m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
         var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -315,7 +318,7 @@
   @test_util.run_deprecated_v1
   def testSharing(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.cached_session():
+      with self.cached_session(use_gpu=True):
         # Initialize variables for numpy implementation.
         m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
         var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
index 51a4350..3868a6f 100644
--- a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
@@ -65,7 +65,7 @@
   @test_util.run_in_graph_and_eager_modes
   def testBasic(self):
     for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
-      with self.cached_session():
+      with self.cached_session(use_gpu=True):
         var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
         var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
         loss = lambda: 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
@@ -129,7 +129,7 @@
   @test_util.run_in_graph_and_eager_modes
   def testPrecomputedGradient(self):
     for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.cached_session():
+      with self.cached_session(use_gpu=True):
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([3.0, 4.0], dtype=dtype)
         loss = lambda: 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
@@ -153,7 +153,7 @@
   @test_util.run_in_graph_and_eager_modes
   def testNoGradients(self):
     for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
-      with self.cached_session():
+      with self.cached_session(use_gpu=True):
         var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
         var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
         loss = lambda: 5 * var0  # pylint: disable=cell-var-from-loop
@@ -165,7 +165,7 @@
   @test_util.run_in_graph_and_eager_modes
   def testNoGradientsForAnyVariables_Minimize(self):
     for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
-      with self.cached_session():
+      with self.cached_session(use_gpu=True):
         var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
         var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
         loss = lambda: constant_op.constant(5.0)
@@ -178,7 +178,7 @@
   @test_util.run_in_graph_and_eager_modes
   def testNoGradientsForAnyVariables_ApplyGradients(self):
     for _, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
-      with self.cached_session():
+      with self.cached_session(use_gpu=True):
         var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
         var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
         sgd_op = gradient_descent.SGD(3.0)
@@ -189,7 +189,7 @@
   @test_util.run_in_graph_and_eager_modes
   def testGradientsAsVariables(self):
     for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
-      with self.cached_session():
+      with self.cached_session(use_gpu=True):
         var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
         var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
         loss = lambda: 5 * var0 + 3 * var1  # pylint: disable=cell-var-from-loop
@@ -227,7 +227,7 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testComputeGradientsWithTensors(self):
-    with self.cached_session():
+    with self.cached_session(use_gpu=True):
       x = ops.convert_to_tensor(1.0)
 
       def f():
@@ -247,7 +247,7 @@
   def testConstraint(self):
     constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
     constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
-    with self.cached_session():
+    with self.cached_session(use_gpu=True):
       var0 = variables.Variable([1.0, 2.0],
                                 constraint=constraint_01)
       var1 = variables.Variable([3.0, 4.0],
@@ -269,14 +269,14 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testIterationWithoutMinimize(self):
-    with self.cached_session():
+    with self.cached_session(use_gpu=True):
       sgd = gradient_descent.SGD(3.0)
       self.evaluate(sgd.iterations.initializer)
       self.assertEqual(0, self.evaluate(sgd.iterations))
 
   @test_util.run_in_graph_and_eager_modes
   def testConfig(self):
-    with self.cached_session():
+    with self.cached_session(use_gpu=True):
       opt = gradient_descent.SGD(learning_rate=1.0)
       config = opt.get_config()
       opt2 = gradient_descent.SGD.from_config(config)
@@ -296,7 +296,7 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testConfigWithLearningRateDecay(self):
-    with self.cached_session():
+    with self.cached_session(use_gpu=True):
       var0 = variables.Variable([[1.0], [2.0]], dtype=dtypes.float32)
       for decay_schedule in [
           learning_rate_schedule.InverseTimeDecay(
@@ -327,7 +327,7 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testGradClipValue(self):
-    with self.cached_session():
+    with self.cached_session(use_gpu=True):
       var = resource_variable_ops.ResourceVariable([1.0, 2.0])
       loss = lambda: 3 * var
       opt = gradient_descent.SGD(learning_rate=1.0, clipvalue=1.0)
@@ -338,7 +338,7 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testGradClipNorm(self):
-    with self.cached_session():
+    with self.cached_session(use_gpu=True):
       var = resource_variable_ops.ResourceVariable([1.0])
       loss = lambda: 3 * var
       opt = gradient_descent.SGD(learning_rate=1.0, clipnorm=1.0)
@@ -359,7 +359,7 @@
 
   @test_util.run_in_graph_and_eager_modes
   def testWeights(self):
-    with self.cached_session():
+    with self.cached_session(use_gpu=True):
       opt1 = adam.Adam(learning_rate=1.0)
       var1 = resource_variable_ops.ResourceVariable([1.0, 2.0],
                                                     dtype=dtypes.float32)
@@ -626,7 +626,7 @@
           'v1 optimizer does not run in experimental_run_tf_function mode or '
           'eager mode')
     np.random.seed(1331)
-    with self.cached_session():
+    with self.cached_session(use_gpu=True):
       train_samples = 20
       input_dim = 3
       num_classes = 2
@@ -714,7 +714,7 @@
           'v1 optimizer does not run in experimental_run_tf_function mode or '
           'eager mode')
     np.random.seed(1331)
-    with self.cached_session():
+    with self.cached_session(use_gpu=True):
       train_samples = 20
       input_dim = 3
       num_classes = 2
@@ -775,7 +775,7 @@
           'v1 optimizer does not run in experimental_run_tf_function mode or '
           'eager mode')
     np.random.seed(1331)
-    with self.cached_session():
+    with self.cached_session(use_gpu=True):
       train_samples = 20
       input_dim = 3
       num_classes = 2