PSv2: Dedup the legacy ParameterServerStrategy class (as the estimator usage of it uses ParameterServerStrategyV1).

PiperOrigin-RevId: 338310081
Change-Id: Icff445e322b22ee4ac7f3e69327c7969444eeb93
diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py
index 0582ba1..312a3a4 100644
--- a/tensorflow/python/distribute/parameter_server_strategy.py
+++ b/tensorflow/python/distribute/parameter_server_strategy.py
@@ -47,9 +47,8 @@
 _LOCAL_CPU = "/device:CPU:0"
 
 
-# TODO(yuefengz): maybe cache variables on local CPU.
-# TODO(b/171250971): Remove this and change all symbol usage of this to V1.
-class ParameterServerStrategy(distribute_lib.Strategy):
+@tf_export(v1=["distribute.experimental.ParameterServerStrategy"])  # pylint: disable=missing-docstring
+class ParameterServerStrategyV1(distribute_lib.StrategyV1):
   """An asynchronous multi-worker parameter server tf.distribute strategy.
 
   This strategy requires two roles: workers and parameter servers. Variables and
@@ -112,11 +111,11 @@
     """
     if cluster_resolver is None:
       cluster_resolver = TFConfigClusterResolver()
-    if not cluster_resolver.cluster_spec():
-      raise ValueError("Cluster spec must be non-empty in `cluster_resolver`.")
-    extended = ParameterServerStrategyExtended(
-        self, cluster_resolver=cluster_resolver)
-    super(ParameterServerStrategy, self).__init__(extended)
+    super(ParameterServerStrategyV1, self).__init__(
+        ParameterServerStrategyExtended(
+            self, cluster_resolver=cluster_resolver))
+    distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
+        "ParameterServerStrategy")
 
   def experimental_distribute_dataset(self, dataset, options=None):
     if (options and options.experimental_replication_mode ==
@@ -127,7 +126,7 @@
           "`experimental_distribute_datasets_from_function`."
       )
     self._raise_pss_error_if_eager()
-    super(ParameterServerStrategy,
+    super(ParameterServerStrategyV1,
           self).experimental_distribute_dataset(dataset=dataset,
                                                 options=options)
 
@@ -140,17 +139,17 @@
           "`experimental_distribute_datasets_from_function` "
           "of tf.distribute.MirroredStrategy")
     self._raise_pss_error_if_eager()
-    super(ParameterServerStrategy, self).distribute_datasets_from_function(
+    super(ParameterServerStrategyV1, self).distribute_datasets_from_function(
         dataset_fn=dataset_fn, options=options)
 
   def run(self, fn, args=(), kwargs=None, options=None):
     self._raise_pss_error_if_eager()
-    super(ParameterServerStrategy, self).run(
+    super(ParameterServerStrategyV1, self).run(
         fn, args=args, kwargs=kwargs, options=options)
 
   def scope(self):
     self._raise_pss_error_if_eager()
-    return super(ParameterServerStrategy, self).scope()
+    return super(ParameterServerStrategyV1, self).scope()
 
   def _raise_pss_error_if_eager(self):
     if context.executing_eagerly():
@@ -159,22 +158,6 @@
           "currently only works with the tf.Estimator API")
 
 
-@tf_export(v1=["distribute.experimental.ParameterServerStrategy"])  # pylint: disable=missing-docstring
-class ParameterServerStrategyV1(distribute_lib.StrategyV1):
-
-  __doc__ = ParameterServerStrategy.__doc__
-
-  def __init__(self, cluster_resolver=None):
-    """Initializes this strategy."""
-    super(ParameterServerStrategyV1, self).__init__(
-        ParameterServerStrategyExtended(
-            self, cluster_resolver=cluster_resolver))
-    distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
-        "ParameterServerStrategy")
-
-  __init__.__doc__ = ParameterServerStrategy.__init__.__doc__
-
-
 # TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
 class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
   """Implementation of ParameterServerStrategy and CentralStorageStrategy."""
diff --git a/tensorflow/python/distribute/parameter_server_strategy_test.py b/tensorflow/python/distribute/parameter_server_strategy_test.py
index 1b4cd21..c196fb4 100644
--- a/tensorflow/python/distribute/parameter_server_strategy_test.py
+++ b/tensorflow/python/distribute/parameter_server_strategy_test.py
@@ -84,7 +84,7 @@
         task_type=task_type,
         task_id=task_id,
         num_accelerators={'GPU': num_gpus})
-    distribution = parameter_server_strategy.ParameterServerStrategy(
+    distribution = parameter_server_strategy.ParameterServerStrategyV1(
         cluster_resolver)
     target = 'grpc://' + cluster_spec[WORKER][task_id]
   else:
@@ -748,7 +748,7 @@
         task_type='worker',
         task_id=1,
         num_accelerators={'GPU': 0})
-    strategy = parameter_server_strategy.ParameterServerStrategy(
+    strategy = parameter_server_strategy.ParameterServerStrategyV1(
         cluster_resolver)
     dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7., 8.])
 
diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py
index c88282f..2094d73 100644
--- a/tensorflow/python/keras/distribute/distribute_strategy_test.py
+++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py
@@ -2449,12 +2449,11 @@
         task_type='worker',
         task_id=1,
         num_accelerators={'GPU': 0})
-    distribution = parameter_server_strategy.ParameterServerStrategy(
+    distribution = parameter_server_strategy.ParameterServerStrategyV1(
         cluster_resolver)
 
     self.assertIsInstance(distribution,
-                          (parameter_server_strategy.ParameterServerStrategyV1,
-                           parameter_server_strategy.ParameterServerStrategy))
+                          parameter_server_strategy.ParameterServerStrategyV1)
 
     with self.assertRaisesRegex(NotImplementedError,
                                 'ParameterServerStrategy*'):
diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py
index f21127f..dee211a 100644
--- a/tensorflow/python/keras/engine/training_v1.py
+++ b/tensorflow/python/keras/engine/training_v1.py
@@ -360,8 +360,7 @@
               distribution_strategy_context.get_strategy())
 
     if isinstance(self._distribution_strategy,
-                  (parameter_server_strategy.ParameterServerStrategyV1,
-                   parameter_server_strategy.ParameterServerStrategy)):
+                  parameter_server_strategy.ParameterServerStrategyV1):
       raise NotImplementedError('`tf.compat.v1.distribute.experimental.Paramet'
                                 'erServerStrategy` currently only works '
                                 'with the tf.Estimator API')