Support async execution in ParallelDevice when remote eager is not in use.
PiperOrigin-RevId: 337335617
Change-Id: I4290de7b7af55cd1ac129073255cd24f5f08dda7
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc
index e270bfc..095f33f 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc
+++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc
@@ -58,7 +58,7 @@
class DeviceThread {
public:
// Starts a background thread waiting for `StartExecute`.
- explicit DeviceThread(const std::string& device)
+ explicit DeviceThread(const std::string& device, const bool is_async)
: status_(TF_NewStatus()),
device_(device),
// If the context's default exector is set to async, re-using that in
@@ -67,7 +67,7 @@
//
// TODO(allenl): We should have an async API that works with the
// parallel device.
- executor_(TFE_NewExecutor(/*is_async=*/false)),
+ executor_(TFE_NewExecutor(is_async)),
op_(nullptr),
thread_(tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "parallel_device_execute",
@@ -236,12 +236,13 @@
}
}
-ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
+ParallelDevice::ParallelDevice(const std::vector<std::string>& devices,
+ const bool is_async)
: underlying_devices_(devices) {
device_threads_.reserve(devices.size());
for (int device_index = 0; device_index < devices.size(); ++device_index) {
device_threads_.emplace_back(
- new DeviceThread(devices[device_index].c_str()));
+ new DeviceThread(devices[device_index].c_str(), is_async));
}
}
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.h b/tensorflow/c/eager/parallel_device/parallel_device_lib.h
index b3dc47a..1bb9ce0 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device_lib.h
+++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.h
@@ -49,7 +49,10 @@
// placed on each underlying device.
class ParallelDevice {
public:
- explicit ParallelDevice(const std::vector<std::string>& devices);
+ // Eager async execution is only supported when remote eager is not in use
+ // (b/157523095).
+ explicit ParallelDevice(const std::vector<std::string>& devices,
+ const bool is_async = false);
~ParallelDevice();