Add support for specifying a custom exit code in WorkerHeartbeatRequest.
PiperOrigin-RevId: 264909269
diff --git a/tensorflow/core/util/event.proto b/tensorflow/core/util/event.proto
index ee1040d..f5dfffa 100644
--- a/tensorflow/core/util/event.proto
+++ b/tensorflow/core/util/event.proto
@@ -106,9 +106,14 @@
int64 timeout_ms = 1;
}
+message RequestedExitCode {
+ int32 exit_code = 1;
+}
+
message WorkerHeartbeatRequest {
WorkerShutdownMode shutdown_mode = 1;
WatchdogConfig watchdog_config = 2;
+ RequestedExitCode exit_code = 3;
}
message WorkerHeartbeatResponse {
diff --git a/tensorflow/python/tpu/session_support.py b/tensorflow/python/tpu/session_support.py
index 8280939..48a3e5f 100644
--- a/tensorflow/python/tpu/session_support.py
+++ b/tensorflow/python/tpu/session_support.py
@@ -145,12 +145,14 @@
# Default timeout is set to allow other shutdown triggered operations (log
# flushing etc) to finish before terminating the worker.
- def shutdown(self, wait_time_in_ms=60000):
+ def shutdown(self, wait_time_in_ms=60000, exit_code=None):
"""Shutdown all workers after `shutdown_timeout_secs`."""
logging.info('Shutting down %s.', self)
req = event_pb2.WorkerHeartbeatRequest(
watchdog_config=event_pb2.WatchdogConfig(timeout_ms=wait_time_in_ms),
- shutdown_mode=event_pb2.SHUTDOWN_AFTER_TIMEOUT)
+ shutdown_mode=event_pb2.SHUTDOWN_AFTER_TIMEOUT,
+ exit_code=event_pb2.RequestedExitCode(
+ exit_code=exit_code) if exit_code is not None else None)
self.configure(req)
# Wait for workers to shutdown.