blob: 7e4407d4b273e0ce9cde892b130b219c8874b46c [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_PLUGIN_REGISTRY_H_
#define TENSORFLOW_STREAM_EXECUTOR_PLUGIN_REGISTRY_H_
#include <map>
#include "absl/base/macros.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/dnn.h"
#include "tensorflow/stream_executor/fft.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/plugin.h"
#include "tensorflow/stream_executor/rng.h"
namespace stream_executor {
namespace internal {
class StreamExecutorInterface;
}
// The PluginRegistry is a singleton that maintains the set of registered
// "support library" plugins. Currently, there are four kinds of plugins:
// BLAS, DNN, FFT, and RNG. Each interface is defined in the corresponding
// gpu_{kind}.h header.
//
// At runtime, a StreamExecutor object will query the singleton registry to
// retrieve the plugin kind that StreamExecutor was configured with (refer to
// the StreamExecutor and PluginConfig declarations).
//
// Plugin libraries are best registered using REGISTER_MODULE_INITIALIZER,
// but can be registered at any time. When registering a DSO-backed plugin, it
// is usually a good idea to load the DSO at registration time, to prevent
// late-loading from distorting performance/benchmarks as much as possible.
class PluginRegistry {
public:
typedef blas::BlasSupport* (*BlasFactory)(internal::StreamExecutorInterface*);
typedef dnn::DnnSupport* (*DnnFactory)(internal::StreamExecutorInterface*);
typedef fft::FftSupport* (*FftFactory)(internal::StreamExecutorInterface*);
typedef rng::RngSupport* (*RngFactory)(internal::StreamExecutorInterface*);
// Gets (and creates, if necessary) the singleton PluginRegistry instance.
static PluginRegistry* Instance();
// Registers the specified factory with the specified platform.
// Returns a non-successful status if the factory has already been registered
// with that platform (but execution should be otherwise unaffected).
template <typename FactoryT>
port::Status RegisterFactory(Platform::Id platform_id, PluginId plugin_id,
const string& name, FactoryT factory);
// Registers the specified factory as usable by _all_ platform types.
// Reports errors just as RegisterFactory.
template <typename FactoryT>
port::Status RegisterFactoryForAllPlatforms(PluginId plugin_id,
const string& name,
FactoryT factory);
// TODO(b/22689637): Setter for temporary mapping until all users are using
// MultiPlatformManager / PlatformId.
void MapPlatformKindToId(PlatformKind platform_kind,
Platform::Id platform_id);
// Potentially sets the plugin identified by plugin_id to be the default
// for the specified platform and plugin kind. If this routine is called
// multiple types for the same PluginKind, the PluginId given in the last call
// will be used.
bool SetDefaultFactory(Platform::Id platform_id, PluginKind plugin_kind,
PluginId plugin_id);
// Return true if the factory/id has been registered for the
// specified platform and plugin kind and false otherwise.
bool HasFactory(Platform::Id platform_id, PluginKind plugin_kind,
PluginId plugin) const;
// Retrieves the factory registered for the specified kind,
// or a port::Status on error.
template <typename FactoryT>
port::StatusOr<FactoryT> GetFactory(Platform::Id platform_id,
PluginId plugin_id);
// TODO(b/22689637): Deprecated/temporary. Will be deleted once all users are
// on MultiPlatformManager / PlatformId.
template <typename FactoryT>
ABSL_DEPRECATED("Use MultiPlatformManager / PlatformId instead.")
port::StatusOr<FactoryT> GetFactory(PlatformKind platform_kind,
PluginId plugin_id);
private:
// Containers for the sets of registered factories, by plugin kind.
struct PluginFactories {
std::map<PluginId, BlasFactory> blas;
std::map<PluginId, DnnFactory> dnn;
std::map<PluginId, FftFactory> fft;
std::map<PluginId, RngFactory> rng;
};
// Simple structure to hold the currently configured default plugins (for a
// particular Platform).
struct DefaultFactories {
DefaultFactories();
PluginId blas, dnn, fft, rng;
};
PluginRegistry();
// Actually performs the work of registration.
template <typename FactoryT>
port::Status RegisterFactoryInternal(PluginId plugin_id,
const string& plugin_name,
FactoryT factory,
std::map<PluginId, FactoryT>* factories);
// Actually performs the work of factory retrieval.
template <typename FactoryT>
port::StatusOr<FactoryT> GetFactoryInternal(
PluginId plugin_id, const std::map<PluginId, FactoryT>& factories,
const std::map<PluginId, FactoryT>& generic_factories) const;
// Returns true if the specified plugin has been registered with the specified
// platform factories. Unlike the other overload of this method, this does
// not implicitly examine the default factory lists.
bool HasFactory(const PluginFactories& factories, PluginKind plugin_kind,
PluginId plugin) const;
// The singleton itself.
static PluginRegistry* instance_;
// TODO(b/22689637): Temporary mapping until all users are using
// MultiPlatformManager / PlatformId.
std::map<PlatformKind, Platform::Id> platform_id_by_kind_;
// The set of registered factories, keyed by platform ID.
std::map<Platform::Id, PluginFactories> factories_;
// Plugins supported for all platform kinds.
PluginFactories generic_factories_;
// The sets of default factories, keyed by platform ID.
std::map<Platform::Id, DefaultFactories> default_factories_;
// Lookup table for plugin names.
std::map<PluginId, string> plugin_names_;
SE_DISALLOW_COPY_AND_ASSIGN(PluginRegistry);
};
// Explicit specializations are defined in plugin_registry.cc.
#define DECLARE_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE) \
template <> \
port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \
Platform::Id platform_id, PluginId plugin_id, const string& name, \
PluginRegistry::FACTORY_TYPE factory); \
template <> \
port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory( \
Platform::Id platform_id, PluginId plugin_id); \
template <> \
port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory( \
PlatformKind platform_kind, PluginId plugin_id)
DECLARE_PLUGIN_SPECIALIZATIONS(BlasFactory);
DECLARE_PLUGIN_SPECIALIZATIONS(DnnFactory);
DECLARE_PLUGIN_SPECIALIZATIONS(FftFactory);
DECLARE_PLUGIN_SPECIALIZATIONS(RngFactory);
#undef DECL_PLUGIN_SPECIALIZATIONS
} // namespace stream_executor
#endif // TENSORFLOW_STREAM_EXECUTOR_PLUGIN_REGISTRY_H_