blob: 32bc8abca1828974cfd5493339503c17ac73f9b1 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Declares the XlaInterpreterExecutor class, which is a CPU-only implementation
// of the StreamExecutor interface. For now, this is used for testing and to
// examine the performance of host-based StreamExecutor code.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_
#include <functional>
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/device_description.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/device_options.h"
#include "tensorflow/stream_executor/event.h"
#include "tensorflow/stream_executor/host/host_stream.h"
#include "tensorflow/stream_executor/host/host_timer.h"
#include "tensorflow/stream_executor/kernel.h"
#include "tensorflow/stream_executor/kernel_spec.h"
#include "tensorflow/stream_executor/launch_dim.h"
#include "tensorflow/stream_executor/plugin.h"
#include "tensorflow/stream_executor/rng.h"
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/stream_executor.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/timer.h"
namespace stream_executor {
namespace interpreter {
using Args = absl::Span<const DeviceMemoryBase>;
class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
public:
explicit XlaInterpreterExecutor(const PluginConfig &plugin_config);
~XlaInterpreterExecutor() override;
port::Status Init(int device_ordinal, DeviceOptions device_options) override {
return ::tensorflow::OkStatus();
}
port::Status GetKernel(const MultiKernelLoaderSpec &spec,
KernelBase *kernel) override {
return port::UnimplementedError("Not Implemented");
}
port::Status Launch(Stream *stream, const ThreadDim &thread_dims,
const BlockDim &block_dims, const KernelBase &kernel,
const KernelArgsArrayBase &args) override {
return port::UnimplementedError("Not Implemented");
}
DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override;
void *GetSubBuffer(DeviceMemoryBase *parent, uint64_t offset_bytes,
uint64_t size_bytes) override;
void Deallocate(DeviceMemoryBase *mem) override;
void *HostMemoryAllocate(uint64_t size) override { return new char[size]; }
void HostMemoryDeallocate(void *mem) override {
delete[] static_cast<char *>(mem);
}
bool HostMemoryRegister(void *mem, uint64_t size) override { return true; }
bool HostMemoryUnregister(void *mem) override { return true; }
bool Memcpy(Stream *stream, void *host_dst, const DeviceMemoryBase &dev_src,
uint64_t size) override;
bool Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, const void *host_src,
uint64_t size) override;
bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *pop_dst,
const DeviceMemoryBase &host_src,
uint64_t size) override {
return false;
}
port::Status MemZero(Stream *stream, DeviceMemoryBase *location,
uint64_t size) override {
return port::InternalError("Interpreter can not memzero");
}
port::Status Memset(Stream *stream, DeviceMemoryBase *location,
uint8_t pattern, uint64_t size) override {
return port::InternalError("Interpreter can not memset");
}
port::Status Memset32(Stream *stream, DeviceMemoryBase *location,
uint32_t pattern, uint64_t size) override {
return port::InternalError("Interpreter can not memset");
}
// No "synchronize all activity" implemented for this platform at the moment.
bool SynchronizeAllActivity() override { return true; }
port::Status SynchronousMemZero(DeviceMemoryBase *location,
uint64_t size) override {
return port::InternalError("Interpreter can not memzero");
}
port::Status SynchronousMemSet(DeviceMemoryBase *location, int value,
uint64_t size) override {
return port::InternalError("Interpreter can not memset");
}
port::Status SynchronousMemcpy(DeviceMemoryBase *dev_dst,
const void *host_src, uint64_t size) override;
port::Status SynchronousMemcpy(void *host_dst,
const DeviceMemoryBase &dev_src,
uint64_t size) override;
port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase *pop_dst,
const DeviceMemoryBase &pop_src,
uint64_t size) override {
return port::Status{port::error::UNIMPLEMENTED, ""};
}
bool HostCallback(Stream *stream,
std::function<port::Status()> callback) override;
port::Status AllocateEvent(Event *event) override {
return ::tensorflow::OkStatus();
}
port::Status DeallocateEvent(Event *event) override {
return ::tensorflow::OkStatus();
}
port::Status RecordEvent(Stream *stream, Event *event) override {
return port::Status{port::error::UNIMPLEMENTED, "RecordEvent"};
}
port::Status WaitForEvent(Stream *stream, Event *event) override {
return port::Status{port::error::UNIMPLEMENTED, "WaitForEvent"};
}
Event::Status PollForEventStatus(Event *event) override {
return Event::Status::kError;
}
bool AllocateStream(Stream *stream) override { return true; }
void DeallocateStream(Stream *stream) override {}
bool CreateStreamDependency(Stream *dependent, Stream *other) override;
bool AllocateTimer(Timer *timer) override { return true; }
void DeallocateTimer(Timer *timer) override {}
bool StartTimer(Stream *stream, Timer *timer) override;
bool StopTimer(Stream *stream, Timer *timer) override;
port::Status BlockHostUntilDone(Stream *stream) override;
int PlatformDeviceCount() override { return 1; }
bool DeviceMemoryUsage(int64_t *free, int64_t *total) const override {
return false;
}
port::StatusOr<std::unique_ptr<DeviceDescription>> CreateDeviceDescription()
const override {
return CreateDeviceDescription(0);
}
static port::StatusOr<std::unique_ptr<DeviceDescription>>
CreateDeviceDescription(int device_ordinal);
port::Status EnablePeerAccessTo(StreamExecutorInterface *other) override {
return ::tensorflow::OkStatus();
}
bool CanEnablePeerAccessTo(StreamExecutorInterface *other) override {
return true;
}
std::unique_ptr<internal::EventInterface> CreateEventImplementation()
override {
return nullptr;
}
std::unique_ptr<internal::KernelInterface> CreateKernelImplementation()
override {
return nullptr;
}
std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
override {
return std::unique_ptr<internal::StreamInterface>(
new host::HostStream(/*thread_stack_size=*/0));
}
std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override {
return std::unique_ptr<internal::TimerInterface>(new host::HostTimer());
}
private:
DeviceMemoryBase AllocateSingleOutput(const xla::Shape &shape);
port::StatusOr<DeviceMemoryBase> AllocateOutputBuffer(
const xla::Shape &shape);
const PluginConfig plugin_config_;
};
} // namespace interpreter
} // namespace stream_executor
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_