blob: 6cec9ea849bd87655a31b48dd3826711ab3e8048 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h"
#include <memory>
#include <string>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/python/semaphore.h"
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/threadpool.h"
#include "tensorflow/core/profiler/lib/traceme.h"
namespace xla {
std::string TpuDevice::DebugString() const {
return absl::StrCat("TPU_", id());
}
static std::shared_ptr<Device> MakeDevice(const std::string& platform_name,
int id, int local_device_ordinal) {
CHECK_EQ(platform_name, "tpu");
CHECK_EQ(id, local_device_ordinal); // Every device must be local for now.
return std::make_shared<TpuDevice>(id, local_device_ordinal);
}
StatusOr<std::shared_ptr<PyTpuClient>> PyTpuClient::Get(
const std::string& worker) {
tpu_driver::TpuDriverConfig driver_config;
driver_config.set_worker(worker);
auto client_status = tpu_driver::TpuDriverRegistry::Open(driver_config);
if (!client_status.ok()) {
return client_status.status();
}
auto client = client_status.ConsumeValueOrDie();
tpu_driver::SystemInfo system_info;
client->QuerySystemInfo(&system_info);
int num_cores =
system_info.tpu_chip_size() * system_info.tpu_chip(0).core_size();
std::vector<std::shared_ptr<Device>> devices;
CHECK_GE(num_cores, 1);
LOG(INFO) << "Creating " << num_cores << " TPU device(s).";
devices.reserve(num_cores);
for (int i = 0; i < num_cores; ++i) {
devices.push_back(MakeDevice("tpu", i, i));
}
return std::make_shared<PyTpuClient>("tpu", std::move(client),
std::move(devices), /*host_id=*/0);
}
PyTpuClient::PyTpuClient(std::string platform_name,
std::unique_ptr<tpu_driver::TpuDriver> driver,
std::vector<std::shared_ptr<Device>> devices,
int host_id)
: platform_name_(std::move(platform_name)),
driver_(std::move(driver)),
devices_(std::move(devices)),
host_id_(host_id) {
local_devices_.resize(devices_.size());
for (const std::shared_ptr<Device>& device : devices_) {
CHECK(id_to_device_.insert({device->id(), device}).second)
<< "Duplicate device id: " << device->id();
if (device->local_device_ordinal() != -1) {
int idx = device->local_device_ordinal();
CHECK(local_devices_[idx] == nullptr) << idx;
CHECK_LT(idx, local_devices_.size());
local_devices_[idx] = device;
}
}
for (int idx = 0; idx < local_devices_.size(); ++idx) {
CHECK(local_devices_[idx] != nullptr) << idx;
}
// TODO(frankchn): Check if thread pool size needs to be adjusted (perhaps
// something like min(cores, devices_.size()) might be appropriate depending
// on the number of devices.
pool_ = std::make_unique<tensorflow::thread::ThreadPool>(
tensorflow::Env::Default(), "PyTpuClient", devices_.size());
}
Status PyTpuClient::TransferToInfeed(const LiteralSlice& literal,
int device_ordinal) {
return Unimplemented("Infeed not implemented.");
}
StatusOr<Literal> PyTpuClient::TransferFromOutfeed(const Shape& shape,
int device_ordinal) {
return Unimplemented("Outfeed not implemented.");
}
StatusOr<DeviceAssignment> PyTpuClient::GetDefaultDeviceAssignment(
int num_replicas) const {
// Copied from xla::ComputationPlace::AssignDevices assuming computation_count
// = 1. Assign devices for each computation. Replicas are assigned to each
// device in order.
DeviceAssignment assignment(num_replicas, 1);
for (int replica = 0; replica < num_replicas; ++replica) {
assignment(replica, 0) = replica;
}
return std::move(assignment);
}
Status PyTpuClient::CheckDeviceOrdinal(int device_ordinal,
absl::string_view caller_name) {
if (device_ordinal < 0 || device_ordinal >= local_device_count()) {
return InvalidArgument(
"%s got bad device_ordinal: %d (num_local_devices=%d)", caller_name,
device_ordinal, local_device_count());
}
return Status::OK();
}
static Status CheckDataType(xla::PrimitiveType dtype) {
if (dtype == xla::PrimitiveType::F64 || dtype == xla::PrimitiveType::S64 ||
dtype == xla::PrimitiveType::U64) {
return InvalidArgument(
"64-bit data types are not yet supported on the TPU driver API. "
"Convert inputs to float32/int32 before using.");
}
return Status::OK();
}
/* static */
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::FromLiterals(
std::vector<BorrowingLiteral> leaves, const Shape& tuple_shape,
std::shared_ptr<void> leaves_references,
std::shared_ptr<PyTpuClient> client, int device_ordinal) {
tensorflow::profiler::TraceMe traceme("PyTpuBuffer::FromLiterals");
VLOG(1) << "PyTpuBuffer::FromLiterals: shape: " << tuple_shape.DebugString()
<< " device ordinal: " << device_ordinal;
TF_RETURN_IF_ERROR(
client->CheckDeviceOrdinal(device_ordinal, "PyTpuBuffer::FromLiterals"));
tpu_driver::TpuDriver* driver = client->driver();
if (!tuple_shape.IsTuple()) {
TF_RET_CHECK(leaves.size() == 1);
return CreateBuffer(
tuple_shape,
[driver, &leaves, &tuple_shape,
leaves_references](tpu_driver::BufferHandle* handle) {
auto event =
driver->TransferToDevice(leaves[0].untyped_data(), handle, {});
event->AddCallback([leaves_references](Status) {});
return event;
},
std::move(client), device_ordinal);
}
std::vector<std::unique_ptr<PyTpuBuffer>> child_buffers;
child_buffers.reserve(leaves.size());
std::vector<PyTpuBuffer*> child_buffer_ptrs;
child_buffer_ptrs.reserve(leaves.size());
auto it_leaf = leaves.begin();
for (const ShapeUtil::IndexedShape& indexed_shape :
ShapeUtil::GetLeafShapes(tuple_shape)) {
TF_RET_CHECK(it_leaf != leaves.end());
auto& leaf = *it_leaf;
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PyTpuBuffer> child_buffer,
CreateBuffer(
indexed_shape.shape,
[driver, &leaf, &indexed_shape](tpu_driver::BufferHandle* handle) {
return driver->TransferToDevice(leaf.untyped_data(), handle, {});
},
client, device_ordinal));
child_buffer_ptrs.push_back(child_buffer.get());
child_buffers.push_back(std::move(child_buffer));
++it_leaf;
}
TF_RET_CHECK(it_leaf == leaves.end());
// `MakeTuple` will extract and make the tuple buffer hold onto the
// `device_buffer_` contained in each `child_buffer`, so it's safe for
// `child_buffers` to get destroyed before this call returns.
return MakeTuple(std::move(child_buffer_ptrs), std::move(client),
device_ordinal);
}
/* static */
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::MakeTuple(
const std::vector<PyTpuBuffer*> buffers,
std::shared_ptr<PyTpuClient> client, int device_ordinal) {
std::vector<Shape> child_shapes;
std::vector<std::shared_ptr<TpuSharedBuffer>> child_device_buffers;
std::vector<tpu_driver::BufferHandle*> child_handle_ptrs;
std::vector<std::shared_ptr<tpu_driver::Event>> child_events;
for (const auto& child_buffer : buffers) {
child_shapes.push_back(child_buffer->on_host_shape());
std::shared_ptr<TpuSharedBuffer> child_device_buffer =
child_buffer->DeviceBuffer();
// Merge all definition events from all children, so that anyone using this
// tuple must wait for all its children to finish receiving transfers.
// This works recursively up a nested tuple tree as well.
for (std::shared_ptr<tpu_driver::Event> child_event :
child_device_buffer->wait_for_use) {
child_events.push_back(std::move(child_event));
}
child_handle_ptrs.push_back(child_device_buffer->handle.get());
child_device_buffers.push_back(std::move(child_device_buffer));
}
Shape tuple_shape = ShapeUtil::MakeTupleShape(child_shapes);
std::unique_ptr<tpu_driver::BufferHandle> tuple_handle =
client->driver()->AllocateTuple(
device_ordinal, tpu_driver::MemoryRegion::HBM, child_handle_ptrs, {});
auto tuple_device_buffer = std::make_shared<TpuSharedBuffer>(
client->driver(), std::move(tuple_handle), std::move(child_events),
device_ordinal);
return absl::make_unique<PyTpuBuffer>(
tuple_shape, std::move(tuple_device_buffer),
std::move(child_device_buffers), std::move(client));
}
PyTpuBuffer::PyTpuBuffer(
Shape on_host_shape, std::shared_ptr<TpuSharedBuffer> device_buffer,
std::vector<std::shared_ptr<TpuSharedBuffer>> child_buffers,
std::shared_ptr<PyTpuClient> client)
: client_(std::move(client)),
on_host_shape_(std::move(on_host_shape)),
device_ordinal_(device_buffer->device_ordinal),
device_buffer_(std::move(device_buffer)),
child_buffers_(std::move(child_buffers)) {}
void PyTpuBuffer::Delete() {
absl::MutexLock lock(&mu_);
device_buffer_ = nullptr;
child_buffers_.clear();
host_value_ = nullptr;
}
Status PyTpuBuffer::CopyToHostAsync() {
std::vector<std::shared_ptr<tpu_driver::Event>> transfer_events;
std::shared_ptr<HostValue> host_value = std::make_shared<HostValue>();
{
absl::MutexLock lock(&mu_);
if (!device_buffer_) {
return InvalidArgument("CopyToHostAsync() called on invalid buffer.");
}
if (host_value_) {
// The host value has already been requested or is available.
return Status::OK();
}
host_value->value = std::make_shared<Literal>(on_host_shape_);
host_value->pending_ops = std::max(1ul, child_buffers_.size());
host_value_ = host_value;
std::vector<tpu_driver::Event*> events;
for (const auto& e : device_buffer_->wait_for_use) {
events.push_back(e.get());
}
VLOG(1) << "CopyToHostAsync:: host shape: "
<< host_value->value->shape().DebugString();
if (!on_host_shape_.IsTuple()) {
CHECK(child_buffers_.empty());
transfer_events.push_back(client_->driver()->TransferFromDevice(
device_buffer_->handle.get(), host_value->value->untyped_data(),
events));
} else {
for (int i = 0; i < child_buffers_.size(); ++i) {
auto& c = child_buffers_[i];
transfer_events.push_back(client_->driver()->TransferFromDevice(
c->handle.get(),
host_value->value->untyped_data(xla::ShapeIndex({i})), events));
}
}
}
for (auto& t : transfer_events) {
t->AddCallback([host_value](const xla::Status& status) {
VLOG(1) << "Device to host transfer finished.";
if (!status.ok()) {
host_value->status =
Status(static_cast<tensorflow::error::Code>(status.code()),
status.error_message());
}
absl::MutexLock m(&host_value->mutex);
--host_value->pending_ops;
if (host_value->pending_ops == 0) {
VLOG(1) << "Host value done: " << host_value->status;
host_value->ready.Notify();
}
});
}
return Status::OK();
}
StatusOr<std::shared_ptr<Literal>> PyTpuBuffer::ToLiteral() {
tensorflow::profiler::TraceMe traceme("PyTpuBuffer::ToLiteral");
TF_RETURN_IF_ERROR(CopyToHostAsync());
mu_.Lock();
std::shared_ptr<HostValue> host_value = host_value_;
mu_.Unlock();
VLOG(1) << "Waiting for device to host transfer " << host_value.get();
host_value->ready.WaitForNotification();
VLOG(1) << "Host copy finished, status:" << host_value->status;
TF_RETURN_IF_ERROR(host_value->status);
return host_value->value;
}
std::shared_ptr<TpuSharedBuffer> PyTpuBuffer::DeviceBuffer() const {
absl::MutexLock lock(&mu_);
return device_buffer_;
}
StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>>
PyTpuBuffer::DestructureTuple() {
tensorflow::profiler::TraceMe traceme("PyTpuBuffer::DestructureTuple");
if (!on_host_shape_.IsTuple()) {
return InvalidArgument(
"Attemped to destructure a PyTpuBuffer that did not have a tuple "
"shape; shape: %s",
ShapeUtil::HumanString(on_host_shape_));
}
if (DeviceBuffer() == nullptr) {
return InvalidArgument("Attempted to destructure a deleted buffer");
}
absl::MutexLock lock(&mu_);
int num_children = ShapeUtil::TupleElementCount(on_host_shape_);
std::vector<std::unique_ptr<PyTpuBuffer>> results;
results.reserve(num_children);
for (int i = 0; i < num_children; ++i) {
results.push_back(absl::make_unique<PyTpuBuffer>(
on_host_shape_.tuple_shapes(i), child_buffers_.at(i),
std::vector<std::shared_ptr<TpuSharedBuffer>>(), client_));
}
return results;
}
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::CopyToDevice(
int dst_device_ordinal) {
tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CopyToDevice");
if (on_host_shape_.IsTuple()) {
return Unimplemented("CopyToDevice for tuples is not supported.");
}
std::shared_ptr<TpuSharedBuffer> src_device_buffer = DeviceBuffer();
if (dst_device_ordinal == device_ordinal_) {
return absl::make_unique<PyTpuBuffer>(
on_host_shape_, src_device_buffer,
std::vector<std::shared_ptr<TpuSharedBuffer>>(), client_);
}
tpu_driver::TpuDriver* driver = client_->driver();
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PyTpuBuffer> dst_buffer,
CreateBuffer(
on_host_shape_,
[driver, src_device_buffer](tpu_driver::BufferHandle* dst_handle) {
std::vector<tpu_driver::Event*> src_wait_for_use;
for (auto& event : src_device_buffer->wait_for_use) {
src_wait_for_use.push_back(event.get());
}
return driver->TransferFromDeviceToDevice(
src_device_buffer->handle.get(), dst_handle, src_wait_for_use);
},
client_, dst_device_ordinal));
// TODO(jiawenhao): This may be too pessimistic: it prevents future readers
// from reading `src_device_buffer` until the device-to-device copy is done.
// Should this go into a new `TpuSharedBuffer::wait_for_dealloc` field?
auto& wait_for_use = dst_buffer->DeviceBuffer()->wait_for_use;
src_device_buffer->wait_for_use.insert(src_device_buffer->wait_for_use.end(),
wait_for_use.begin(),
wait_for_use.end());
return dst_buffer;
}
Status PyTpuBuffer::BlockHostUntilReady() {
tensorflow::profiler::TraceMe traceme("PyTpuBuffer::BlockHostUntilReady");
return DeviceBuffer()->handle->OnReady()->Await();
}
/* static */
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::AllocateBuffer(
const Shape& shape, std::shared_ptr<PyTpuClient> client,
int device_ordinal) {
tensorflow::profiler::TraceMe traceme("PyTpuBuffer::AllocateBuffer");
VLOG(1) << "PyTpuBuffer::AllocateBuffer: shape: " << shape.DebugString()
<< " device ordinal: " << device_ordinal;
if (!shape.IsTuple()) {
return CreateBuffer(shape, absl::nullopt, std::move(client),
device_ordinal);
}
std::vector<std::unique_ptr<PyTpuBuffer>> child_buffers;
child_buffers.reserve(shape.tuple_shapes().size());
std::vector<PyTpuBuffer*> child_buffer_ptrs;
child_buffer_ptrs.reserve(shape.tuple_shapes().size());
for (const auto& child_shape : shape.tuple_shapes()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<PyTpuBuffer> child_buffer,
AllocateBuffer(child_shape, client, device_ordinal));
child_buffer_ptrs.push_back(child_buffer.get());
child_buffers.push_back(std::move(child_buffer));
}
// `MakeTuple` will extract and make the tuple buffer hold onto the
// `device_buffer_` contained in each `child_buffer`, so it's safe for
// `child_buffers` to get destroyed before this call returns.
return PyTpuBuffer::MakeTuple(child_buffer_ptrs, std::move(client),
device_ordinal);
}
/*static*/
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuBuffer::CreateBuffer(
const Shape& non_tuple_shape, absl::optional<BufferInitializer> initializer,
std::shared_ptr<PyTpuClient> client, int device_ordinal) {
tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CreateBuffer");
VLOG(1) << "PyTpuBuffer::CreateBuffer: shape: "
<< non_tuple_shape.DebugString()
<< " device ordinal: " << device_ordinal;
TF_RET_CHECK(!non_tuple_shape.IsTuple());
TF_RETURN_IF_ERROR(CheckDataType(non_tuple_shape.element_type()));
std::unique_ptr<tpu_driver::BufferHandle> handle =
client->driver()->Allocate(device_ordinal, tpu_driver::MemoryRegion::HBM,
non_tuple_shape.ToProto(), {});
// If this buffer needs to be initialized, anyone using this buffer must wait
// for the initialization event in `wait_for_use` to finish first.
std::vector<std::shared_ptr<tpu_driver::Event>> wait_for_use;
if (initializer.has_value()) {
std::shared_ptr<tpu_driver::Event> init = initializer.value()(handle.get());
wait_for_use.push_back(std::move(init));
}
auto device_buffer = std::make_shared<TpuSharedBuffer>(
client->driver(), std::move(handle), std::move(wait_for_use),
device_ordinal);
return absl::make_unique<PyTpuBuffer>(
non_tuple_shape, std::move(device_buffer),
std::vector<std::shared_ptr<TpuSharedBuffer>>(), client);
}
static std::shared_ptr<Device> LookupDevice(const PyTpuClient& client,
int device_id) {
auto it = client.id_to_device().find(device_id);
CHECK(it != client.id_to_device().end())
<< "Unknown device id: " << device_id;
return it->second;
}
PyTpuExecutable::PyTpuExecutable(
std::vector<std::unique_ptr<tpu_driver::LoadedProgramHandle>> executables,
DeviceAssignment device_assignment, std::shared_ptr<PyTpuClient> client,
xla::Shape result_shape)
: client_(std::move(client)),
executables_(std::move(executables)),
device_assignment_(std::move(device_assignment)),
result_shape_(std::move(result_shape)) {
const int num_replicas = device_assignment_.replica_count();
for (int replica = 0; replica < num_replicas; ++replica) {
const int device_id = device_assignment_(replica, 0);
std::shared_ptr<Device> device = LookupDevice(*client_, device_id);
if (device->host_id() != client_->host_id()) {
VLOG(3) << "Non-local device: " << device_id;
continue;
}
local_replicas_.push_back(replica);
local_devices_.push_back(device);
}
CHECK_GE(local_replicas_.size(), 1);
CHECK_EQ(local_replicas_.size(), executables_.size());
}
PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper(
absl::Span<const std::vector<PyTpuBuffer*>> all_core_arguments,
absl::Span<PyTpuBuffer* const> this_core_arguments, int replica,
const RunId& run_id) {
const int device_id = device_assignment_(replica, 0);
std::shared_ptr<Device> device = LookupDevice(*client_, device_id);
CHECK_EQ(device->host_id(), client_->host_id());
int device_ordinal = device->local_device_ordinal();
tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Execute");
VLOG(3) << "Replica " << replica
<< " mapped to device ordinal for execution: " << device_ordinal;
std::unique_ptr<::xla::PyTpuBuffer> output_buffer =
::xla::PyTpuBuffer::AllocateBuffer(result_shape_, client_, device_ordinal)
.ValueOrDie();
VLOG(1) << "Created output buffer: " << result_shape_.DebugString();
std::vector<tpu_driver::BufferHandle*> inputs;
std::vector<tpu_driver::Event*> ready_to_execute;
std::shared_ptr<tpu_driver::Event> output_buffer_ready =
output_buffer->DeviceBuffer()->handle->OnReady();
ready_to_execute.push_back(output_buffer_ready.get());
for (auto* input_handle : this_core_arguments) {
inputs.push_back(input_handle->DeviceBuffer()->handle.get());
}
for (const auto& core_args : all_core_arguments) {
for (const auto* handle : core_args) {
for (auto pending_event : handle->DeviceBuffer()->wait_for_use) {
ready_to_execute.push_back(pending_event.get());
}
}
}
xla::DeviceAssignmentProto device_assignment;
CHECK(device_assignment_.Serialize(&device_assignment).ok());
std::shared_ptr<tpu_driver::Event> on_execute_finished =
client_->driver()->ExecuteProgram(
executables_[replica].get(), inputs,
{output_buffer->DeviceBuffer()->handle.get()}, device_assignment,
{ready_to_execute});
return {std::move(output_buffer), std::move(on_execute_finished)};
}
// Delay before warning about a slow execute call.
static const absl::Duration kWarnExecutionDelay = absl::Seconds(10);
// Delay before terminating a stalled execute call.
static const absl::Duration kMaxExecutionDelay = absl::Seconds(120);
Status WaitForExecuteEvent(tpu_driver::Event* event) {
absl::optional<Status> opt_status;
auto start_time = absl::Now();
while (!opt_status.has_value() &&
absl::Now() - start_time < kMaxExecutionDelay) {
opt_status = event->AwaitWithTimeout(kWarnExecutionDelay);
if (!opt_status.has_value()) {
LOG(WARNING)
<< "TPU Execute is taking a long time. This might be due to a "
"deadlock between multiple TPU cores or a very slow program.";
}
}
if (!opt_status.has_value()) {
return tensorflow::errors::DeadlineExceeded(
absl::StrFormat("TPU program took more than %d seconds to complete.",
absl::ToInt64Seconds(kMaxExecutionDelay)));
}
return opt_status.value();
}
StatusOr<std::unique_ptr<PyTpuBuffer>> PyTpuExecutable::Execute(
absl::Span<PyTpuBuffer* const> argument_handles) {
if (num_replicas() != 1) {
return InvalidArgument(
"Attempted to execute computation with %d replicas using Execute()",
num_replicas());
}
std::vector<PyTpuBuffer*> all_core_arguments(argument_handles.begin(),
argument_handles.end());
ExecuteResult result =
ExecuteHelper(absl::MakeSpan(&all_core_arguments, 1), argument_handles,
/*replica=*/0, RunId());
Status status = WaitForExecuteEvent(result.on_execute_finished.get());
if (!status.ok()) {
LOG(ERROR) << "Failed to execute program: " << status;
return status;
}
return std::move(result.buffer);
}
StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>>
PyTpuExecutable::ExecutePerReplica(
absl::Span<const std::vector<PyTpuBuffer*>> argument_handles) {
tensorflow::profiler::TraceMe traceme("PyTpuExecutable::ExecutePerReplica");
int num_local_replicas = local_replicas_.size();
const int num_local_devices = client_->local_device_count();
if (argument_handles.size() != num_local_replicas) {
return InvalidArgument(
"Attempted to execute with %d local replicas when local replica count "
"is %d (total replica count: %d)",
argument_handles.size(), num_local_replicas, num_replicas());
}
if (argument_handles.size() > num_local_devices) {
return InvalidArgument(
"Attempted to execute with %d replicas when device count is %d",
argument_handles.size(), num_local_devices);
}
VLOG(1) << "Executing replicated computation; num_replicas=" << num_replicas()
<< " num_local_replicas=" << num_local_replicas;
absl::Mutex results_lock;
std::vector<ExecuteResult> results(num_local_replicas);
auto* thread_pool = client_->GetThreadPool();
int failed = 0;
Status first_failure_status;
xla::Semaphore execute_semaphore(0);
for (int i = 0; i < num_local_replicas; ++i) {
// We are scheduling Execute on a thread pool as ExecuteHelper can take a
// long time and we want all cores to be scheduled in parallel.
thread_pool->Schedule([this, i, argument_handles, &results, &results_lock,
&execute_semaphore]() {
const int replica = local_replicas_[i];
RunId run_id;
auto result =
ExecuteHelper(argument_handles, argument_handles[i], replica, run_id);
results[i] = std::move(result);
execute_semaphore.Release(1);
});
}
execute_semaphore.Acquire(num_local_replicas);
for (int i = 0; i < num_local_replicas; ++i) {
auto s = WaitForExecuteEvent(results[i].on_execute_finished.get());
if (!s.ok()) {
if (failed == 0) {
first_failure_status = Status(
static_cast<tensorflow::error::Code>(s.code()), s.error_message());
}
++failed;
}
}
if (failed > 0) {
return first_failure_status;
}
VLOG(1) << "Replicated execution complete.";
std::vector<std::unique_ptr<PyTpuBuffer>> wrapped_results(num_local_replicas);
for (int i = 0; i < num_local_replicas; ++i) {
wrapped_results[i] = std::move(results[i].buffer);
}
return wrapped_results;
}
/*static*/ StatusOr<std::unique_ptr<PyTpuExecutable>> PyTpuExecutable::Compile(
const XlaComputation& computation,
absl::optional<std::vector<Shape>> argument_layouts,
const ExecutableBuildOptions* build_options,
std::shared_ptr<PyTpuClient> client,
absl::optional<DeviceAssignment> device_assignment) {
tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Compile");
VLOG(1) << "Compile: "
<< computation.GetProgramShape().ValueOrDie().DebugString();
// TODO(power) -- handle argument layouts
// TODO(power) -- handle build options
ExecutableBuildOptions options;
if (build_options != nullptr) {
options = *build_options;
}
if (device_assignment) {
if (device_assignment->replica_count() != options.num_replicas()) {
return InvalidArgument(
"Mismatched number of replicas for device "
"assignment and computation (%d vs %d).",
device_assignment->replica_count(), options.num_replicas());
} else if (device_assignment->computation_count() != 1) {
return Unimplemented(
"Only 1 computation per replica supported, %d requested.",
device_assignment->computation_count());
}
} else {
TF_ASSIGN_OR_RETURN(device_assignment, client->GetDefaultDeviceAssignment(
options.num_replicas()));
}
CHECK_GE(options.num_replicas(), 1);
CHECK_EQ(options.num_replicas(), device_assignment->replica_count());
CHECK(!argument_layouts.has_value());
// TODO(henrytan): an area for optimization with less buffer copy.
xla::HloProto hlo_proto;
*hlo_proto.mutable_hlo_module() = computation.proto();
// TODO(henrytan): in the future, we want to consider argument Layout
// information e.g. for linearization.
std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program =
client->driver()->CompileProgram(hlo_proto, options.num_replicas(), {});
::xla::Shape result_layout;
if (options.result_layout()) {
result_layout = *options.result_layout();
} else {
xla::ProgramShapeProto program_shape_proto;
auto fetch_metadata_status =
compiled_program->program_shape(&program_shape_proto);
if (!fetch_metadata_status.ok()) {
return Status(
static_cast<tensorflow::error::Code>(fetch_metadata_status.code()),
fetch_metadata_status.error_message());
}
result_layout = ::xla::Shape(program_shape_proto.result());
}
VLOG(1) << "Got result shape: " << result_layout.DebugString();
std::vector<std::unique_ptr<tpu_driver::LoadedProgramHandle>> loaded_programs;
loaded_programs.resize(options.num_replicas());
for (int replica = 0; replica < options.num_replicas(); ++replica) {
const int device_id = (*device_assignment)(replica, 0);
std::shared_ptr<Device> device = LookupDevice(*client, device_id);
CHECK_EQ(device->host_id(), client->host_id());
int device_ordinal = device->local_device_ordinal();
loaded_programs[replica] = client->driver()->LoadProgram(
device_ordinal, compiled_program.get(), {});
}
return absl::make_unique<PyTpuExecutable>(
std::move(loaded_programs), std::move(*device_assignment),
std::move(client), std::move(result_layout));
}
} // namespace xla