PSv2: Remove ParameterServerFailureError now that parameter server failures are coming back from runtime as UnavailableError.

PiperOrigin-RevId: 337933395
Change-Id: Ib6277de3b5d457e73d6efb099724f8d76e001045
diff --git a/tensorflow/python/distribute/client/client.py b/tensorflow/python/distribute/client/client.py
index 6eabbfa..be7157c 100644
--- a/tensorflow/python/distribute/client/client.py
+++ b/tensorflow/python/distribute/client/client.py
@@ -844,11 +844,6 @@
     return self._closure_queue.done()
 
 
-class ParameterServerFailureError(Exception):
-  """An error representing at least one parameter server is interrupted."""
-  pass
-
-
 class Client(object):
   """An object to schedule and orchestrate remote function execution.
 
@@ -942,7 +937,7 @@
     """
     # Slot variables are usually created during function tracing time; thus
     # `schedule` needs to be called within the `strategy.scope()`.
-    with self.strategy.scope(), _translate_parameter_server_failure():
+    with self.strategy.scope():
       return self.cluster.schedule(fn, args=args, kwargs=kwargs)
 
   def join(self):
@@ -964,8 +959,7 @@
         scheduled function since the last time an error was thrown or since
         the beginning of the program.
     """
-    with _translate_parameter_server_failure():
-      self.cluster.join()
+    self.cluster.join()
 
   def done(self):
     """Returns whether all the scheduled functions have finished execution.
@@ -1066,23 +1060,10 @@
 
 # pylint: disable=missing-function-docstring
 @contextlib.contextmanager
-def _translate_parameter_server_failure():
-  try:
-    yield
-  except Exception as e:  # pylint: disable=broad-except
-    if _is_ps_failure(e):
-      raise ParameterServerFailureError(e)
-    else:
-      raise
-
-
-# pylint: disable=missing-function-docstring
-@contextlib.contextmanager
 def handle_parameter_server_failure():
   try:
-    with _translate_parameter_server_failure():
-      yield
-  except ParameterServerFailureError as e:  # pylint: disable=broad-except
+    yield
+  except errors.UnavailableError as e:  # pylint: disable=broad-except
     restart_exit_code = os.environ.get("TF_CLIENT_NON_FATAL_RESTART_EXIT_CODE",
                                        None)
     if restart_exit_code is not None:
diff --git a/tensorflow/python/distribute/client/client_mpr_test.py b/tensorflow/python/distribute/client/client_mpr_test.py
index 802b23e..7f66562 100644
--- a/tensorflow/python/distribute/client/client_mpr_test.py
+++ b/tensorflow/python/distribute/client/client_mpr_test.py
@@ -31,6 +31,7 @@
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import test
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variables
@@ -72,8 +73,7 @@
       # Now the main process can terminate.
       functions_scheduled_event.set()
 
-      # Verified that join and schedule indeed raise
-      # ParameterServerFailureError.
+      # Verified that join and schedule indeed raise UnavailableError.
       try:
         if test_join:
           ps_client.join()
@@ -81,7 +81,7 @@
           while ps_client.cluster._closure_queue._error is None:
             time.sleep(1)
           ps_client.schedule(worker_fn)
-      except client_lib.ParameterServerFailureError:
+      except errors.UnavailableError:
         # The following verifies that after PS fails, continue executing
         # functions on workers should fail and indicate it's PS failure.
         for worker_id in range(3):
@@ -101,7 +101,7 @@
             raise RuntimeError("Executing a function after PS fails, should "
                                "result in a PS failure.")
 
-      raise RuntimeError("ParameterServerFailureError supposed to be raised.")
+      raise RuntimeError("UnavailableError supposed to be raised.")
 
     manager = multi_process_runner.manager()
     functions_scheduled_event = manager.Event()