PSv2: Enable parameter_server_training_test in multiple gpu tests.
PiperOrigin-RevId: 364451340
Change-Id: I05173ee3e9c2136f565bb8ebecdf2a315d0b130a
diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD
index d1a17a0..751864c 100644
--- a/tensorflow/python/keras/distribute/BUILD
+++ b/tensorflow/python/keras/distribute/BUILD
@@ -856,12 +856,13 @@
],
)
-tf_py_test(
+distribute_py_test(
name = "parameter_server_training_test",
srcs = ["parameter_server_training_test.py"],
python_version = "PY3",
shard_count = 1,
tags = [
+ "multi_and_single_gpu",
"no_tfrt", # TODO(b/180537361): Reenable TFRT after the issue is resolved.
],
deps = [
diff --git a/tensorflow/python/keras/distribute/parameter_server_training_test.py b/tensorflow/python/keras/distribute/parameter_server_training_test.py
index 71fc520..a1572fd 100644
--- a/tensorflow/python/keras/distribute/parameter_server_training_test.py
+++ b/tensorflow/python/keras/distribute/parameter_server_training_test.py
@@ -237,6 +237,22 @@
multi_worker_testing_utils.make_parameter_server_cluster(3, 2),
variable_partitioner=sharded_variable.FixedShardsPartitioner(2))
+ def assert_list_all_equal(self, list1, list2):
+ """Used in lieu of `assertAllEqual`.
+
+ This is used to replace standard `assertAllEqual` for the cases where
+ `list1` and `list2` contain `AggregatingVariable`. Lists with
+ `AggregatingVariable` are not convertible to numpy array via `np.array`
+ calls as numpy would raise `ValueError: setting an array element with a
+ sequence.`
+
+ Args:
+ list1: The first list to compare equality.
+ list2: The second list to compare equality.
+ """
+ for lhs, rhs in zip(list1, list2):
+ self.assertEqual(lhs, rhs)
+
def test_keras_layer_setattr(self):
class Layer(base_layer.Layer):
@@ -255,10 +271,11 @@
self.assertLen(layer.non_trainable_weights, 2)
self.assertEqual(layer.non_trainable_weights[0], [2])
self.assertEqual(layer.non_trainable_weights[1], [3])
- self.assertAllEqual(layer.weights,
- layer.trainable_weights + layer.non_trainable_weights)
- self.assertAllEqual(layer.trainable_weights, layer.trainable_variables)
- self.assertAllEqual(layer.weights, layer.variables)
+ self.assert_list_all_equal(
+ layer.weights, layer.trainable_weights + layer.non_trainable_weights)
+ self.assert_list_all_equal(layer.trainable_weights,
+ layer.trainable_variables)
+ self.assert_list_all_equal(layer.weights, layer.variables)
checkpoint_deps = set(dep.ref for dep in layer._checkpoint_dependencies)
self.assertEqual(checkpoint_deps, set([layer.w, layer.b]))
@@ -287,10 +304,11 @@
self.assertLen(layer.non_trainable_weights, 2)
self.assertEqual(layer.non_trainable_weights[0], [2.])
self.assertEqual(layer.non_trainable_weights[1], [3.])
- self.assertAllEqual(layer.weights,
- layer.trainable_weights + layer.non_trainable_weights)
- self.assertAllEqual(layer.trainable_weights, layer.trainable_variables)
- self.assertAllEqual(layer.weights, layer.variables)
+ self.assert_list_all_equal(
+ layer.weights, layer.trainable_weights + layer.non_trainable_weights)
+ self.assert_list_all_equal(layer.trainable_weights,
+ layer.trainable_variables)
+ self.assert_list_all_equal(layer.weights, layer.variables)
checkpoint_deps = set(dep.ref for dep in layer._checkpoint_dependencies)
self.assertEqual(checkpoint_deps, set([layer.w, layer.b]))