blob: 114014e4e132adf1f21f751a3132b393fb504ed7 [file] [log] [blame]
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_split.h"
#include "tensorflow/compiler/xla/pjrt/semaphore.h"
#include "tensorflow/compiler/xla/pjrt/worker_thread.h"
#include "tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.h"
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
#include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
namespace tpu_driver {
namespace {
using xla::Status;
using xla::WorkerThread;
const char kPodTpuDriverPrefix[] = "grpc+pod://";
class PodTpuDriver;
class PodEvent : public Event {
public:
explicit PodEvent(PodTpuDriver* driver, int64_t operation_id)
: driver_(driver), operation_id_(operation_id) {}
int64_t operation_id() const { return operation_id_; }
xla::Status Await() override;
absl::optional<xla::Status> AwaitWithTimeout(
absl::Duration duration) override;
void AddCallback(std::function<void(Status)> callback) override;
private:
PodTpuDriver* driver_;
const int64_t operation_id_;
};
class CombinedEvent : public PodEvent {
public:
explicit CombinedEvent(PodTpuDriver* driver, int64_t operation_id,
std::vector<std::shared_ptr<Event>> events)
: PodEvent(driver, operation_id), events_(events) {}
xla::Status Await() override {
for (auto& event : events_) {
TF_RETURN_IF_ERROR(event->Await());
}
return Status::OK();
}
absl::optional<xla::Status> AwaitWithTimeout(
absl::Duration duration) override {
// TODO(frankchn): This might extend the timeout.
for (auto& event : events_) {
auto status = event->AwaitWithTimeout(duration);
if (status == absl::nullopt) {
return absl::nullopt;
} else {
TF_RETURN_IF_ERROR(status.value());
}
}
return Status::OK();
}
void AddCallback(std::function<void(Status)> callback) override {
// TODO(frankchn): This may return before every event is done.
events_[0]->AddCallback(std::move(callback));
}
private:
std::vector<std::shared_ptr<Event>> events_;
};
class PodBufferHandle : public BufferHandle {
public:
explicit PodBufferHandle(PodTpuDriver* driver, int64_t operation_id,
int64_t size_in_bytes,
absl::optional<xla::ShapeProto> shape,
int64_t core_id)
: driver_(driver),
operation_id_(operation_id),
size_in_bytes_(size_in_bytes),
shape_(shape),
event_(std::make_shared<PodEvent>(driver_, operation_id_)),
core_id_(core_id) {}
std::shared_ptr<Event> OnReady() override { return event_; }
int64_t size_in_bytes() override { return size_in_bytes_; }
absl::optional<xla::ShapeProto> shape() override { return shape_; }
int64_t operation_id() const { return operation_id_; }
int64_t core_id() const { return core_id_; }
private:
PodTpuDriver* driver_;
const int64_t operation_id_;
const int64_t size_in_bytes_;
const absl::optional<xla::ShapeProto> shape_;
std::shared_ptr<PodEvent> event_;
const int64_t core_id_;
};
class PodCompiledProgramHandle : public CompiledProgramHandle {
public:
explicit PodCompiledProgramHandle(PodTpuDriver* driver, int64_t operation_id)
: driver_(driver),
operation_id_(operation_id),
event_(std::make_shared<PodEvent>(driver_, operation_id_)) {}
std::shared_ptr<Event> OnReady() override { return event_; }
xla::Status program_shape(xla::ProgramShapeProto* program_shape) override;
int64_t operation_id() const { return operation_id_; }
private:
PodTpuDriver* driver_;
const int64_t operation_id_;
std::shared_ptr<PodEvent> event_;
};
class PodLoadedProgramHandle : public LoadedProgramHandle {
public:
explicit PodLoadedProgramHandle(PodTpuDriver* driver, int64_t operation_id,
int64_t core_id)
: driver_(driver),
operation_id_(operation_id),
core_id_(core_id),
event_(std::make_shared<PodEvent>(driver_, operation_id_)) {}
std::shared_ptr<Event> OnReady() override { return event_; }
int64_t operation_id() const { return operation_id_; }
int64_t core_id() const { return core_id_; }
private:
PodTpuDriver* driver_;
const int64_t operation_id_;
const int64_t core_id_;
std::shared_ptr<PodEvent> event_;
};
struct EventInFlight {
std::shared_ptr<Event> underlying_event;
std::function<std::shared_ptr<Event>(void)> create_fn;
absl::flat_hash_set<int64_t> incomplete_deps;
std::vector<std::function<void(Status)>> callbacks;
};
class PodTpuDriver : public TpuDriver {
public:
explicit PodTpuDriver(const TpuDriverConfig& config,
std::shared_ptr<::grpc::ChannelCredentials> creds)
: config_(config),
creds_(creds),
event_thread_(tensorflow::Env::Default(), "grpc_pod_event_thread") {
std::vector<std::string> workers = absl::StrSplit(
absl::StripPrefix(config.worker(), kPodTpuDriverPrefix), ',');
for (const auto& worker : workers) {
TpuDriverConfig worker_config(config_);
*(worker_config.mutable_worker()) = absl::StrCat("grpc://", worker);
drivers_.push_back(
CreateGrpcTpuDriver(worker_config, creds_).ConsumeValueOrDie());
}
for (int driver_num = 0; driver_num < workers.size(); ++driver_num) {
SystemInfo driver_info;
drivers_[driver_num]->QuerySystemInfo(&driver_info);
for (const auto& tpu_chip : driver_info.tpu_chip()) {
*(pod_info_.add_tpu_chip()) = tpu_chip;
}
int core_num = 0;
for (const auto& tpu_core : driver_info.local_core()) {
*(pod_info_.add_local_core()) = tpu_core;
core_to_driver_.push_back(drivers_[driver_num].get());
core_to_driver_id_.push_back(driver_num);
core_to_driver_core_.push_back(core_num++);
}
*(pod_info_.mutable_cpu()) = driver_info.cpu();
pod_info_.set_host_count(pod_info_.host_count() + 1);
pod_info_.set_chip_count(pod_info_.chip_count() +
driver_info.chip_count());
pod_info_.set_core_count(pod_info_.core_count() +
driver_info.core_count());
}
pod_info_.set_host_id(0);
}
~PodTpuDriver() override {
// TODO(frankchn): Unload all handles, and wait for all events to finish.
}
void QuerySystemInfo(SystemInfo* system_info) override {
*system_info = pod_info_;
}
xla::Status Reset() override {
for (auto& driver : drivers_) {
TF_RETURN_IF_ERROR(driver->Reset());
}
return xla::Status::OK();
}
std::unique_ptr<BufferHandle> Allocate(
int32_t core_id, MemoryRegion region, int64_t num_bytes,
absl::Span<Event* const> wait_for) override {
int64_t operation_id = GetOperationId();
auto deps = GetDependencyOperationIds(wait_for);
ScheduleRequest(
operation_id,
[this, core_id, region, num_bytes, operation_id]() {
absl::MutexLock l(&mu_);
underlying_buffers_.insert(
{operation_id,
core_to_driver_[core_id]->Allocate(core_to_driver_core_[core_id],
region, num_bytes, {})});
return underlying_buffers_[operation_id]->OnReady();
},
deps);
return absl::make_unique<PodBufferHandle>(this, operation_id, num_bytes,
absl::nullopt, core_id);
}
std::unique_ptr<BufferHandle> Allocate(
int32_t core_id, MemoryRegion region, const xla::ShapeProto& shape,
absl::Span<Event* const> wait_for) override {
int64_t operation_id = GetOperationId();
auto deps = GetDependencyOperationIds(wait_for);
ScheduleRequest(
operation_id,
[this, core_id, region, shape, operation_id]() {
absl::MutexLock l(&mu_);
underlying_buffers_.insert(
{operation_id,
core_to_driver_[core_id]->Allocate(core_to_driver_core_[core_id],
region, shape, {})});
return underlying_buffers_[operation_id]->OnReady();
},
deps);
return absl::make_unique<PodBufferHandle>(
this, operation_id, ComputeBytesFromShape(shape), shape, core_id);
}
std::unique_ptr<BufferHandle> AllocateTuple(
int32_t core_id, MemoryRegion region,
absl::Span<BufferHandle* const> children,
absl::Span<Event* const> wait_for) override {
int64_t operation_id = GetOperationId();
auto deps = GetDependencyOperationIds(wait_for);
std::vector<int64_t> children_ids;
for (int i = 0; i < children.size(); ++i) {
auto child_op_id =
static_cast<PodBufferHandle* const>(children[i])->operation_id();
deps.insert(child_op_id);
children_ids.push_back(child_op_id);
}
ScheduleRequest(
operation_id,
[this, core_id, region, children_ids, operation_id]() {
absl::MutexLock l(&mu_);
std::vector<BufferHandle*> child_buffers;
child_buffers.reserve(children_ids.size());
for (int i = 0; i < children_ids.size(); ++i) {
child_buffers.push_back(underlying_buffers_[children_ids[i]].get());
}
underlying_buffers_.insert(
{operation_id,
core_to_driver_[core_id]->AllocateTuple(
core_to_driver_core_[core_id], region, child_buffers, {})});
return underlying_buffers_[operation_id]->OnReady();
},
deps);
return absl::make_unique<PodBufferHandle>(this, operation_id, 0,
absl::nullopt, core_id);
}
std::shared_ptr<Event> Deallocate(
std::unique_ptr<BufferHandle> handle,
absl::Span<Event* const> wait_for) override {
int64_t operation_id = GetOperationId();
auto deps = GetDependencyOperationIds(wait_for);
deps.insert(static_cast<PodBufferHandle*>(handle.get())->operation_id());
auto op_id = static_cast<PodBufferHandle*>(handle.get())->operation_id();
auto core_id = static_cast<PodBufferHandle*>(handle.get())->core_id();
ScheduleRequest(
operation_id,
[this, op_id, core_id]() {
absl::MutexLock l(&mu_);
auto buf_iter = underlying_buffers_.find(op_id);
auto underlying_hn = std::move(buf_iter->second);
underlying_buffers_.erase(buf_iter);
return core_to_driver_[core_id]->Deallocate(std::move(underlying_hn),
{});
},
deps);
return std::make_shared<PodEvent>(this, operation_id);
}
std::shared_ptr<Event> TransferToDevice(
const void* src, BufferHandle* dst,
absl::Span<Event* const> wait_for) override {
int64_t operation_id = GetOperationId();
auto deps = GetDependencyOperationIds(wait_for);
deps.insert(static_cast<PodBufferHandle*>(dst)->operation_id());
auto op_id = static_cast<PodBufferHandle*>(dst)->operation_id();
auto core_id = static_cast<PodBufferHandle*>(dst)->core_id();
ScheduleRequest(
operation_id,
[this, src, op_id, core_id]() {
absl::MutexLock l(&mu_);
auto buf_iter = underlying_buffers_.find(op_id);
return core_to_driver_[core_id]->TransferToDevice(
src, buf_iter->second.get(), {});
},
deps);
return std::make_shared<PodEvent>(this, operation_id);
}
std::shared_ptr<Event> TransferFromDevice(
const BufferHandle* src, void* dst,
absl::Span<Event* const> wait_for) override {
int64_t operation_id = GetOperationId();
auto deps = GetDependencyOperationIds(wait_for);
deps.insert(static_cast<const PodBufferHandle*>(src)->operation_id());
auto op_id = static_cast<const PodBufferHandle*>(src)->operation_id();
auto core_id = static_cast<const PodBufferHandle*>(src)->core_id();
ScheduleRequest(
operation_id,
[this, dst, op_id, core_id]() {
absl::MutexLock l(&mu_);
auto buf_iter = underlying_buffers_.find(op_id);
return core_to_driver_[core_id]->TransferFromDevice(
buf_iter->second.get(), dst, {});
},
deps);
return std::make_shared<PodEvent>(this, operation_id);
}
std::shared_ptr<Event> TransferFromDeviceToDevice(
const BufferHandle* src, BufferHandle* dst,
absl::Span<Event* const> wait_for) override {
int64_t operation_id = GetOperationId();
auto deps = GetDependencyOperationIds(wait_for);
deps.insert(static_cast<const PodBufferHandle*>(src)->operation_id());
deps.insert(static_cast<PodBufferHandle*>(dst)->operation_id());
auto src_op_id = static_cast<const PodBufferHandle*>(src)->operation_id();
auto dst_op_id = static_cast<PodBufferHandle*>(dst)->operation_id();
auto core_id = static_cast<PodBufferHandle*>(dst)->core_id();
ScheduleRequest(
operation_id,
[this, src_op_id, dst_op_id, core_id]() {
absl::MutexLock l(&mu_);
auto src_iter = underlying_buffers_.find(src_op_id);
auto dst_iter = underlying_buffers_.find(dst_op_id);
return core_to_driver_[core_id]->TransferFromDeviceToDevice(
src_iter->second.get(), dst_iter->second.get(), {});
},
deps);
return std::make_shared<PodEvent>(this, operation_id);
}
std::unique_ptr<CompiledProgramHandle> CompileProgram(
const xla::HloProto& source, int32_t num_replicas,
absl::Span<Event* const> wait_for) override {
int64_t operation_id = GetOperationId();
auto deps = GetDependencyOperationIds(wait_for);
ScheduleRequest(
operation_id,
[this, operation_id, source, num_replicas]() {
absl::MutexLock l(&mu_);
auto cph_iterator =
underlying_cph_
.insert(
{operation_id,
std::vector<std::unique_ptr<CompiledProgramHandle>>()})
.first;
std::vector<std::shared_ptr<Event>> collected_events;
for (int i = 0; i < drivers_.size(); ++i) {
auto current_cph =
drivers_[i]->CompileProgram(source, num_replicas, {});
cph_iterator->second.push_back(std::move(current_cph));
collected_events.push_back(cph_iterator->second[i]->OnReady());
}
return std::make_shared<CombinedEvent>(this, operation_id,
collected_events);
},
deps);
return absl::make_unique<PodCompiledProgramHandle>(this, operation_id);
}
std::unique_ptr<LoadedProgramHandle> LoadProgram(
int32_t core_id, const CompiledProgramHandle* handle,
absl::Span<Event* const> wait_for) override {
int64_t operation_id = GetOperationId();
auto deps = GetDependencyOperationIds(wait_for);
deps.insert(
static_cast<const PodCompiledProgramHandle*>(handle)->operation_id());
auto cph_op_id =
static_cast<const PodCompiledProgramHandle*>(handle)->operation_id();
ScheduleRequest(
operation_id,
[this, operation_id, cph_op_id, core_id]() {
absl::MutexLock l(&mu_);
auto cph_iter = underlying_cph_.find(cph_op_id);
underlying_lph_.insert(
{operation_id,
core_to_driver_[core_id]->LoadProgram(
core_to_driver_core_[core_id],
cph_iter->second[core_to_driver_id_[core_id]].get(), {})});
return underlying_lph_[operation_id]->OnReady();
},
deps);
return absl::make_unique<PodLoadedProgramHandle>(this, operation_id,
core_id);
}
std::shared_ptr<Event> UnloadProgram(
std::unique_ptr<LoadedProgramHandle> handle,
absl::Span<Event* const> wait_for) override {
int64_t operation_id = GetOperationId();
auto deps = GetDependencyOperationIds(wait_for);
deps.insert(
static_cast<PodLoadedProgramHandle*>(handle.get())->operation_id());
auto op_id =
static_cast<PodLoadedProgramHandle*>(handle.get())->operation_id();
auto core_id =
static_cast<PodLoadedProgramHandle*>(handle.get())->core_id();
ScheduleRequest(
operation_id,
[this, op_id, core_id]() {
absl::MutexLock l(&mu_);
auto lph_iter = underlying_lph_.find(op_id);
auto event = core_to_driver_[core_id]->UnloadProgram(
std::move(lph_iter->second), {});
underlying_lph_.erase(lph_iter);
return event;
},
deps);
return std::make_shared<PodEvent>(this, operation_id);
}
std::shared_ptr<Event> ExecuteProgram(
LoadedProgramHandle* program, absl::Span<BufferHandle* const> inputs,
absl::Span<BufferHandle* const> outputs,
const xla::DeviceAssignmentProto& device_assignment,
absl::Span<Event* const> wait_for) override {
int64_t operation_id = GetOperationId();
auto deps = GetDependencyOperationIds(wait_for);
deps.insert(static_cast<PodLoadedProgramHandle*>(program)->operation_id());
auto op_id = static_cast<PodLoadedProgramHandle*>(program)->operation_id();
auto core_id = static_cast<PodLoadedProgramHandle*>(program)->core_id();
std::vector<int64_t> input_op_ids;
std::vector<int64_t> output_op_ids;
for (auto* input : inputs) {
auto input_dep =
static_cast<PodBufferHandle* const>(input)->operation_id();
input_op_ids.push_back(input_dep);
deps.insert(input_dep);
}
for (auto* output : outputs) {
auto output_dep =
static_cast<PodBufferHandle* const>(output)->operation_id();
output_op_ids.push_back(output_dep);
deps.insert(output_dep);
}
ScheduleRequest(
operation_id,
[this, core_id, op_id, input_op_ids, output_op_ids,
device_assignment]() {
absl::MutexLock l(&mu_);
std::vector<BufferHandle*> underlying_inputs;
std::vector<BufferHandle*> underlying_outputs;
underlying_inputs.reserve(input_op_ids.size());
for (auto input_op_id : input_op_ids) {
underlying_inputs.push_back(underlying_buffers_[input_op_id].get());
}
underlying_outputs.reserve(output_op_ids.size());
for (auto output_op_id : output_op_ids) {
underlying_outputs.push_back(
underlying_buffers_[output_op_id].get());
}
LoadedProgramHandle* handle = underlying_lph_[op_id].get();
return core_to_driver_[core_id]->ExecuteProgram(
handle, underlying_inputs, underlying_outputs, device_assignment,
{});
},
deps);
return std::make_shared<PodEvent>(this, operation_id);
}
std::unique_ptr<TpuLinearizer> GetLinearizer() override {
return drivers_[0]->GetLinearizer();
}
// Helper methods for Event scheduling
absl::optional<Status> WaitForEvent(int64_t event_id,
absl::Duration duration) {
std::shared_ptr<Event> underlying_event;
{
absl::MutexLock l(&event_mu_);
auto event = events_.find(event_id);
if (event == events_.end()) {
auto event_status = abnormal_event_status_.find(event_id);
if (event_status == abnormal_event_status_.end()) {
return Status::OK();
} else {
return event_status->second;
}
}
auto done = [this, event_id]() {
event_mu_.AssertHeld();
return events_[event_id].underlying_event != nullptr;
};
auto status =
event_mu_.AwaitWithTimeout(absl::Condition(&done), duration);
if (!status) {
return absl::nullopt;
}
underlying_event = events_[event_id].underlying_event;
}
// Wait for the underlying event without holding on to the event_lock_, or
// else incoming events will not be processed.
return underlying_event->AwaitWithTimeout(duration);
}
void AddCallbackForEvent(int64_t event_id, std::function<void(Status)> fn) {
absl::MutexLock l(&event_mu_);
auto event = events_.find(event_id);
if (event == events_.end()) {
auto event_status = abnormal_event_status_.find(event_id);
if (event_status == abnormal_event_status_.end()) {
fn(Status::OK());
} else {
fn(event_status->second);
}
}
if (event->second.underlying_event != nullptr) {
event->second.underlying_event->AddCallback(fn);
} else {
event->second.callbacks.push_back(std::move(fn));
}
}
xla::Status GetCompiledProgramShape(int64_t op_id,
xla::ProgramShapeProto* program_shape) {
absl::MutexLock l(&mu_);
auto done = [this, op_id]() {
mu_.AssertHeld();
return underlying_cph_.contains(op_id);
};
mu_.Await(absl::Condition(&done));
return underlying_cph_[op_id][0]->program_shape(program_shape);
}
private:
const TpuDriverConfig& config_;
std::shared_ptr<::grpc::ChannelCredentials> creds_;
std::vector<std::unique_ptr<TpuDriver>> drivers_;
std::vector<int32_t> core_to_driver_id_;
std::vector<TpuDriver*> core_to_driver_;
std::vector<int32_t> core_to_driver_core_;
SystemInfo pod_info_;
absl::Mutex mu_;
absl::Mutex event_mu_;
absl::flat_hash_map<int64_t, std::unique_ptr<BufferHandle>>
underlying_buffers_ ABSL_GUARDED_BY(mu_);
absl::flat_hash_map<int64_t,
std::vector<std::unique_ptr<CompiledProgramHandle>>>
underlying_cph_ ABSL_GUARDED_BY(mu_);
absl::flat_hash_map<int64_t, std::unique_ptr<LoadedProgramHandle>>
underlying_lph_ ABSL_GUARDED_BY(mu_);
absl::btree_map<int64_t, EventInFlight> events_ ABSL_GUARDED_BY(event_mu_);
absl::flat_hash_map<int64_t, Status> abnormal_event_status_
ABSL_GUARDED_BY(event_mu_);
std::atomic<int64_t> operation_id_counter_{0};
WorkerThread event_thread_;
int64_t GetOperationId() { return operation_id_counter_++; }
absl::flat_hash_set<int64_t> GetDependencyOperationIds(
absl::Span<Event* const> wait_for) {
absl::flat_hash_set<int64_t> deps;
for (auto* event : wait_for) {
deps.insert(static_cast<PodEvent* const>(event)->operation_id());
}
return deps;
}
// EventCompleted is executed on the event_thread_ worker thread. We want
// to propagate the fact that the event is completed to any subsequent events
// that might depend on this event.
void EventCompleted(int64_t event_id, Status status) {
absl::MutexLock l(&event_mu_);
absl::btree_map<int64_t, EventInFlight>::iterator curr_event;
if (!status.ok()) abnormal_event_status_.insert({event_id, status});
curr_event = events_.find(event_id);
DCHECK(curr_event->second.callbacks.empty());
DCHECK(curr_event->second.incomplete_deps.empty());
for (auto& event : events_) {
event.second.incomplete_deps.erase(event_id);
// The if statement conditions on both
// - all previous events have completed (incomplete_deps.empty())
// - the op creating this event has not been called yet
// (event.second.create_fn != nullptr)
// We call the create_fn that creates the event and adds any relevant
// callbacks to the actual event, before setting create_fn to nullptr
// to indicate that it has already been called
if (event.second.incomplete_deps.empty() &&
event.second.create_fn != nullptr) {
// We were the last unfilled dependency, all other dependencies are
// filled. We can now fire the create function.
event.second.underlying_event = event.second.create_fn();
for (auto& fn : event.second.callbacks) {
event.second.underlying_event->AddCallback(std::move(fn));
}
event.second.callbacks.clear();
event.second.create_fn = nullptr;
}
}
// We erase the current event to signal that it has finished.
events_.erase(curr_event);
}
void ScheduleRequest(int64_t operation_id,
std::function<std::shared_ptr<Event>(void)> fn,
const absl::flat_hash_set<int64_t>& deps) {
absl::MutexLock l(&event_mu_);
absl::btree_map<int64_t, EventInFlight>::iterator event;
absl::flat_hash_set<int64_t> incomplete_deps;
event = events_.insert({operation_id, {}}).first;
for (const auto& dep : deps) {
if (events_.count(dep) > 0) incomplete_deps.insert(dep);
}
if (incomplete_deps.empty()) {
// All dependencies have been fulfilled, we execute the request
// immediately and add a callback to inform our event fulfilled thread
// when it is done.
event->second.create_fn = nullptr;
event->second.underlying_event = fn();
event->second.underlying_event->AddCallback(
[this, operation_id](Status status) {
event_thread_.Schedule([this, operation_id, status]() {
EventCompleted(operation_id, status);
});
});
} else {
// There are some dependencies that are not yet fulfilled. We attach
// the request to the event, and will execute it in the EventFulfilled
// worker thread when all its dependencies are fulfilled.
event->second.create_fn = std::move(fn);
event->second.incomplete_deps = std::move(incomplete_deps);
event->second.callbacks.push_back([this, operation_id](Status status) {
event_thread_.Schedule([this, operation_id, status]() {
EventCompleted(operation_id, status);
});
});
}
}
};
xla::Status PodEvent::Await() {
return driver_->WaitForEvent(operation_id_, absl::InfiniteDuration()).value();
}
absl::optional<xla::Status> PodEvent::AwaitWithTimeout(
absl::Duration duration) {
return driver_->WaitForEvent(operation_id_, duration);
}
void PodEvent::AddCallback(std::function<void(Status)> callback) {
driver_->AddCallbackForEvent(operation_id_, std::move(callback));
}
xla::StatusOr<std::unique_ptr<TpuDriver>> CreatePodTpuDriver(
const TpuDriverConfig& config,
std::shared_ptr<::grpc::ChannelCredentials> creds) {
return std::unique_ptr<TpuDriver>(new PodTpuDriver(config, creds));
}
xla::Status PodCompiledProgramHandle::program_shape(
xla::ProgramShapeProto* program_shape) {
return driver_->GetCompiledProgramShape(operation_id(), program_shape);
}
} // namespace
REGISTER_TPU_DRIVER(kPodTpuDriverPrefix,
[](const TpuDriverConfig& config)
-> xla::StatusOr<std::unique_ptr<TpuDriver>> {
return CreatePodTpuDriver(
config,
::grpc::InsecureChannelCredentials()); // NOLINT
});
} // namespace tpu_driver