blob: a3852b0edb0fb331cfcada4557e7ca3b9a6558a0 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_
#include <memory>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/executor_cache.h"
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
namespace tensorflow {
class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
public:
using StreamMap =
absl::flat_hash_map<stream_executor::internal::StreamInterface*,
SE_Stream*>;
using EventMap =
absl::flat_hash_map<stream_executor::internal::EventInterface*,
SE_Event*>;
static const ::stream_executor::Platform::Id kId;
using Status = ::stream_executor::port::Status;
template <typename T>
using StatusOr = ::stream_executor::port::StatusOr<T>;
TpuPlatform();
~TpuPlatform() override;
static TpuPlatform* GetRegisteredPlatform();
Id id() const override;
const std::string& Name() const override;
int VisibleDeviceCount() const override;
int64 TpuMemoryLimit() override;
bool ShouldRegisterTpuDeviceToDeviceCopy() override;
bool Initialized() const override;
Status Initialize(
const std::map<std::string, std::string>& platform_options) override;
Status Reset() override { return Reset(false); }
Status Reset(bool only_tear_down) override {
LOG(FATAL) << "Not yet implemented";
}
StatusOr<std::unique_ptr<::stream_executor::DeviceDescription>>
DescriptionForDevice(int ordinal) const override {
LOG(FATAL) << "Not yet implemented";
}
StatusOr<::stream_executor::StreamExecutor*> ExecutorForDevice(
int ordinal) override {
stream_executor::StreamExecutorConfig config;
config.ordinal = ordinal;
return GetExecutor(config);
}
StatusOr<::stream_executor::StreamExecutor*>
ExecutorForDeviceWithPluginConfig(
int ordinal,
const ::stream_executor::PluginConfig& plugin_config) override {
stream_executor::StreamExecutorConfig config;
config.ordinal = ordinal;
config.plugin_config = plugin_config;
return GetExecutor(config);
}
StatusOr<::stream_executor::StreamExecutor*> GetExecutor(
const ::stream_executor::StreamExecutorConfig& config) override;
StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
GetUncachedExecutor(
const ::stream_executor::StreamExecutorConfig& config) override;
void RegisterTraceListener(
std::unique_ptr<stream_executor::TraceListener> listener) override {
LOG(FATAL) << "Not yet implemented";
}
void UnregisterTraceListener(
stream_executor::TraceListener* listener) override {
LOG(FATAL) << "Not yet implemented";
}
StreamMap* stream_map() { return &stream_map_; }
EventMap* event_map() { return &event_map_; }
private:
SE_Platform* platform_;
stream_executor::ExecutorCache executor_cache_;
StreamMap stream_map_;
EventMap event_map_;
};
bool RegisterTpuPlatform();
} // namespace tensorflow
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_