Factor out C++ types from the parallel device
PiperOrigin-RevId: 314807016
Change-Id: I4e41ac3e8a08ea0f1db93826652142a083f17fd1
diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD
index 6fce918..0d0e5ff 100644
--- a/tensorflow/c/eager/parallel_device/BUILD
+++ b/tensorflow/c/eager/parallel_device/BUILD
@@ -12,28 +12,69 @@
# need a second rule that omits .cc files, in
# tensorflow/python:_pywrap_parallel_device.
filegroup(
- name = "headers",
+ name = "lib_headers",
+ srcs = ["parallel_device_lib.h"],
+)
+
+filegroup(
+ name = "lib_sources",
+ srcs = ["parallel_device_lib.cc"],
+)
+
+filegroup(
+ name = "device_headers",
srcs = ["parallel_device.h"],
+)
+
+filegroup(
+ name = "device_sources",
+ srcs = ["parallel_device.cc"],
+)
+
+filegroup(
+ name = "headers",
+ srcs = [
+ ":device_headers",
+ ":lib_headers",
+ ],
visibility = ["//tensorflow/python:__pkg__"],
)
filegroup(
name = "sources",
- srcs = ["parallel_device.cc"],
+ srcs = [
+ ":device_sources",
+ ":lib_sources",
+ ],
visibility = ["//tensorflow/python:__pkg__"],
)
cc_library(
name = "parallel_device",
- srcs = [":sources"],
- hdrs = [":headers"],
+ srcs = [":device_sources"],
+ hdrs = [":device_headers"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":parallel_device_lib",
+ "//tensorflow/c:c_api",
+ "//tensorflow/c/eager:c_api",
+ "//tensorflow/c/eager:c_api_experimental",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:variant",
+ ],
+)
+
+cc_library(
+ name = "parallel_device_lib",
+ srcs = [":lib_sources"],
+ hdrs = [":lib_headers"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:lib",
- "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant",
],
diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc
index 06669cf..eec893e 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device.cc
+++ b/tensorflow/c/eager/parallel_device/parallel_device.cc
@@ -23,25 +23,13 @@
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
+#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/c/tf_status.h"
-#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow {
-namespace eager {
+namespace parallel_device {
namespace {
-// Functor for making unique_ptrs slightly more ergonomic. Using
-// decltype(delete_fn) in the unique_ptr's second template argument requires
-// passing a function pointer to delete_fn when constructing the unique_ptr.
-class TensorHandleDeleter {
- public:
- void operator()(TFE_TensorHandle* to_delete) const {
- TFE_DeleteTensorHandle(to_delete);
- }
-};
-
-using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
-
class OpDeleter {
public:
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
@@ -49,224 +37,43 @@
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
-class ExecutorDeleter {
- public:
- void operator()(TFE_Executor* to_delete) const {
- TFE_DeleteExecutor(to_delete);
- }
-};
-
-using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
-
-class ParallelTensor;
-
using MaybeParallelTensorOwned =
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
-using MaybeParallelTensorUnowned =
- absl::variant<ParallelTensor*, TFE_TensorHandle*>;
-// Creates a vector of `count` new executors (threads).
-std::vector<ExecutorPtr> MakeExecutors(size_t count) {
- std::vector<ExecutorPtr> executors;
- executors.reserve(count);
- for (int i = 0; i < count; ++i) {
- executors.emplace_back(TFE_NewExecutor(true /* is_async */));
- }
- return executors;
-}
-
-// A representation of the custom device passed in and out of the TFE custom
-// device APIs, providing context about the parallel device to
-// ParallelDeviceExecute.
-class ParallelDevice {
+// A ParallelDevice on its own is not registered with a TFE_Context, and so has
+// no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
+// name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
+// placed on the parallel device.
+class NamedParallelDevice {
public:
- ParallelDevice(const std::string& name,
- const std::vector<std::string>& devices);
-
- // Helper to copy a tensor handle from another device once for each component
- // of the ParallelDevice.
- //
- // Sets a bad status and returns a nullptr if `tensor` is already on the
- // ParallelDevice, or if the individual copies fail.
- std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
- TFE_TensorHandle* tensor,
- TF_Status* status) const;
-
- // A parallel tensor with scalar integers numbering component devices.
- std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
- TF_Status* status) const;
-
- // Takes a description of a single operation being executed on the
- // ParallelDevice, and in turn runs one operation per component device with
- // its corresponding inputs from the input ParallelTensors (or
- // implicitly-mirrored tensors on other devices). Wraps the resulting
- // per-device and per-output TFE_TensorHandles into one ParallelTensor per
- // output of the original operation.
- //
- // `inputs` are either ParallelTensors, i.e. already on the ParallelDevice, or
- // un-replicated TFE_TensorHandles on other devices. TPUReplicatedInput
- // requires non-parallel tensors, and TPUReplicatedOutput requires a parallel
- // tensor, but other operations will implicitly broadcast non-parallel input
- // tensors across the ParallelDevice's component devices.
- //
- // Two special-cased operations, TPUReplicatedInput and TPUReplicatedOutput,
- // pack and un-pack parallel tensors respectively. Only TPUReplicatedOutput
- // causes `Execute` to return non-parallel tensors.
- //
- // Attributes are forwarded to executed operations unmodified.
- //
- // The returned optional has a value if and only if `status` evaluates to
- // TF_OK.
- absl::optional<std::vector<MaybeParallelTensorOwned>> Execute(
- TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
- const char* operation_name, const TFE_OpAttrs* attributes,
- int expected_max_outputs, TF_Status* status) const;
-
- // Implements the parallel case for `Execute`, where all of the outputs of the
- // operation are ParallelTensors, and all inputs are either ParallelTensors or
- // should be implicitly broadcast. This means the operation is not
- // TPUReplicatedInput or TPUReplicatedOutput.
- //
- // The returned optional has a value if and only if `status` evaluates to
- // TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
- // if sanity checks on dtypes/metadata fail.
- absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
- ExecuteParallelOperation(TFE_Context* context,
- std::vector<MaybeParallelTensorUnowned> inputs,
- const char* operation_name,
- const TFE_OpAttrs* attributes,
- int expected_max_outputs, TF_Status* status) const;
-
- const std::string& device_name() const { return device_name_; }
+ NamedParallelDevice(const std::string& name,
+ std::unique_ptr<ParallelDevice> parallel_device)
+ : device_name_(name), parallel_device_(std::move(parallel_device)) {}
+ const std::string& name() const { return device_name_; }
+ const ParallelDevice& device() const { return *parallel_device_; }
private:
- // The name of the parallel device
- // (e.g. "/job:localhost/replica:0/task:0/device:CUSTOM:0")
- const std::string device_name_;
- // A sequence of device names, indicating which devices replicated operations
- // are forwarded to.
- const std::vector<std::string> underlying_devices_;
- // A sequence of TFE_Executors, one per device, for executing operations in
- // parallel.
- const std::vector<ExecutorPtr> executors_;
+ std::string device_name_;
+ std::unique_ptr<ParallelDevice> parallel_device_;
};
-// The internal representation of a TFE_TensorHandle placed on a
-// ParallelDevice. Contains a tuple of tensors, one on each of the
-// `underlying_devices_` of the ParallelDevice.
-class ParallelTensor {
- public:
- // Construct a ParallelTensor from TensorHandles placed on the component
- // devices of a ParallelDevice.
- static std::unique_ptr<ParallelTensor> FromTensorHandles(
- const ParallelDevice& parallel_device,
- std::vector<TensorHandlePtr> components, TF_Status* status);
-
- // Helper to wrap a ParallelTensor into a TFE_TensorHandle which contains it.
- static TensorHandlePtr AsTensorHandle(TFE_Context* context,
- std::unique_ptr<ParallelTensor> t,
- TF_Status* status);
-
- size_t num_tensors() const { return tensors_.size(); }
- TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
-
- private:
- ParallelTensor(const ParallelDevice& device,
- std::vector<TensorHandlePtr> tensors,
- std::vector<int64_t> shape, const TF_DataType dtype)
- : device_(device),
- tensors_(std::move(tensors)),
- shape_(std::move(shape)),
- dtype_(dtype) {}
-
- const ParallelDevice& device_;
- const std::vector<TensorHandlePtr> tensors_;
- const std::vector<int64_t> shape_;
- const TF_DataType dtype_;
-};
-
-ParallelDevice::ParallelDevice(const std::string& name,
- const std::vector<std::string>& devices)
- : device_name_(name),
- underlying_devices_(devices),
- executors_(MakeExecutors(underlying_devices_.size())) {}
-
-std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
- TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
- const char* current_device = TFE_TensorHandleDeviceName(tensor, status);
- if (device_name_ == current_device) {
- std::string message(absl::StrCat(
- "Tried to copy a TensorHandle to its existing device: ", device_name_));
- TF_SetStatus(status, TF_INTERNAL, message.c_str());
- return nullptr;
- }
- std::vector<TensorHandlePtr> components;
- components.reserve(underlying_devices_.size());
- for (const std::string& underlying_device_name : underlying_devices_) {
- TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
- tensor, context, underlying_device_name.c_str(), status);
- if (TF_GetCode(status) != TF_OK) return nullptr;
- components.emplace_back(t);
- }
- return ParallelTensor::FromTensorHandles(*this, std::move(components),
- status);
-}
-
-std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
- TFE_Context* context, TF_Status* status) const {
- // TODO(allenl): We could cache DeviceIDs (keyed by context).
- std::vector<TensorHandlePtr> components;
- components.reserve(underlying_devices_.size());
- for (int device_index = 0; device_index < underlying_devices_.size();
- ++device_index) {
- int64_t* device_id = new int64_t;
- *device_id = device_index;
- std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
- TF_NewTensor(
- TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
- sizeof(int64_t),
- [](void* data, size_t, void* arg) {
- delete reinterpret_cast<int64_t*>(data);
- },
- nullptr),
- TF_DeleteTensor);
- // TODO(allenl): Here and when executing regular operations, we could hold
- // on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
- // device names repeatedly.
- OpPtr const_op(TFE_NewOp(context, "Const", status));
- if (TF_GetCode(status) != TF_OK) return nullptr;
- TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
- status);
- if (TF_GetCode(status) != TF_OK) return nullptr;
- TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
- if (TF_GetCode(status) != TF_OK) return nullptr;
- TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
- TFE_TensorHandle* device_handle;
- int num_outputs = 1;
- TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
- if (TF_GetCode(status) != TF_OK) return nullptr;
- components.emplace_back(device_handle);
- if (TF_GetCode(status) != TF_OK) return nullptr;
- }
- return ParallelTensor::FromTensorHandles(*this, std::move(components),
- status);
-}
-
-absl::optional<std::vector<MaybeParallelTensorOwned>> ParallelDevice::Execute(
- TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
- const char* operation_name, const TFE_OpAttrs* attributes,
- int expected_max_outputs, TF_Status* status) const {
+absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
+ const ParallelDevice& parallel_device,
+ const std::string& parallel_device_name, TFE_Context* context,
+ std::vector<MaybeParallelTensorUnowned> inputs, const char* operation_name,
+ const TFE_OpAttrs* attributes, int expected_max_outputs,
+ TF_Status* status) {
absl::optional<std::vector<MaybeParallelTensorOwned>> result;
// TODO(allenl): We should remove "TPU" from these op names at the very least,
// or consider other ways of packing/unpacking parallel tensors.
if (operation_name == std::string("TPUReplicatedInput")) {
// Special-cased operation for packing per-device tensors into one parallel
// tensor.
- if (inputs.size() != underlying_devices_.size()) {
+ if (inputs.size() != parallel_device.num_underlying_devices()) {
std::string message(absl::StrCat(
- "The parallel device ", device_name_, " expected ",
- underlying_devices_.size(), " inputs to TPUReplicatedInput, but got ",
- inputs.size()));
+ "The parallel device ", parallel_device_name, " expected ",
+ parallel_device.num_underlying_devices(),
+ " inputs to TPUReplicatedInput, but got ", inputs.size()));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
}
@@ -289,7 +96,7 @@
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
result_content.push_back(ParallelTensor::FromTensorHandles(
- *this, std::move(components), status));
+ parallel_device, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
@@ -300,10 +107,10 @@
TFE_OpAddAttrs(op.get(), attributes);
int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
if (TF_GetCode(status) != TF_OK) return result;
- if (expected_outputs != underlying_devices_.size()) {
+ if (expected_outputs != parallel_device.num_underlying_devices()) {
std::string message(absl::StrCat(
- "The parallel device ", device_name_, " expected ",
- underlying_devices_.size(),
+ "The parallel device ", parallel_device_name, " expected ",
+ parallel_device.num_underlying_devices(),
" outputs for TPUReplicatedOutput, but got ", expected_outputs));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
@@ -329,15 +136,15 @@
} else if (operation_name == std::string("DeviceID")) {
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
- result_content.push_back(DeviceIDs(context, status));
+ result_content.push_back(parallel_device.DeviceIDs(context, status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
maybe_parallel_results(
- ExecuteParallelOperation(context, std::move(inputs), operation_name,
- attributes, expected_max_outputs, status));
+ parallel_device.Execute(context, std::move(inputs), operation_name,
+ attributes, expected_max_outputs, status));
if (!maybe_parallel_results.has_value()) return result;
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
std::move(maybe_parallel_results.value()));
@@ -351,153 +158,6 @@
return result;
}
-absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
-ParallelDevice::ExecuteParallelOperation(
- TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
- const char* operation_name, const TFE_OpAttrs* attributes,
- int expected_max_outputs, TF_Status* status) const {
- absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
- // Compute per-device per-output tensors
- std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
- per_device_output_tensors.reserve(underlying_devices_.size());
- // TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
- // setting the thread-local executor like this.
- TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
- auto reset_executor = gtl::MakeCleanup([context, previous_executor]() {
- TFE_ContextSetExecutorForThread(context, previous_executor);
- TFE_DeleteExecutor(previous_executor);
- });
- int first_op_output_count;
- for (int device_index = 0; device_index < underlying_devices_.size();
- ++device_index) {
- TFE_Executor* executor = executors_[device_index].get();
- // Note that the `reset_executor` cleanup sets the thread's executor back to
- // the value before this function ran.
- TFE_ContextSetExecutorForThread(context, executor);
- OpPtr op(TFE_NewOp(context, operation_name, status));
- if (TF_GetCode(status) != TF_OK) return result;
- TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
- status);
- TFE_OpAddAttrs(op.get(), attributes);
- for (int input_index = 0; input_index < inputs.size(); ++input_index) {
- if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
- // Non-parallel tensors are implicitly broadcast, i.e. set as the input
- // to each parallel operation.
- //
- // TODO(allenl): There may be smarter ways to do this copy in some
- // cases, i.e. with a collective broadcast. We'll need to be careful
- // about things that are taken as inputs on the host or on their
- // existing device (for multi-device functions).
- TFE_OpAddInput(op.get(),
- absl::get<TFE_TensorHandle*>(inputs[input_index]),
- status);
- if (TF_GetCode(status) != TF_OK) return result;
- } else {
- // Parallel tensors are divided between operations by device.
- TFE_OpAddInput(op.get(),
- absl::get<ParallelTensor*>(inputs[input_index])
- ->tensor(device_index),
- status);
- if (TF_GetCode(status) != TF_OK) return result;
- }
- }
- std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
- int real_num_outputs = expected_max_outputs;
- // For nested devices, the inner device sees the async executor we've
- // set. Inner parallel devices will just overwrite this with their own and
- // then set it back to ours before returning. This means parallel devices
- // which consist of several aliased parallel devices would hypothetically
- // deadlock if the outer parallel device ran one collective with a group
- // size equal to the total number of aliased physical devices. Currently
- // physical devices cannot participate in a single collective reduction
- // multiple times, so this would fail earlier.
- //
- // TODO(allenl): Keep a map from outer executor to list of inner executors
- // rather than a single list of executors so aliased nested parallel devices
- // don't re-use an executor.
- TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
- if (device_index == 0) {
- first_op_output_count = real_num_outputs;
- } else {
- if (real_num_outputs != first_op_output_count) {
- TF_SetStatus(status, TF_INTERNAL,
- "Parallel ops produced different numbers of tensors.");
- return result;
- }
- }
- if (TF_GetCode(status) != TF_OK) return result;
- std::vector<TensorHandlePtr> this_outputs;
- this_outputs.reserve(real_num_outputs);
- for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
- this_outputs.emplace_back(op_outputs[output_num]);
- }
- per_device_output_tensors.push_back(std::move(this_outputs));
- }
- for (int device_index = 0; device_index < underlying_devices_.size();
- ++device_index) {
- TFE_Executor* executor = executors_[device_index].get();
- // TODO(b/157523095): Syncing the executor here shouldn't be
- // necessary. Currently async+remote is missing cross-executor
- // coordination.
- TFE_ExecutorWaitForAllPendingNodes(executor, status);
- if (TF_GetCode(status) != TF_OK) return result;
- }
- // For each output of the original operation, pack the per-device
- // TensorHandles we've computed into a single parallel TensorHandle.
- std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
- per_device_outputs.reserve(first_op_output_count);
- for (int i = 0; i < first_op_output_count; ++i) {
- std::vector<TensorHandlePtr> components;
- components.reserve(underlying_devices_.size());
- for (int j = 0; j < underlying_devices_.size(); ++j) {
- components.push_back(std::move(per_device_output_tensors[j][i]));
- }
- per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
- *this, std::move(components), status));
- if (TF_GetCode(status) != TF_OK) return result;
- }
- result.emplace(std::move(per_device_outputs));
- return result;
-}
-
-std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
- const ParallelDevice& parallel_device,
- std::vector<TensorHandlePtr> components, TF_Status* status) {
- TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
- std::vector<int64_t> shape(
- TFE_TensorHandleNumDims(components[0].get(), status));
- if (TF_GetCode(status) != TF_OK) return nullptr;
- for (int i = 0; i < shape.size(); ++i) {
- shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
- if (TF_GetCode(status) != TF_OK) return nullptr;
- }
-
- // Verify that the TensorHandle's shape and dtype match all of the component
- // shapes and dtypes.
- for (TensorHandlePtr& component : components) {
- for (int i = 0; i < shape.size(); ++i) {
- int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
- if (TF_GetCode(status) != TF_OK) return nullptr;
- if (tensor_dim != shape[i]) {
- // TODO(allenl): Allow shapes to differ.
- TF_SetStatus(status, TF_UNIMPLEMENTED,
- "Components of a ParallelTensor must currently all have "
- "the same shape");
- return nullptr;
- }
- if (TFE_TensorHandleDataType(component.get()) != dtype) {
- TF_SetStatus(status, TF_INTERNAL,
- "Components of a ParallelTensor must all have "
- "the same dtype");
- return nullptr;
- }
- }
- }
-
- return std::unique_ptr<ParallelTensor>(new ParallelTensor(
- parallel_device, std::move(components), std::move(shape), dtype));
-}
-
// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how
// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
// reference counts drop to zero.
@@ -505,17 +165,18 @@
delete reinterpret_cast<ParallelTensor*>(data);
}
-TensorHandlePtr ParallelTensor::AsTensorHandle(
- TFE_Context* context, std::unique_ptr<ParallelTensor> t,
- TF_Status* status) {
+TensorHandlePtr ParallelTensorToTensorHandle(
+ const std::string& parallel_device_name, TFE_Context* context,
+ std::unique_ptr<ParallelTensor> t, TF_Status* status) {
// The resulting TensorHandle owns an opaque pointer to "device memory", which
// for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
// deleted, it will call ParallelTensorDeallocator to free the struct.
ParallelTensor* t_released = t.release();
+ const std::vector<int64_t>& shape(t_released->shape());
return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory(
- context, t_released->device_.device_name().c_str(), t_released->dtype_,
- t_released->shape_.data(), t_released->shape_.size(), t_released, 1,
- &ParallelTensorDeallocator, nullptr, status));
+ context, parallel_device_name.c_str(), t_released->dtype(), shape.data(),
+ shape.size(), t_released, 1, &ParallelTensorDeallocator, nullptr,
+ status));
}
// For TFE_CustomDevice::copy_tensor_to_device in the parallel device
@@ -531,12 +192,14 @@
TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
- ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
+ NamedParallelDevice* named_device =
+ reinterpret_cast<NamedParallelDevice*>(device_info);
+ const ParallelDevice& dev = named_device->device();
std::unique_ptr<ParallelTensor> parallel_tensor(
- dev->CopyToParallelDevice(context, tensor, status));
+ dev.CopyToParallelDevice(context, tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
- return ParallelTensor::AsTensorHandle(context, std::move(parallel_tensor),
- status)
+ return ParallelTensorToTensorHandle(named_device->name(), context,
+ std::move(parallel_tensor), status)
.release();
}
@@ -570,14 +233,15 @@
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* status,
void* device_info) {
- ParallelDevice* dev = reinterpret_cast<ParallelDevice*>(device_info);
+ NamedParallelDevice* named_device =
+ reinterpret_cast<NamedParallelDevice*>(device_info);
std::vector<MaybeParallelTensorUnowned> typed_inputs;
typed_inputs.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
const char* tensor_handle_device =
TFE_TensorHandleDeviceName(inputs[i], status);
if (TF_GetCode(status) != TF_OK) return;
- if (dev->device_name() == tensor_handle_device) {
+ if (named_device->name() == tensor_handle_device) {
// We assume that any tensors already placed on this device are
// ParallelTensors.
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
@@ -589,8 +253,9 @@
}
absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
- dev->Execute(context, std::move(typed_inputs), operation_name, attributes,
- *num_outputs, status));
+ ExecuteWithSpecialOps(named_device->device(), named_device->name(),
+ context, std::move(typed_inputs), operation_name,
+ attributes, *num_outputs, status));
if (TF_GetCode(status) != TF_OK) return;
if (!maybe_typed_outputs.has_value()) {
TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
@@ -611,8 +276,8 @@
if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
} else {
- outputs[i] = ParallelTensor::AsTensorHandle(
- context,
+ outputs[i] = ParallelTensorToTensorHandle(
+ named_device->name(), context,
std::move(absl::get<std::unique_ptr<ParallelTensor>>(
typed_output)),
status)
@@ -629,7 +294,7 @@
// device_info is passed in using a C-style generic. It must always be a
// ParallelDevice.
void DeleteParallelDevice(void* device_info) {
- delete reinterpret_cast<ParallelDevice*>(device_info);
+ delete reinterpret_cast<NamedParallelDevice*>(device_info);
}
} // namespace
@@ -648,8 +313,10 @@
++device_index) {
underlying_devices_vector.push_back(underlying_devices[device_index]);
}
- *device_info = new ParallelDevice(device_name, underlying_devices_vector);
+ std::unique_ptr<ParallelDevice> parallel_device(
+ new ParallelDevice(underlying_devices_vector));
+ *device_info =
+ new NamedParallelDevice{device_name, std::move(parallel_device)};
}
-
-} // namespace eager
+} // namespace parallel_device
} // namespace tensorflow
diff --git a/tensorflow/c/eager/parallel_device/parallel_device.h b/tensorflow/c/eager/parallel_device/parallel_device.h
index f448a4c..b8e571b 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device.h
+++ b/tensorflow/c/eager/parallel_device/parallel_device.h
@@ -21,7 +21,7 @@
#include "tensorflow/c/eager/c_api_experimental.h"
namespace tensorflow {
-namespace eager {
+namespace parallel_device {
// Allocate a parallel device named `device_name` which forwards operations to
// `underlying_devices`, maintaining "parallel tensors" with components placed
@@ -59,7 +59,7 @@
int num_underlying_devices,
TFE_CustomDevice* device, void** device_info);
-} // namespace eager
+} // namespace parallel_device
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc
new file mode 100644
index 0000000..f56b8d8
--- /dev/null
+++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc
@@ -0,0 +1,251 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
+
+#include "tensorflow/core/lib/gtl/cleanup.h"
+
+namespace tensorflow {
+namespace parallel_device {
+namespace {
+
+class OpDeleter {
+ public:
+ void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
+};
+
+using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
+
+// Creates a vector of `count` new executors (threads).
+std::vector<ExecutorPtr> MakeExecutors(size_t count) {
+ std::vector<ExecutorPtr> executors;
+ executors.reserve(count);
+ for (int i = 0; i < count; ++i) {
+ executors.emplace_back(TFE_NewExecutor(true /* is_async */));
+ }
+ return executors;
+}
+
+} // namespace
+
+ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
+ : underlying_devices_(devices),
+ executors_(MakeExecutors(underlying_devices_.size())) {}
+
+std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
+ TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
+ std::vector<TensorHandlePtr> components;
+ components.reserve(underlying_devices_.size());
+ for (const std::string& underlying_device_name : underlying_devices_) {
+ TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
+ tensor, context, underlying_device_name.c_str(), status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ components.emplace_back(t);
+ }
+ return ParallelTensor::FromTensorHandles(*this, std::move(components),
+ status);
+}
+
+std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs(
+ TFE_Context* context, TF_Status* status) const {
+ // TODO(allenl): We could cache DeviceIDs (keyed by context).
+ std::vector<TensorHandlePtr> components;
+ components.reserve(underlying_devices_.size());
+ for (int device_index = 0; device_index < underlying_devices_.size();
+ ++device_index) {
+ int64_t* device_id = new int64_t;
+ *device_id = device_index;
+ std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> tensor(
+ TF_NewTensor(
+ TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id,
+ sizeof(int64_t),
+ [](void* data, size_t, void* arg) {
+ delete reinterpret_cast<int64_t*>(data);
+ },
+ nullptr),
+ TF_DeleteTensor);
+ // TODO(allenl): Here and when executing regular operations, we could hold
+ // on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing
+ // device names repeatedly.
+ OpPtr const_op(TFE_NewOp(context, "Const", status));
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(),
+ status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64);
+ TFE_TensorHandle* device_handle;
+ int num_outputs = 1;
+ TFE_Execute(const_op.get(), &device_handle, &num_outputs, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ components.emplace_back(device_handle);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ }
+ return ParallelTensor::FromTensorHandles(*this, std::move(components),
+ status);
+}
+
+absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
+ParallelDevice::Execute(TFE_Context* context,
+ std::vector<MaybeParallelTensorUnowned> inputs,
+ const char* operation_name,
+ const TFE_OpAttrs* attributes, int expected_max_outputs,
+ TF_Status* status) const {
+ absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
+ // Compute per-device per-output tensors
+ std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
+ per_device_output_tensors.reserve(underlying_devices_.size());
+ // TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
+ // setting the thread-local executor like this.
+ TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
+ auto reset_executor =
+ tensorflow::gtl::MakeCleanup([context, previous_executor]() {
+ TFE_ContextSetExecutorForThread(context, previous_executor);
+ TFE_DeleteExecutor(previous_executor);
+ });
+ int first_op_output_count;
+ for (int device_index = 0; device_index < underlying_devices_.size();
+ ++device_index) {
+ TFE_Executor* executor = executors_[device_index].get();
+ // Note that the `reset_executor` cleanup sets the thread's executor back to
+ // the value before this function ran.
+ TFE_ContextSetExecutorForThread(context, executor);
+ OpPtr op(TFE_NewOp(context, operation_name, status));
+ if (TF_GetCode(status) != TF_OK) return result;
+ TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
+ status);
+ TFE_OpAddAttrs(op.get(), attributes);
+ for (int input_index = 0; input_index < inputs.size(); ++input_index) {
+ if (absl::holds_alternative<TFE_TensorHandle*>(inputs[input_index])) {
+ // Non-parallel tensors are implicitly broadcast, i.e. set as the input
+ // to each parallel operation.
+ //
+ // TODO(allenl): There may be smarter ways to do this copy in some
+ // cases, i.e. with a collective broadcast. We'll need to be careful
+ // about things that are taken as inputs on the host or on their
+ // existing device (for multi-device functions).
+ TFE_OpAddInput(op.get(),
+ absl::get<TFE_TensorHandle*>(inputs[input_index]),
+ status);
+ if (TF_GetCode(status) != TF_OK) return result;
+ } else {
+ // Parallel tensors are divided between operations by device.
+ TFE_OpAddInput(op.get(),
+ absl::get<ParallelTensor*>(inputs[input_index])
+ ->tensor(device_index),
+ status);
+ if (TF_GetCode(status) != TF_OK) return result;
+ }
+ }
+ std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
+ int real_num_outputs = expected_max_outputs;
+ // For nested devices, the inner device sees the async executor we've
+ // set. Inner parallel devices will just overwrite this with their own and
+ // then set it back to ours before returning. This means parallel devices
+ // which consist of several aliased parallel devices would hypothetically
+ // deadlock if the outer parallel device ran one collective with a group
+ // size equal to the total number of aliased physical devices. Currently
+ // physical devices cannot participate in a single collective reduction
+ // multiple times, so this would fail earlier.
+ //
+ // TODO(allenl): Keep a map from outer executor to list of inner executors
+ // rather than a single list of executors so aliased nested parallel devices
+ // don't re-use an executor.
+ TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
+ if (device_index == 0) {
+ first_op_output_count = real_num_outputs;
+ } else {
+ if (real_num_outputs != first_op_output_count) {
+ TF_SetStatus(status, TF_INTERNAL,
+ "Parallel ops produced different numbers of tensors.");
+ return result;
+ }
+ }
+ if (TF_GetCode(status) != TF_OK) return result;
+ std::vector<TensorHandlePtr> this_outputs;
+ this_outputs.reserve(real_num_outputs);
+ for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
+ this_outputs.emplace_back(op_outputs[output_num]);
+ }
+ per_device_output_tensors.push_back(std::move(this_outputs));
+ }
+ for (int device_index = 0; device_index < underlying_devices_.size();
+ ++device_index) {
+ TFE_Executor* executor = executors_[device_index].get();
+ // TODO(b/157523095): Syncing the executor here shouldn't be
+ // necessary. Currently async+remote is missing cross-executor
+ // coordination.
+ TFE_ExecutorWaitForAllPendingNodes(executor, status);
+ if (TF_GetCode(status) != TF_OK) return result;
+ }
+ // For each output of the original operation, pack the per-device
+ // TensorHandles we've computed into a single parallel TensorHandle.
+ std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs;
+ per_device_outputs.reserve(first_op_output_count);
+ for (int i = 0; i < first_op_output_count; ++i) {
+ std::vector<TensorHandlePtr> components;
+ components.reserve(underlying_devices_.size());
+ for (int j = 0; j < underlying_devices_.size(); ++j) {
+ components.push_back(std::move(per_device_output_tensors[j][i]));
+ }
+ per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
+ *this, std::move(components), status));
+ if (TF_GetCode(status) != TF_OK) return result;
+ }
+ result.emplace(std::move(per_device_outputs));
+ return result;
+}
+
+std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
+ const ParallelDevice& parallel_device,
+ std::vector<TensorHandlePtr> components, TF_Status* status) {
+ TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
+ std::vector<int64_t> shape(
+ TFE_TensorHandleNumDims(components[0].get(), status));
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ for (int i = 0; i < shape.size(); ++i) {
+ shape[i] = TFE_TensorHandleDim(components[0].get(), i, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ }
+
+ // Verify that the TensorHandle's shape and dtype match all of the component
+ // shapes and dtypes.
+ for (TensorHandlePtr& component : components) {
+ for (int i = 0; i < shape.size(); ++i) {
+ int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
+ if (TF_GetCode(status) != TF_OK) return nullptr;
+ if (tensor_dim != shape[i]) {
+ // TODO(allenl): Allow shapes to differ.
+ TF_SetStatus(status, TF_UNIMPLEMENTED,
+ "Components of a ParallelTensor must currently all have "
+ "the same shape");
+ return nullptr;
+ }
+ if (TFE_TensorHandleDataType(component.get()) != dtype) {
+ TF_SetStatus(status, TF_INTERNAL,
+ "Components of a ParallelTensor must all have "
+ "the same dtype");
+ return nullptr;
+ }
+ }
+ }
+
+ return std::unique_ptr<ParallelTensor>(new ParallelTensor(
+ parallel_device, std::move(components), std::move(shape), dtype));
+}
+
+} // namespace parallel_device
+} // namespace tensorflow
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.h b/tensorflow/c/eager/parallel_device/parallel_device_lib.h
new file mode 100644
index 0000000..377377b
--- /dev/null
+++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.h
@@ -0,0 +1,141 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
+#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "absl/types/variant.h"
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/eager/c_api_experimental.h"
+
+namespace tensorflow {
+namespace parallel_device {
+
+// Functor for making unique_ptrs slightly more ergonomic. Using
+// decltype(delete_fn) in the unique_ptr's second template argument requires
+// passing a function pointer to delete_fn when constructing the unique_ptr.
+class TensorHandleDeleter {
+ public:
+ void operator()(TFE_TensorHandle* to_delete) const {
+ TFE_DeleteTensorHandle(to_delete);
+ }
+};
+
+using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
+
+class ExecutorDeleter {
+ public:
+ void operator()(TFE_Executor* to_delete) const {
+ TFE_DeleteExecutor(to_delete);
+ }
+};
+
+using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
+
+class ParallelTensor;
+
+using MaybeParallelTensorUnowned =
+ absl::variant<ParallelTensor*, TFE_TensorHandle*>;
+
+// Forwards operations to `devices`, maintaining ParallelTensor with components
+// placed on each underlying device.
+class ParallelDevice {
+ public:
+ explicit ParallelDevice(const std::vector<std::string>& devices);
+
+ // Helper to copy a tensor handle from another device once for each component
+ // of the ParallelDevice.
+ //
+ // Sets a bad status and returns a nullptr if `tensor` is already on the
+ // ParallelDevice, or if the individual copies fail.
+ std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
+ TFE_TensorHandle* tensor,
+ TF_Status* status) const;
+
+ // A parallel tensor with scalar integers numbering component devices.
+ std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
+ TF_Status* status) const;
+
+ // The number of devices operations run on.
+ size_t num_underlying_devices() const { return underlying_devices_.size(); }
+
+ // Takes a description of a single operation being executed on the
+ // ParallelDevice, and in turn runs one operation per component device with
+ // its corresponding inputs from the input ParallelTensors (or
+ // implicitly-mirrored tensors on other devices). Wraps the resulting
+ // per-device and per-output TFE_TensorHandles into one ParallelTensor per
+ // output of the original operation.
+ //
+ // Attributes are forwarded to executed operations unmodified.
+ //
+ // The returned optional has a value if and only if `status` evaluates to
+ // TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or
+ // if sanity checks on dtypes/metadata fail.
+ absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
+ TFE_Context* context, std::vector<MaybeParallelTensorUnowned> inputs,
+ const char* operation_name, const TFE_OpAttrs* attributes,
+ int expected_max_outputs, TF_Status* status) const;
+
+ private:
+ // A sequence of device names, indicating which devices replicated operations
+ // are forwarded to.
+ const std::vector<std::string> underlying_devices_;
+ // A sequence of TFE_Executors, one per device, for executing operations in
+ // parallel.
+ const std::vector<ExecutorPtr> executors_;
+};
+
+// Contains a tuple of tensors, one on each of the `underlying_devices_` of the
+// ParallelDevice.
+class ParallelTensor {
+ public:
+ // Construct a ParallelTensor from TensorHandles placed on the component
+ // devices of a ParallelDevice.
+ static std::unique_ptr<ParallelTensor> FromTensorHandles(
+ const ParallelDevice& parallel_device,
+ std::vector<TensorHandlePtr> components, TF_Status* status);
+
+ size_t num_tensors() const { return tensors_.size(); }
+ TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
+
+ // A generalization of the shapes of the underlying tensors.
+ const std::vector<int64_t>& shape() const { return shape_; }
+ TF_DataType dtype() const { return dtype_; }
+
+ private:
+ ParallelTensor(const ParallelDevice& device,
+ std::vector<TensorHandlePtr> tensors,
+ std::vector<int64_t> shape, const TF_DataType dtype)
+ : device_(device),
+ tensors_(std::move(tensors)),
+ shape_(std::move(shape)),
+ dtype_(dtype) {}
+
+ const ParallelDevice& device_;
+ const std::vector<TensorHandlePtr> tensors_;
+ const std::vector<int64_t> shape_;
+ const TF_DataType dtype_;
+};
+
+} // namespace parallel_device
+} // namespace tensorflow
+
+#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_
diff --git a/tensorflow/c/eager/parallel_device/parallel_device_testlib.h b/tensorflow/c/eager/parallel_device/parallel_device_testlib.h
index fdd2108..3f91722 100644
--- a/tensorflow/c/eager/parallel_device/parallel_device_testlib.h
+++ b/tensorflow/c/eager/parallel_device/parallel_device_testlib.h
@@ -165,7 +165,7 @@
TF_Status* status) {
TFE_CustomDevice device;
void* device_info;
- tensorflow::eager::AllocateParallelDevice(
+ tensorflow::parallel_device::AllocateParallelDevice(
device_name, underlying_devices.data(), underlying_devices.size(),
&device, &device_info);
TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
diff --git a/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc b/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc
index 62488cb..dd97d90 100644
--- a/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc
+++ b/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc
@@ -52,7 +52,7 @@
tensorflow::Safe_PyObjectPtr device_capsule(
PyCapsule_New(device, "TFE_CustomDevice", &CallDelete_Device));
void* device_info;
- tensorflow::eager::AllocateParallelDevice(
+ tensorflow::parallel_device::AllocateParallelDevice(
name, underlying_devices_c.data(), underlying_devices_c.size(),
device, &device_info);
if (PyErr_Occurred()) throw py::error_already_set();