blob: ff75ba3a7e33949c45b3e575deba33059eaa9b0d [file] [log] [blame]
#pragma once
#include <ATen/CPUGeneral.h>
#include <ATen/Type.h>
#include <ATen/TypeExtendedInterface.h>
#include <ATen/Utils.h>
#include <ATen/LegacyTHDispatch.h>
#include <ATen/LegacyTHDispatcher.h>
#include <ATen/core/ATenGeneral.h>
#include <ATen/core/Generator.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/core/VariableHooksInterface.h>
#include <ATen/detail/CUDAHooksInterface.h>
#include <ATen/detail/HIPHooksInterface.h>
#include <ATen/detail/ComplexHooksInterface.h>
#include <c10/util/Exception.h>
#include <memory>
#include <mutex>
#include <cstdint>
namespace at {
class Tensor;
class CAFFE2_API Context {
public:
Context();
TypeExtendedInterface* getNonVariableTypeRaw(Backend p, ScalarType s) {
return static_cast<TypeExtendedInterface*>(globalLegacyTypeDispatch().getNonVariableTypeRaw(p, s));
}
TypeExtendedInterface * getNonVariableTypeOpt(Backend p, ScalarType s) {
return static_cast<TypeExtendedInterface*>(globalLegacyTypeDispatch().getNonVariableTypeOpt(p, s));
}
TypeExtendedInterface & getNonVariableType(Backend p, ScalarType s) {
return static_cast<TypeExtendedInterface&>(globalLegacyTypeDispatch().getNonVariableType(p, s));
}
TypeExtendedInterface & getVariableType(Backend p, ScalarType s) {
return static_cast<TypeExtendedInterface&>(globalLegacyTypeDispatch().getVariableType(p, s));
}
TypeExtendedInterface & getType(Backend p, ScalarType s, bool is_variable) {
return static_cast<TypeExtendedInterface&>(globalLegacyTypeDispatch().getType(p, s, is_variable));
}
LegacyTHDispatcher& getLegacyTHDispatcher(Backend p, ScalarType s) {
return globalLegacyTHDispatch().getLegacyTHDispatcher(p, s);
}
// The passed in Type must be delete'able
// TODO: Just make it take a unique_ptr
void registerType(Backend b, ScalarType s, Type* t) {
globalLegacyTypeDispatch().registerType(b, s,
LegacyTypeDispatch::TypeUniquePtr{t, LegacyTypeDeleter([](Type* p) { delete p; }) });
}
void registerLegacyTHDispatcher(Backend b, ScalarType s, LegacyTHDispatcher* t) {
globalLegacyTHDispatch().registerDispatcher(b, s,
LegacyTHDispatch::LegacyTHDispatcherUniquePtr{t, LegacyTHDispatcherDeleter([](LegacyTHDispatcher* p) { delete p; }) });
}
Generator & defaultGenerator(DeviceType device_type) {
initCUDAIfNeeded(device_type);
initHIPIfNeeded(device_type);
auto & generator = generator_registry[static_cast<int>(device_type)];
if(!generator)
AT_ERROR(DeviceTypeName(device_type), " backend type not enabled.");
return *generator;
}
bool hasOpenMP() const;
bool hasMKL() const;
bool hasLAPACK() const;
bool hasMAGMA() const {
return detail::getCUDAHooks().hasMAGMA();
}
bool hasCUDA() const {
return detail::getCUDAHooks().hasCUDA();
}
bool hasHIP() const {
return detail::getHIPHooks().hasHIP();
}
// defined in header so that getNonVariableType has ability to inline
// call_once check. getNonVariableType is called fairly frequently
THCState* lazyInitCUDA() {
std::call_once(thc_init,[&] {
thc_state = detail::getCUDAHooks().initCUDA();
generator_registry[static_cast<int>(DeviceType::CUDA)] =
detail::getCUDAHooks().initCUDAGenerator(this);
detail::getCUDAHooks().registerCUDATypes(this);
});
return thc_state.get();
}
THHState* lazyInitHIP() {
std::call_once(thh_init,[&] {
thh_state = detail::getHIPHooks().initHIP();
generator_registry[static_cast<int>(DeviceType::HIP)] =
detail::getHIPHooks().initHIPGenerator(this);
detail::getHIPHooks().registerHIPTypes(this);
});
return thh_state.get();
}
void lazyInitComplex() {
std::call_once(complex_init_, [&] {
detail::getComplexHooks().registerComplexTypes(this);
});
}
THCState* getTHCState() {
// AT_ASSERT(thc_state);
return thc_state.get();
}
THHState* getTHHState() {
return thh_state.get();
}
size_t freshTypeID() {
return next_id++;
}
bool setFlushDenormal(bool on);
// NB: This method is *purely* whether or not a user requested
// that CuDNN was enabled, it doesn't actually say anything about
// whether or not CuDNN is actually usable. Use cudnn_is_acceptable
// to test this instead
bool userEnabledCuDNN() const;
void setUserEnabledCuDNN(bool e);
bool benchmarkCuDNN() const;
void setBenchmarkCuDNN(bool);
bool deterministicCuDNN() const;
void setDeterministicCuDNN(bool);
std::unique_ptr<Generator>
generator_registry[static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)];
private:
void initCUDAIfNeeded(DeviceType p) {
if (p == DeviceType::CUDA) {
lazyInitCUDA();
}
}
void initHIPIfNeeded(DeviceType p) {
if (p == DeviceType::HIP) {
lazyInitHIP();
}
}
void initComplexIfNeeded(ScalarType s) {
if (isComplexType(s)) {
lazyInitComplex();
}
}
std::once_flag thc_init;
std::once_flag thh_init;
std::once_flag complex_init_;
bool enabled_cudnn = true;
bool deterministic_cudnn = false;
bool benchmark_cudnn = false;
std::atomic<size_t> next_id;
std::unique_ptr<THCState, void(*)(THCState*)> thc_state;
std::unique_ptr<THHState, void(*)(THHState*)> thh_state;
friend struct Type;
};
CAFFE2_API Context& globalContext();
static inline void init() {
globalContext();
if (const char *env_p = std::getenv("OMP_NUM_THREADS")) {
at::set_num_threads(std::stoi(env_p));
}
if (const char *env_p = std::getenv("MKL_NUM_THREADS")) {
at::set_num_threads(std::stoi(env_p));
}
}
static inline TypeExtendedInterface& getNonVariableType(Backend p, ScalarType s) {
return globalContext().getNonVariableType(p, s);
}
static inline TypeExtendedInterface& getNonVariableType(DeviceType p, ScalarType s) {
return globalContext().getNonVariableType(deviceTypeToBackend(p), s);
}
CAFFE2_API TypeExtendedInterface& getType(TensorOptions options);
CAFFE2_API TypeExtendedInterface& getType(const TensorImpl*);
CAFFE2_API TypeExtendedInterface& getType(const Tensor&);
CAFFE2_API Allocator* getCPUAllocator();
static inline TypeExtendedInterface& CPU(ScalarType s) {
return getNonVariableType(Backend::CPU, s);
}
static inline TypeExtendedInterface& CUDA(ScalarType s) {
return getNonVariableType(Backend::CUDA, s);
}
static inline TypeExtendedInterface& HIP(ScalarType s) {
return getNonVariableType(Backend::HIP, s);
}
CAFFE2_API LegacyTHDispatcher& getLegacyTHDispatcher(TensorOptions options);
CAFFE2_API LegacyTHDispatcher& getLegacyTHDispatcher(const Tensor&);
static inline bool hasCUDA() {
return globalContext().hasCUDA();
}
static inline bool hasHIP() {
return globalContext().hasHIP();
}
static inline size_t getNumGPUs() {
if (hasCUDA()) {
return detail::getCUDAHooks().getNumGPUs();
}
if (hasHIP()) {
return detail::getHIPHooks().getNumGPUs();
}
return 0;
}
static inline bool hasOpenMP() {
return globalContext().hasOpenMP();
}
static inline bool hasMKL() {
return globalContext().hasMKL();
}
static inline bool hasLAPACK() {
return globalContext().hasLAPACK();
}
static inline bool hasMAGMA() {
return globalContext().hasMAGMA();
}
static inline void manual_seed(uint64_t seed) {
globalContext().defaultGenerator(DeviceType::CPU).manualSeed(seed);
// NB: Sometimes we build with CUDA, but we don't have any GPUs
// available. In that case, we must not seed CUDA; it will fail!
if (hasCUDA() && detail::getCUDAHooks().getNumGPUs() > 0) {
globalContext().defaultGenerator(DeviceType::CUDA).manualSeedAll(seed);
}
}
} // namespace at