PSv2: Remove ParameterServerFailureError now that parameter server failures are coming back from runtime as UnavailableError.
PiperOrigin-RevId: 337933395
Change-Id: Ib6277de3b5d457e73d6efb099724f8d76e001045
diff --git a/tensorflow/python/distribute/client/client.py b/tensorflow/python/distribute/client/client.py
index 6eabbfa..be7157c 100644
--- a/tensorflow/python/distribute/client/client.py
+++ b/tensorflow/python/distribute/client/client.py
@@ -844,11 +844,6 @@
return self._closure_queue.done()
-class ParameterServerFailureError(Exception):
- """An error representing at least one parameter server is interrupted."""
- pass
-
-
class Client(object):
"""An object to schedule and orchestrate remote function execution.
@@ -942,7 +937,7 @@
"""
# Slot variables are usually created during function tracing time; thus
# `schedule` needs to be called within the `strategy.scope()`.
- with self.strategy.scope(), _translate_parameter_server_failure():
+ with self.strategy.scope():
return self.cluster.schedule(fn, args=args, kwargs=kwargs)
def join(self):
@@ -964,8 +959,7 @@
scheduled function since the last time an error was thrown or since
the beginning of the program.
"""
- with _translate_parameter_server_failure():
- self.cluster.join()
+ self.cluster.join()
def done(self):
"""Returns whether all the scheduled functions have finished execution.
@@ -1066,23 +1060,10 @@
# pylint: disable=missing-function-docstring
@contextlib.contextmanager
-def _translate_parameter_server_failure():
- try:
- yield
- except Exception as e: # pylint: disable=broad-except
- if _is_ps_failure(e):
- raise ParameterServerFailureError(e)
- else:
- raise
-
-
-# pylint: disable=missing-function-docstring
-@contextlib.contextmanager
def handle_parameter_server_failure():
try:
- with _translate_parameter_server_failure():
- yield
- except ParameterServerFailureError as e: # pylint: disable=broad-except
+ yield
+ except errors.UnavailableError as e: # pylint: disable=broad-except
restart_exit_code = os.environ.get("TF_CLIENT_NON_FATAL_RESTART_EXIT_CODE",
None)
if restart_exit_code is not None:
diff --git a/tensorflow/python/distribute/client/client_mpr_test.py b/tensorflow/python/distribute/client/client_mpr_test.py
index 802b23e..7f66562 100644
--- a/tensorflow/python/distribute/client/client_mpr_test.py
+++ b/tensorflow/python/distribute/client/client_mpr_test.py
@@ -31,6 +31,7 @@
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
@@ -72,8 +73,7 @@
# Now the main process can terminate.
functions_scheduled_event.set()
- # Verified that join and schedule indeed raise
- # ParameterServerFailureError.
+ # Verified that join and schedule indeed raise UnavailableError.
try:
if test_join:
ps_client.join()
@@ -81,7 +81,7 @@
while ps_client.cluster._closure_queue._error is None:
time.sleep(1)
ps_client.schedule(worker_fn)
- except client_lib.ParameterServerFailureError:
+ except errors.UnavailableError:
# The following verifies that after PS fails, continue executing
# functions on workers should fail and indicate it's PS failure.
for worker_id in range(3):
@@ -101,7 +101,7 @@
raise RuntimeError("Executing a function after PS fails, should "
"result in a PS failure.")
- raise RuntimeError("ParameterServerFailureError supposed to be raised.")
+ raise RuntimeError("UnavailableError supposed to be raised.")
manager = multi_process_runner.manager()
functions_scheduled_event = manager.Event()