PSv2: Enable parameter_server_training_test in multiple gpu tests.

PiperOrigin-RevId: 364451340
Change-Id: I05173ee3e9c2136f565bb8ebecdf2a315d0b130a
diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD
index d1a17a0..751864c 100644
--- a/tensorflow/python/keras/distribute/BUILD
+++ b/tensorflow/python/keras/distribute/BUILD
@@ -856,12 +856,13 @@
     ],
 )
 
-tf_py_test(
+distribute_py_test(
     name = "parameter_server_training_test",
     srcs = ["parameter_server_training_test.py"],
     python_version = "PY3",
     shard_count = 1,
     tags = [
+        "multi_and_single_gpu",
         "no_tfrt",  # TODO(b/180537361): Reenable TFRT after the issue is resolved.
     ],
     deps = [
diff --git a/tensorflow/python/keras/distribute/parameter_server_training_test.py b/tensorflow/python/keras/distribute/parameter_server_training_test.py
index 71fc520..a1572fd 100644
--- a/tensorflow/python/keras/distribute/parameter_server_training_test.py
+++ b/tensorflow/python/keras/distribute/parameter_server_training_test.py
@@ -237,6 +237,22 @@
         multi_worker_testing_utils.make_parameter_server_cluster(3, 2),
         variable_partitioner=sharded_variable.FixedShardsPartitioner(2))
 
+  def assert_list_all_equal(self, list1, list2):
+    """Used in lieu of `assertAllEqual`.
+
+    This is used to replace standard `assertAllEqual` for the cases where
+    `list1` and `list2` contain `AggregatingVariable`. Lists with
+    `AggregatingVariable` are not convertible to numpy array via `np.array`
+    calls as numpy would raise `ValueError: setting an array element with a
+    sequence.`
+
+    Args:
+      list1: The first list to compare equality.
+      list2: The second list to compare equality.
+    """
+    for lhs, rhs in zip(list1, list2):
+      self.assertEqual(lhs, rhs)
+
   def test_keras_layer_setattr(self):
 
     class Layer(base_layer.Layer):
@@ -255,10 +271,11 @@
     self.assertLen(layer.non_trainable_weights, 2)
     self.assertEqual(layer.non_trainable_weights[0], [2])
     self.assertEqual(layer.non_trainable_weights[1], [3])
-    self.assertAllEqual(layer.weights,
-                        layer.trainable_weights + layer.non_trainable_weights)
-    self.assertAllEqual(layer.trainable_weights, layer.trainable_variables)
-    self.assertAllEqual(layer.weights, layer.variables)
+    self.assert_list_all_equal(
+        layer.weights, layer.trainable_weights + layer.non_trainable_weights)
+    self.assert_list_all_equal(layer.trainable_weights,
+                               layer.trainable_variables)
+    self.assert_list_all_equal(layer.weights, layer.variables)
 
     checkpoint_deps = set(dep.ref for dep in layer._checkpoint_dependencies)
     self.assertEqual(checkpoint_deps, set([layer.w, layer.b]))
@@ -287,10 +304,11 @@
     self.assertLen(layer.non_trainable_weights, 2)
     self.assertEqual(layer.non_trainable_weights[0], [2.])
     self.assertEqual(layer.non_trainable_weights[1], [3.])
-    self.assertAllEqual(layer.weights,
-                        layer.trainable_weights + layer.non_trainable_weights)
-    self.assertAllEqual(layer.trainable_weights, layer.trainable_variables)
-    self.assertAllEqual(layer.weights, layer.variables)
+    self.assert_list_all_equal(
+        layer.weights, layer.trainable_weights + layer.non_trainable_weights)
+    self.assert_list_all_equal(layer.trainable_weights,
+                               layer.trainable_variables)
+    self.assert_list_all_equal(layer.weights, layer.variables)
 
     checkpoint_deps = set(dep.ref for dep in layer._checkpoint_dependencies)
     self.assertEqual(checkpoint_deps, set([layer.w, layer.b]))