blob: b90e5bdc1a2bf411cd26f7fb0351a3f4107017de [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file extends/implements core stream executor base classes in terms of
// the C API defined in stream_executor.h. A class "CSomething" represents a
// "Something" that can be manipulated via calls in the C interface and a C
// struct called "SP_Something".
//
// This file also contains stream_executor::Platform registration for pluggable
// device.
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
#include <string>
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/stream_executor/executor_cache.h"
#include "tensorflow/stream_executor/multi_platform_manager.h"
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
#include "tensorflow/stream_executor/timer.h"
using tensorflow::StatusFromTF_Status;
namespace stream_executor {
using tensorflow::StringPiece;
using OwnedTFStatus = std::unique_ptr<TF_Status, TFStatusDeleter>;
namespace {
#define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \
do { \
if (STRUCT_OBJ.struct_size == 0) { \
return port::FailedPreconditionError( \
"struct_size field in " #STRUCT_NAME \
" must be set to " #SIZE_VALUE_NAME "."); \
} \
} while (0)
#define VALIDATE_MEMBER(STRUCT_NAME, STRUCT_OBJ, NAME) \
do { \
if (STRUCT_OBJ.NAME == 0) { \
return port::FailedPreconditionError( \
"'" #NAME "' field in " #STRUCT_NAME " must be set."); \
} \
} while (0)
port::Status ValidateDeviceType(StringPiece type) {
// Validate device type. Device type must start with a capital letter and
// consist of capital letters and underscores. Reasoning behind this decision:
// * At the minimum we want to disallow '/' and ':' since
// these characters are used in device spec, for e.g.
// /job:foo/replica:12/device:GPU:1.
// * Underscores seem useful, for e.g. XLA_GPU uses underscores.
// * Allowing lowercase might get confusing. For example, say someone
// registers a new type called "Gpu". It might be confusing for users that
// "Gpu" is not the same device type as "GPU".
// Note that lowercase "cpu" and "gpu" are currently supported only for
// legacy reasons:
// https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/python/framework/device_spec.py;l=46;drc=d3a378f9665d8eee827c74cb9ecbee81e4c288dd
static const LazyRE2 kTfDeviceTypeRegEx = {"[A-Z][A-Z_]*"};
bool matches = RE2::FullMatch(type, *kTfDeviceTypeRegEx);
if (!matches) {
return port::FailedPreconditionError(
tensorflow::strings::StrCat("Device name/type '", type, "' must match ",
kTfDeviceTypeRegEx->pattern(), "."));
}
return port::Status::OK();
}
port::Status ValidateSPPlatform(const SP_Platform& platform) {
VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE);
VALIDATE_MEMBER(SP_Platform, platform, name);
VALIDATE_MEMBER(SP_Platform, platform, type);
TF_RETURN_IF_ERROR(ValidateDeviceType(platform.name));
TF_RETURN_IF_ERROR(ValidateDeviceType(platform.type));
// `visible_device_count` could be 0 at initialization time.
return port::Status::OK();
}
port::Status ValidateSPPlatformFns(const SP_PlatformFns& platform_fns) {
VALIDATE_STRUCT_SIZE(SP_PlatformFns, platform_fns,
SP_PLATFORM_FNS_STRUCT_SIZE);
VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_device);
VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_device);
VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_stream_executor);
VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_stream_executor);
VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_timer_fns);
VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_timer_fns);
VALIDATE_MEMBER(SP_PlatformFns, platform_fns, create_device_fns);
VALIDATE_MEMBER(SP_PlatformFns, platform_fns, destroy_device_fns);
return port::Status::OK();
}
port::Status ValidateSPTimerFns(const SP_TimerFns& timer_fns) {
VALIDATE_STRUCT_SIZE(SP_TimerFns, timer_fns, SP_TIMER_FNS_STRUCT_SIZE);
VALIDATE_MEMBER(SP_TimerFns, timer_fns, nanoseconds);
return port::Status::OK();
}
port::Status ValidateSPAllocatorStats(const SP_AllocatorStats& stats) {
VALIDATE_STRUCT_SIZE(SP_AllocatorStats, stats, SP_ALLOCATORSTATS_STRUCT_SIZE);
// All other fields could theoretically be zero/null.
return port::Status::OK();
}
port::Status ValidateSPDeviceMemoryBase(const SP_DeviceMemoryBase& mem) {
VALIDATE_STRUCT_SIZE(SP_DeviceMemoryBase, mem,
SP_DEVICE_MEMORY_BASE_STRUCT_SIZE);
// All other fields could theoretically be zero/null.
return port::Status::OK();
}
port::Status ValidateSPDevice(const SP_Device& device) {
VALIDATE_STRUCT_SIZE(SP_Device, device, SP_DEVICE_STRUCT_SIZE);
// All other fields could theoretically be zero/null.
return port::Status::OK();
}
port::Status ValidateSPDeviceFns(const SP_DeviceFns& device_fns) {
VALIDATE_STRUCT_SIZE(SP_DeviceFns, device_fns, SP_DEVICE_FNS_STRUCT_SIZE);
// All other fields could theoretically be zero/null.
return port::Status::OK();
}
port::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se,
const SP_Platform& platform) {
VALIDATE_STRUCT_SIZE(SP_StreamExecutor, se, SP_STREAM_EXECUTOR_STRUCT_SIZE);
VALIDATE_MEMBER(SP_StreamExecutor, se, allocate);
VALIDATE_MEMBER(SP_StreamExecutor, se, deallocate);
VALIDATE_MEMBER(SP_StreamExecutor, se, get_allocator_stats);
VALIDATE_MEMBER(SP_StreamExecutor, se, host_memory_allocate);
VALIDATE_MEMBER(SP_StreamExecutor, se, host_memory_deallocate);
if (platform.supports_unified_memory) {
VALIDATE_MEMBER(SP_StreamExecutor, se, unified_memory_allocate);
VALIDATE_MEMBER(SP_StreamExecutor, se, unified_memory_deallocate);
}
VALIDATE_MEMBER(SP_StreamExecutor, se, device_memory_usage);
VALIDATE_MEMBER(SP_StreamExecutor, se, create_stream);
VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_stream);
VALIDATE_MEMBER(SP_StreamExecutor, se, create_stream_dependency);
VALIDATE_MEMBER(SP_StreamExecutor, se, get_stream_status);
VALIDATE_MEMBER(SP_StreamExecutor, se, create_event);
VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_event);
VALIDATE_MEMBER(SP_StreamExecutor, se, get_event_status);
VALIDATE_MEMBER(SP_StreamExecutor, se, record_event);
VALIDATE_MEMBER(SP_StreamExecutor, se, wait_for_event);
VALIDATE_MEMBER(SP_StreamExecutor, se, create_timer);
VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_timer);
VALIDATE_MEMBER(SP_StreamExecutor, se, start_timer);
VALIDATE_MEMBER(SP_StreamExecutor, se, stop_timer);
VALIDATE_MEMBER(SP_StreamExecutor, se, memcpy_dtoh);
VALIDATE_MEMBER(SP_StreamExecutor, se, memcpy_htod);
VALIDATE_MEMBER(SP_StreamExecutor, se, sync_memcpy_dtoh);
VALIDATE_MEMBER(SP_StreamExecutor, se, sync_memcpy_htod);
VALIDATE_MEMBER(SP_StreamExecutor, se, block_host_for_event);
VALIDATE_MEMBER(SP_StreamExecutor, se, synchronize_all_activity);
VALIDATE_MEMBER(SP_StreamExecutor, se, host_callback);
return port::Status::OK();
}
port::Status ValidateSEPlatformRegistrationParams(
const SE_PlatformRegistrationParams& params) {
VALIDATE_STRUCT_SIZE(SE_PlatformRegistrationParams, params,
SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE);
VALIDATE_MEMBER(SE_PlatformRegistrationParams, params, destroy_platform);
VALIDATE_MEMBER(SE_PlatformRegistrationParams, params, destroy_platform_fns);
return port::Status::OK();
}
#undef VALIDATE_MEMBER
// Converts SE_EventStatus to Event::Status.
Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
switch (s) {
case SE_EVENT_ERROR:
return Event::Status::kError;
case SE_EVENT_PENDING:
return Event::Status::kPending;
case SE_EVENT_COMPLETE:
return Event::Status::kComplete;
default:
return Event::Status::kUnknown;
}
}
// Converts DeviceMemoryBase to a C struct.
SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
// `opaque` field inside SP_DeviceMemoryBase is not const.
// Therefore, we need to cast away the constness before setting it.
device_memory_base.opaque = const_cast<void*>(mem->opaque());
device_memory_base.size = mem->size();
device_memory_base.payload = mem->payload();
return device_memory_base;
}
DeviceMemoryBase DeviceMemoryBaseFromC(const SP_DeviceMemoryBase& mem) {
DeviceMemoryBase base(mem.opaque, mem.size);
base.SetPayload(mem.payload);
return base;
}
// Wrapper that allows passing std::function across C API.
struct HostCallbackContext {
std::function<port::Status()> callback;
};
// This wrapper allows calling `HostCallbackContext::callback` across C API.
// This function matches `SE_StatusCallbackFn` signature and will be passed as
// `callback_fn` to `host_callback` in `SP_StreamExecutor`.
void HostCallbackTrampoline(void* ctx, TF_Status* status) {
HostCallbackContext* host_ctx = static_cast<HostCallbackContext*>(ctx);
port::Status s = host_ctx->callback();
Set_TF_Status_from_Status(status, s);
delete host_ctx;
}
class CStreamExecutor : public internal::StreamExecutorInterface {
public:
explicit CStreamExecutor(SP_Device device, SP_DeviceFns* device_fns,
SP_StreamExecutor* stream_executor,
SP_Platform* platform, SP_PlatformFns* platform_fns,
SP_TimerFns* timer_fns, const std::string& name,
int visible_device_count)
: device_(std::move(device)),
device_fns_(device_fns),
stream_executor_(stream_executor),
platform_(platform),
platform_fns_(platform_fns),
timer_fns_(timer_fns),
platform_name_(name),
visible_device_count_(visible_device_count) {}
~CStreamExecutor() override {
platform_fns_->destroy_device(platform_, &device_);
}
port::Status Init(int device_ordinal, DeviceOptions device_options) override {
return port::Status::OK();
}
DeviceMemoryBase Allocate(uint64 size, int64 memory_space) override {
SP_DeviceMemoryBase mem = {SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
stream_executor_->allocate(&device_, size, memory_space, &mem);
port::Status status = ValidateSPDeviceMemoryBase(mem);
if (!status.ok()) {
LOG(ERROR) << status.error_message();
}
return DeviceMemoryBaseFromC(mem);
}
DeviceMemoryBase Allocate(uint64 size) {
return Allocate(size, /*memory_space=*/0);
}
void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset,
uint64 size) override {
LOG(FATAL) << "GetSubBuffer is not supported by pluggable device.";
}
void Deallocate(DeviceMemoryBase* mem) override {
SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(mem);
stream_executor_->deallocate(&device_, &device_memory_base);
}
void* HostMemoryAllocate(uint64 size) override {
return stream_executor_->host_memory_allocate(&device_, size);
}
void HostMemoryDeallocate(void* mem) override {
stream_executor_->host_memory_deallocate(&device_, mem);
}
bool HostMemoryRegister(void* mem, uint64 size) override { return false; }
bool HostMemoryUnregister(void* mem) override { return false; }
void* UnifiedMemoryAllocate(uint64 size) override {
CHECK(stream_executor_->unified_memory_allocate);
return stream_executor_->unified_memory_allocate(&device_, size);
}
void UnifiedMemoryDeallocate(void* mem) override {
CHECK(stream_executor_->unified_memory_deallocate);
stream_executor_->unified_memory_deallocate(&device_, mem);
}
absl::optional<AllocatorStats> GetAllocatorStats() override {
SP_AllocatorStats c_stats{SP_ALLOCATORSTATS_STRUCT_SIZE};
TF_Bool has_stats =
stream_executor_->get_allocator_stats(&device_, &c_stats);
if (!has_stats) {
return absl::nullopt;
}
port::Status status = ValidateSPAllocatorStats(c_stats);
if (!status.ok()) {
LOG(ERROR) << status.error_message();
return absl::nullopt;
}
::stream_executor::AllocatorStats stats;
stats.num_allocs = c_stats.num_allocs;
stats.bytes_in_use = c_stats.bytes_in_use;
stats.peak_bytes_in_use = c_stats.peak_bytes_in_use;
stats.largest_alloc_size = c_stats.largest_alloc_size;
if (c_stats.has_bytes_limit) {
stats.bytes_limit = c_stats.bytes_limit;
}
stats.bytes_reserved = c_stats.bytes_reserved;
stats.peak_bytes_reserved = c_stats.peak_bytes_reserved;
if (c_stats.has_bytes_reservable_limit) {
stats.bytes_reservable_limit = c_stats.bytes_reservable_limit;
}
stats.largest_free_block_bytes = c_stats.largest_free_block_bytes;
return stats;
}
bool SynchronizeAllActivity() override {
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->synchronize_all_activity(&device_, c_status.get());
if (TF_GetCode(c_status.get()) != TF_OK) {
LOG(ERROR) << TF_Message(c_status.get());
return false;
}
return true;
}
port::Status SynchronousMemZero(DeviceMemoryBase* location,
uint64 size) override {
// TODO(annarev): figure out if we should support memzero/memset
// functionality by allocating on host and then copying to device.
return port::UnimplementedError(
"SynchronousMemZero is not supported by pluggable device.");
}
port::Status SynchronousMemSet(DeviceMemoryBase* location, int value,
uint64 size) override {
return port::UnimplementedError(
"SynchronousMemSet is not supported by pluggable device.");
}
port::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst,
const void* host_src, uint64 size) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(gpu_dst);
stream_executor_->sync_memcpy_htod(&device_, &device_memory_base, host_src,
size, c_status.get());
return StatusFromTF_Status(c_status.get());
}
port::Status SynchronousMemcpy(void* host_dst,
const DeviceMemoryBase& gpu_src,
uint64 size) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(&gpu_src);
stream_executor_->sync_memcpy_dtoh(&device_, host_dst, &device_memory_base,
size, c_status.get());
return StatusFromTF_Status(c_status.get());
}
port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase* gpu_dst,
const DeviceMemoryBase& gpu_src,
uint64 size) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
stream_executor_->sync_memcpy_dtod(&device_, &device_mem_dst,
&device_mem_src, size, c_status.get());
return StatusFromTF_Status(c_status.get());
}
port::Status MemZero(Stream* stream, DeviceMemoryBase* location,
uint64 size) override {
return port::UnimplementedError(
"MemZero is not supported by pluggable device.");
}
port::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8 pattern,
uint64 size) override {
return port::UnimplementedError(
"Memset is not supported by pluggable device.");
}
port::Status Memset32(Stream* stream, DeviceMemoryBase* location,
uint32 pattern, uint64 size) override {
return port::UnimplementedError(
"Memset32 is not supported by pluggable device.");
}
bool Memcpy(Stream* stream, void* host_dst, const DeviceMemoryBase& gpu_src,
uint64 size) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
stream_executor_->memcpy_dtoh(&device_, stream_handle, host_dst,
&device_mem_src, size, c_status.get());
if (TF_GetCode(c_status.get()) != TF_OK) {
LOG(ERROR) << TF_Message(c_status.get());
return false;
}
return true;
}
bool Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, const void* host_src,
uint64 size) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
stream_executor_->memcpy_htod(&device_, stream_handle, &device_mem_dst,
host_src, size, c_status.get());
if (TF_GetCode(c_status.get()) != TF_OK) {
LOG(ERROR) << TF_Message(c_status.get());
return false;
}
return true;
}
bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst,
const DeviceMemoryBase& gpu_src,
uint64 size) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
stream_executor_->memcpy_dtod(&device_, stream_handle, &device_mem_dst,
&device_mem_src, size, c_status.get());
if (TF_GetCode(c_status.get()) != TF_OK) {
LOG(ERROR) << TF_Message(c_status.get());
return false;
}
return true;
}
bool HostCallback(Stream* stream,
std::function<port::Status()> callback) override {
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
HostCallbackContext* ctx = new HostCallbackContext{callback};
return stream_executor_->host_callback(&device_, stream_handle,
&HostCallbackTrampoline, ctx);
}
port::Status AllocateEvent(Event* event) override {
DCHECK(event != nullptr);
return static_cast<CEvent*>(event->implementation())->Create();
}
port::Status DeallocateEvent(Event* event) override {
static_cast<CEvent*>(event->implementation())->Destroy();
return port::Status::OK();
}
port::Status RecordEvent(Stream* stream, Event* event) override {
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
return static_cast<CEvent*>(event->implementation())->Record(stream_handle);
}
port::Status WaitForEvent(Stream* stream, Event* event) override {
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
SP_Event event_handle =
static_cast<CEvent*>(event->implementation())->Handle();
OwnedTFStatus c_status(TF_NewStatus());
stream_executor_->wait_for_event(&device_, stream_handle, event_handle,
c_status.get());
port::Status s = StatusFromTF_Status(c_status.get());
return s;
}
Event::Status PollForEventStatus(Event* event) override {
SP_Event event_handle =
static_cast<CEvent*>(event->implementation())->Handle();
SE_EventStatus event_status =
stream_executor_->get_event_status(&device_, event_handle);
return SEEventStatusToEventStatus(event_status);
}
bool AllocateStream(Stream* stream) override {
DCHECK(stream != nullptr);
port::Status status =
static_cast<CStream*>(stream->implementation())->Create();
// TODO(annarev): update AllocateStream to return status instead
// (similar to AllocateEvent).
return status.ok();
}
void DeallocateStream(Stream* stream) override {
static_cast<CStream*>(stream->implementation())->Destroy();
}
bool CreateStreamDependency(Stream* dependent, Stream* other) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream dependent_handle =
static_cast<CStream*>(dependent->implementation())->Handle();
SP_Stream other_handle =
static_cast<CStream*>(other->implementation())->Handle();
stream_executor_->create_stream_dependency(&device_, dependent_handle,
other_handle, c_status.get());
if (TF_GetCode(c_status.get()) != TF_OK) {
LOG(ERROR) << TF_Message(c_status.get());
return false;
}
return true;
}
bool AllocateTimer(Timer* timer) override {
port::Status status =
static_cast<CTimer*>(timer->implementation())->Create();
// TODO(annarev): change return value of AllocateTimer
// to status (similar to AllocateEvent).
return status.ok();
}
void DeallocateTimer(Timer* timer) override {
static_cast<CTimer*>(timer->implementation())->Destroy();
}
bool StartTimer(Stream* stream, Timer* timer) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
SP_Timer timer_handle =
static_cast<CTimer*>(timer->implementation())->Handle();
stream_executor_->start_timer(&device_, stream_handle, timer_handle,
c_status.get());
if (TF_GetCode(c_status.get()) != TF_OK) {
LOG(ERROR) << TF_Message(c_status.get());
return false;
}
return true;
}
bool StopTimer(Stream* stream, Timer* timer) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
SP_Timer timer_handle =
static_cast<CTimer*>(timer->implementation())->Handle();
stream_executor_->stop_timer(&device_, stream_handle, timer_handle,
c_status.get());
if (TF_GetCode(c_status.get()) != TF_OK) {
LOG(ERROR) << TF_Message(c_status.get());
return false;
}
return true;
}
port::Status BlockHostForEvent(Stream* stream, Event* event) {
OwnedTFStatus c_status(TF_NewStatus());
SP_Event event_handle =
static_cast<CEvent*>(event->implementation())->Handle();
stream_executor_->block_host_for_event(&device_, event_handle,
c_status.get());
return StatusFromTF_Status(c_status.get());
}
port::Status BlockHostUntilDone(Stream* stream) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
// If `block_host_until_done` is set, use it.
if (stream_executor_->block_host_until_done != nullptr) {
stream_executor_->block_host_until_done(&device_, stream_handle,
c_status.get());
return StatusFromTF_Status(c_status.get());
}
// Create and record an event and then wait for it.
SP_Event event_handle;
stream_executor_->create_event(&device_, &event_handle, c_status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get()));
stream_executor_->record_event(&device_, stream_handle, event_handle,
c_status.get());
port::Status s = StatusFromTF_Status(c_status.get());
if (!s.ok()) {
stream_executor_->destroy_event(&device_, event_handle);
return s;
}
stream_executor_->block_host_for_event(&device_, event_handle,
c_status.get());
stream_executor_->destroy_event(&device_, event_handle);
return StatusFromTF_Status(c_status.get());
}
port::Status GetStatus(Stream* stream) override {
OwnedTFStatus c_status(TF_NewStatus());
SP_Stream stream_handle =
static_cast<CStream*>(stream->implementation())->Handle();
stream_executor_->get_stream_status(&device_, stream_handle,
c_status.get());
return StatusFromTF_Status(c_status.get());
}
int PlatformDeviceCount() override { return visible_device_count_; }
port::Status EnablePeerAccessTo(StreamExecutorInterface* other) override {
return port::UnimplementedError(
"EnablePeerAccessTo is not supported by pluggable device.");
}
bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override {
return false;
}
bool DeviceMemoryUsage(int64* free, int64* total) const override {
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
return stream_executor_->device_memory_usage(
&device_, reinterpret_cast<int64_t*>(free),
reinterpret_cast<int64_t*>(total));
}
// Creates a new DeviceDescription object.
// Ownership is transferred to the caller.
port::StatusOr<std::unique_ptr<DeviceDescription>> CreateDeviceDescription()
const override {
OwnedTFStatus c_status(TF_NewStatus());
internal::DeviceDescriptionBuilder builder;
if (device_.hardware_name != nullptr) {
builder.set_name(device_.hardware_name);
}
if (device_.device_vendor != nullptr) {
builder.set_device_vendor(device_.device_vendor);
}
if (device_.pci_bus_id != nullptr) {
builder.set_pci_bus_id(device_.pci_bus_id);
}
if (device_fns_->get_numa_node != nullptr) {
int32_t numa_node = device_fns_->get_numa_node(&device_);
if (numa_node >= 0) {
builder.set_numa_node(numa_node);
}
}
if (device_fns_->get_memory_bandwidth != nullptr) {
int64_t memory_bandwidth = device_fns_->get_memory_bandwidth(&device_);
if (memory_bandwidth >= 0) {
builder.set_memory_bandwidth(memory_bandwidth);
}
}
// TODO(annarev): Add gflops field in DeviceDescription and set it here.
// TODO(annarev): Perhaps add `supports_unified_memory` in
// DeviceDescription.
return builder.Build();
}
// Each call creates a new instance of the platform-specific implementation of
// the corresponding interface type.
std::unique_ptr<internal::EventInterface> CreateEventImplementation()
override {
return std::unique_ptr<internal::EventInterface>(
new CEvent(&device_, stream_executor_));
}
std::unique_ptr<internal::KernelInterface> CreateKernelImplementation()
override {
LOG(FATAL)
<< "CreateKernelImplementation is not supported by pluggable device.";
}
std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
override {
return std::unique_ptr<internal::StreamInterface>(
new CStream(&device_, stream_executor_));
}
std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override {
return std::unique_ptr<internal::TimerInterface>(
new CTimer(&device_, stream_executor_, timer_fns_));
}
private:
SP_Device device_;
SP_DeviceFns* device_fns_;
SP_StreamExecutor* stream_executor_;
SP_Platform* platform_;
SP_PlatformFns* platform_fns_;
SP_TimerFns* timer_fns_;
std::string platform_name_;
int visible_device_count_;
};
} // namespace
CPlatform::CPlatform(SP_Platform platform,
void (*destroy_platform)(SP_Platform*),
SP_PlatformFns platform_fns,
void (*destroy_platform_fns)(SP_PlatformFns*),
SP_DeviceFns device_fns, SP_StreamExecutor stream_executor,
SP_TimerFns timer_fns)
: platform_(std::move(platform)),
destroy_platform_(destroy_platform),
platform_fns_(std::move(platform_fns)),
destroy_platform_fns_(destroy_platform_fns),
device_fns_(std::move(device_fns)),
stream_executor_(std::move(stream_executor)),
timer_fns_(std::move(timer_fns)),
name_(platform.name) {}
CPlatform::~CPlatform() {
executor_cache_.DestroyAllExecutors();
platform_fns_.destroy_device_fns(&platform_, &device_fns_);
platform_fns_.destroy_stream_executor(&platform_, &stream_executor_);
platform_fns_.destroy_timer_fns(&platform_, &timer_fns_);
destroy_platform_(&platform_);
destroy_platform_fns_(&platform_fns_);
}
port::StatusOr<std::unique_ptr<DeviceDescription>>
CPlatform::DescriptionForDevice(int ordinal) const {
// TODO(annarev): see if we can get StreamExecutor instance
// and call GetDeviceDescription. executor_cache_.Get would need
// to be made const for it to work.
internal::DeviceDescriptionBuilder builder;
builder.set_name(name_);
return builder.Build();
}
port::StatusOr<StreamExecutor*> CPlatform::ExecutorForDevice(int ordinal) {
stream_executor::StreamExecutorConfig config;
config.ordinal = ordinal;
return GetExecutor(config);
}
port::StatusOr<StreamExecutor*> CPlatform::ExecutorForDeviceWithPluginConfig(
int ordinal, const PluginConfig& plugin_config) {
StreamExecutorConfig config;
config.ordinal = ordinal;
config.plugin_config = plugin_config;
return GetExecutor(config);
}
port::StatusOr<StreamExecutor*> CPlatform::GetExecutor(
const StreamExecutorConfig& config) {
return executor_cache_.GetOrCreate(
config, [&]() { return GetUncachedExecutor(config); });
}
port::StatusOr<std::unique_ptr<StreamExecutor>> CPlatform::GetUncachedExecutor(
const StreamExecutorConfig& config) {
// Fill device creation params
SE_CreateDeviceParams device_params{SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE};
SP_Device device{SP_DEVICE_STRUCT_SIZE};
device_params.device = &device;
device_params.ext = nullptr;
device_params.ordinal = config.ordinal;
OwnedTFStatus c_status(TF_NewStatus());
// Create Device
platform_fns_.create_device(&platform_, &device_params, c_status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get()));
TF_RETURN_IF_ERROR(ValidateSPDevice(device));
auto executor = absl::make_unique<CStreamExecutor>(
std::move(device), &device_fns_, &stream_executor_, &platform_,
&platform_fns_, &timer_fns_, name_, platform_.visible_device_count);
auto result = absl::make_unique<StreamExecutor>(this, std::move(executor),
config.ordinal);
return result;
}
port::Status InitStreamExecutorPlugin(void* dso_handle, string* device_type,
string* platform_name) {
tensorflow::Env* env = tensorflow::Env::Default();
// Step 1: Load symbol for `TF_InitPlugin`
void* dso_symbol;
TF_RETURN_IF_ERROR(
env->GetSymbolFromLibrary(dso_handle, "SE_InitPlugin", &dso_symbol));
// Step 2: Call `TF_InitPlugin`
auto init_fn = reinterpret_cast<SEInitPluginFn>(dso_symbol);
return InitStreamExecutorPlugin(init_fn, device_type, platform_name);
}
port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn,
string* device_type,
string* platform_name) {
SE_PlatformRegistrationParams params{
SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE};
SP_Platform platform{SP_PLATFORM_STRUCT_SIZE};
SP_PlatformFns platform_fns{SP_PLATFORM_FNS_STRUCT_SIZE};
params.major_version = SE_MAJOR;
params.minor_version = SE_MINOR;
params.patch_version = SE_PATCH;
params.platform = &platform;
params.platform_fns = &platform_fns;
OwnedTFStatus c_status(TF_NewStatus());
init_fn(&params, c_status.get());
TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
TF_RETURN_IF_ERROR(ValidateSEPlatformRegistrationParams(params));
TF_RETURN_IF_ERROR(ValidateSPPlatform(platform));
TF_RETURN_IF_ERROR(ValidateSPPlatformFns(platform_fns));
// Fill SP_DeviceFns creation params
SE_CreateDeviceFnsParams device_fns_params{
SE_CREATE_DEVICE_FNS_PARAMS_STRUCT_SIZE};
SP_DeviceFns device_fns{SP_DEVICE_FNS_STRUCT_SIZE};
device_fns_params.device_fns = &device_fns;
// Create StreamExecutor
platform_fns.create_device_fns(&platform, &device_fns_params, c_status.get());
TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
TF_RETURN_IF_ERROR(ValidateSPDeviceFns(device_fns));
// Fill stream executor creation params
SE_CreateStreamExecutorParams se_params{
SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE};
SP_StreamExecutor se{SP_STREAMEXECUTOR_STRUCT_SIZE};
se_params.stream_executor = &se;
// Create StreamExecutor
platform_fns.create_stream_executor(&platform, &se_params, c_status.get());
TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
TF_RETURN_IF_ERROR(ValidateSPStreamExecutor(se, platform));
SP_TimerFns timer_fns{SP_TIMER_FNS_STRUCT_SIZE};
platform_fns.create_timer_fns(&platform, &timer_fns, c_status.get());
TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns));
platform_fns.create_timer_fns(&platform, &timer_fns, c_status.get());
TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns));
// Register new platform
*device_type = std::string(platform.type);
*platform_name = std::string(platform.name);
std::unique_ptr<stream_executor::CPlatform> cplatform(
new stream_executor::CPlatform(
std::move(platform), params.destroy_platform, std::move(platform_fns),
params.destroy_platform_fns, std::move(device_fns), std::move(se),
std::move(timer_fns)));
SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
std::move(cplatform)));
return port::Status::OK();
}
} // namespace stream_executor