blob: eec893e704d97df8fb0635af57497852fd819e5f [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/parallel_device/parallel_device.h"
#include <memory>
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/c/tf_status.h"
namespace tensorflow {
namespace parallel_device {
namespace {
class OpDeleter {
public:
void operator()(TFE_Op* to_delete) const { TFE_DeleteOp(to_delete); }
};
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
using MaybeParallelTensorOwned =
absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
// A ParallelDevice on its own is not registered with a TFE_Context, and so has
// no device name (e.g. for `tf.device`). `NamedParallelDevice` associates a
// name with it, which lets us pack its `ParallelTensor`s into TFE_TensorHandles
// placed on the parallel device.
class NamedParallelDevice {
public:
NamedParallelDevice(const std::string& name,
std::unique_ptr<ParallelDevice> parallel_device)
: device_name_(name), parallel_device_(std::move(parallel_device)) {}
const std::string& name() const { return device_name_; }
const ParallelDevice& device() const { return *parallel_device_; }
private:
std::string device_name_;
std::unique_ptr<ParallelDevice> parallel_device_;
};
absl::optional<std::vector<MaybeParallelTensorOwned>> ExecuteWithSpecialOps(
const ParallelDevice& parallel_device,
const std::string& parallel_device_name, TFE_Context* context,
std::vector<MaybeParallelTensorUnowned> inputs, const char* operation_name,
const TFE_OpAttrs* attributes, int expected_max_outputs,
TF_Status* status) {
absl::optional<std::vector<MaybeParallelTensorOwned>> result;
// TODO(allenl): We should remove "TPU" from these op names at the very least,
// or consider other ways of packing/unpacking parallel tensors.
if (operation_name == std::string("TPUReplicatedInput")) {
// Special-cased operation for packing per-device tensors into one parallel
// tensor.
if (inputs.size() != parallel_device.num_underlying_devices()) {
std::string message(absl::StrCat(
"The parallel device ", parallel_device_name, " expected ",
parallel_device.num_underlying_devices(),
" inputs to TPUReplicatedInput, but got ", inputs.size()));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
}
std::vector<TensorHandlePtr> components;
components.reserve(inputs.size());
for (int i = 0; i < inputs.size(); ++i) {
if (absl::holds_alternative<ParallelTensor*>(inputs[i])) {
std::string message(absl::StrCat(
"Expected all inputs to TPUReplicatedInput to be non-parallel "
"TensorHandles. The input ",
i,
" was a parallel tensor (already "
"placed on the parallel device)."));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
}
components.emplace_back(TFE_TensorHandleCopySharingTensor(
absl::get<TFE_TensorHandle*>(inputs[i]), status));
}
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
result_content.push_back(ParallelTensor::FromTensorHandles(
parallel_device, std::move(components), status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
} else if (operation_name == std::string("TPUReplicatedOutput")) {
// Special-cased operation for un-packing one parallel tensor into
// per-device tensors.
OpPtr op(TFE_NewOp(context, operation_name, status));
TFE_OpAddAttrs(op.get(), attributes);
int expected_outputs = TFE_OpGetOutputLength(op.get(), "outputs", status);
if (TF_GetCode(status) != TF_OK) return result;
if (expected_outputs != parallel_device.num_underlying_devices()) {
std::string message(absl::StrCat(
"The parallel device ", parallel_device_name, " expected ",
parallel_device.num_underlying_devices(),
" outputs for TPUReplicatedOutput, but got ", expected_outputs));
TF_SetStatus(status, TF_INVALID_ARGUMENT, message.c_str());
return result;
}
if (absl::holds_alternative<TFE_TensorHandle*>(inputs[0])) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"Expected the input to "
"TPUReplicatedOutput to be a parallel tensor (placed on the "
"parallel device).");
return result;
}
ParallelTensor* t = absl::get<ParallelTensor*>(inputs[0]);
std::vector<MaybeParallelTensorOwned> outputs;
outputs.reserve(t->num_tensors());
for (int i = 0; i < t->num_tensors(); ++i) {
TensorHandlePtr this_output(
TFE_TensorHandleCopySharingTensor(t->tensor(i), status));
outputs.emplace_back(std::move(this_output));
if (TF_GetCode(status) != TF_OK) return result;
}
result.emplace(std::move(outputs));
return result;
} else if (operation_name == std::string("DeviceID")) {
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(1);
result_content.push_back(parallel_device.DeviceIDs(context, status));
if (TF_GetCode(status) != TF_OK) return result;
result.emplace(std::move(result_content));
return result;
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
maybe_parallel_results(
parallel_device.Execute(context, std::move(inputs), operation_name,
attributes, expected_max_outputs, status));
if (!maybe_parallel_results.has_value()) return result;
std::vector<std::unique_ptr<ParallelTensor>> parallel_results(
std::move(maybe_parallel_results.value()));
std::vector<MaybeParallelTensorOwned> result_content;
result_content.reserve(parallel_results.size());
for (std::unique_ptr<ParallelTensor>& parallel_result : parallel_results) {
result_content.push_back(
MaybeParallelTensorOwned(std::move(parallel_result)));
}
result.emplace(std::move(result_content));
return result;
}
// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how
// ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their
// reference counts drop to zero.
void ParallelTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<ParallelTensor*>(data);
}
TensorHandlePtr ParallelTensorToTensorHandle(
const std::string& parallel_device_name, TFE_Context* context,
std::unique_ptr<ParallelTensor> t, TF_Status* status) {
// The resulting TensorHandle owns an opaque pointer to "device memory", which
// for a ParallelDevice is really a ParallelTensor. When the TensorHandle is
// deleted, it will call ParallelTensorDeallocator to free the struct.
ParallelTensor* t_released = t.release();
const std::vector<int64_t>& shape(t_released->shape());
return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory(
context, parallel_device_name.c_str(), t_released->dtype(), shape.data(),
shape.size(), t_released, 1, &ParallelTensorDeallocator, nullptr,
status));
}
// For TFE_CustomDevice::copy_tensor_to_device in the parallel device
// registration.
//
// Replicates a single TFE_TensorHandle, producing a TFE_TensorHandle containing
// a ParallelTensor with one copy of `tensor` for each device in the
// ParallelDevice.
//
// Since this function is used to satisfy the TFE_CustomDevice C API,
// device_info is passed in using a C-style generic. It must always be a
// ParallelDevice.
TFE_TensorHandle* CopyToParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
NamedParallelDevice* named_device =
reinterpret_cast<NamedParallelDevice*>(device_info);
const ParallelDevice& dev = named_device->device();
std::unique_ptr<ParallelTensor> parallel_tensor(
dev.CopyToParallelDevice(context, tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
return ParallelTensorToTensorHandle(named_device->name(), context,
std::move(parallel_tensor), status)
.release();
}
// For TFE_CustomDevice::copy_tensor_from_device in the parallel device
// registration.
//
// Currently this is an error, and un-packing ParallelTensors must be performed
// explicitly by running a TPUReplicatedOutput operation on the parallel device.
//
// TODO(allenl): There are some use-cases that are only supported by copying to
// host at the moment (e.g. debug print on a tensor, .numpy(), etc.). We either
// need to return something here or address these use-cases one by one.
TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_INTERNAL,
"Trying to copy a tensor out of a parallel device.");
return nullptr;
}
// For TFE_CustomDevice::execute in the parallel device registration.
//
// Since this function is used to satisfy the TFE_CustomDevice C API,
// device_info is passed in using a C-style generic. It must always be a
// ParallelDevice.
void ParallelDeviceExecute(TFE_Context* context, int num_inputs,
TFE_TensorHandle** inputs,
const char* operation_name,
const TFE_OpAttrs* attributes, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* status,
void* device_info) {
NamedParallelDevice* named_device =
reinterpret_cast<NamedParallelDevice*>(device_info);
std::vector<MaybeParallelTensorUnowned> typed_inputs;
typed_inputs.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
const char* tensor_handle_device =
TFE_TensorHandleDeviceName(inputs[i], status);
if (TF_GetCode(status) != TF_OK) return;
if (named_device->name() == tensor_handle_device) {
// We assume that any tensors already placed on this device are
// ParallelTensors.
typed_inputs.emplace_back(reinterpret_cast<ParallelTensor*>(
TFE_TensorHandleDevicePointer(inputs[i], status)));
if (TF_GetCode(status) != TF_OK) return;
} else {
typed_inputs.emplace_back(inputs[i]);
}
}
absl::optional<std::vector<MaybeParallelTensorOwned>> maybe_typed_outputs(
ExecuteWithSpecialOps(named_device->device(), named_device->name(),
context, std::move(typed_inputs), operation_name,
attributes, *num_outputs, status));
if (TF_GetCode(status) != TF_OK) return;
if (!maybe_typed_outputs.has_value()) {
TF_SetStatus(status, TF_INTERNAL, "OK status but no value was returned.");
return;
}
std::vector<MaybeParallelTensorOwned> typed_outputs(
std::move(maybe_typed_outputs.value()));
if (typed_outputs.size() > *num_outputs) {
TF_SetStatus(status, TF_INTERNAL,
"The allocated output buffer was too small.");
return;
}
for (int i = 0; i < typed_outputs.size(); ++i) {
MaybeParallelTensorOwned typed_output(std::move(typed_outputs[i]));
if (absl::holds_alternative<TensorHandlePtr>(typed_output)) {
outputs[i] = absl::get<TensorHandlePtr>(typed_output).release();
} else {
outputs[i] = ParallelTensorToTensorHandle(
named_device->name(), context,
std::move(absl::get<std::unique_ptr<ParallelTensor>>(
typed_output)),
status)
.release();
if (TF_GetCode(status) != TF_OK) return;
}
}
*num_outputs = typed_outputs.size();
}
// For TFE_CustomDevice::delete_device in the parallel device registration.
//
// Since this function is used to satisfy the TFE_CustomDevice C API,
// device_info is passed in using a C-style generic. It must always be a
// ParallelDevice.
void DeleteParallelDevice(void* device_info) {
delete reinterpret_cast<NamedParallelDevice*>(device_info);
}
} // namespace
void AllocateParallelDevice(const char* device_name,
const char* const* underlying_devices,
int num_underlying_devices,
TFE_CustomDevice* device, void** device_info) {
device->copy_tensor_to_device = &CopyToParallelDevice;
device->copy_tensor_from_device = &CopyTensorFromParallelDevice;
device->delete_device = &DeleteParallelDevice;
device->execute = &ParallelDeviceExecute;
std::vector<std::string> underlying_devices_vector;
underlying_devices_vector.reserve(num_underlying_devices);
for (int device_index = 0; device_index < num_underlying_devices;
++device_index) {
underlying_devices_vector.push_back(underlying_devices[device_index]);
}
std::unique_ptr<ParallelDevice> parallel_device(
new ParallelDevice(underlying_devices_vector));
*device_info =
new NamedParallelDevice{device_name, std::move(parallel_device)};
}
} // namespace parallel_device
} // namespace tensorflow