blob: d0780dc7b2cb48ffe2fcf39e79234a8cc335db7c [file] [log] [blame]
#pragma once
#include <c10/Device.h>
#include <c10/DeviceType.h>
#include <c10/Stream.h>
// Just for C10_ANONYMOUS_VARIABLE
#include <c10/util/Registry.h>
#include <atomic>
namespace c10 {
namespace impl {
/**
* DeviceGuardImplInterface represents the virtual interface which provides
* functionality to provide an RAII class for device and stream switching,
* via DeviceGuard. Every distinct device type, e.g., CUDA and HIP, is
* expected to implement and register an implementation of this interface.
* All classes which inherit from DeviceGuardImplInterface should be declared
* 'final'.
*
* This class exists because we provide a unified interface for performing
* device guards via DeviceGuard, but we cannot assume that we have actually
* compiled against the, e.g., CUDA library, which actually implements
* this guard functionality. In this case, a dynamic dispatch is required
* to cross the library boundary.
*
* If possible, you should directly use implementations of this interface;
* those uses will be devirtualized.
*/
struct C10_API DeviceGuardImplInterface {
/**
* Return the type of device managed by this guard implementation.
*/
virtual DeviceType type() const = 0;
/**
* Set the current device to Device, and return the previous Device.
*/
virtual Device exchangeDevice(Device) const = 0;
// NB: Implementations of exchangeDevice can be a bit boilerplatey. You might
// consider replacing exchangeDevice with a non-virtual function with a baked
// in implementation; however, note that this will triple the number of
// virtual calls (when you implement exchangeDevice in a final subclass,
// the compiler gets to devirtualize everything; it won't do that if you don't
// define it in the subclass!) A common way to solve this problem is to use
// some sort of CRTP; however, we can template DeviceGuardImplInterface since
// we really *do* need it to be virtual. A little boilerplate seems easiest
// to explain. (Another way around this problem is to provide inline
// functions that provide the default implementations, but this seems a little
// hard to explain. In any case, we're only going to have on order of ten
// implementations of this anyway.)
/**
* Get the current device.
*/
virtual Device getDevice() const = 0;
/**
* Set the current device to Device.
*/
virtual void setDevice(Device) const = 0;
/**
* Set the current device to Device, without checking for errors
* (so, e.g., this can be called from a destructor).
*/
virtual void uncheckedSetDevice(Device) const noexcept = 0;
/**
* Get the current stream for a given device.
*/
virtual Stream getStream(Device) const noexcept = 0;
/**
* Set a stream to be the thread local current stream for its device.
* Return the previous stream for that device. You are NOT required
* to set the current device to match the device of this stream.
*/
virtual Stream exchangeStream(Stream) const noexcept = 0;
/**
* Intended use of this class is to leak the DeviceGuardImpl at program end.
* So you better not call the destructor, buster!
*/
virtual ~DeviceGuardImplInterface() = default;
};
// The registry is NON-owning. Each stored pointer is std::atomic so
// that under all interleavings of registry calls the structure is
// race-free. This doesn't cost us anything on reads in X86. (An
// unsynchronized implementation probably is OK too, but I didn't want
// to prove that we never read from device_guard_impl_registry at the
// same time some registration is occurring. Shiver.)
//
// I'd like this registry to be valid even at program destruction time
// (in case someone uses a DeviceGuard in a destructor to do some cleanup
// in the CUDA API.) Since there are no direct accesses of the underlying
// owning objects which I can use to enforce initialization order (unlike
// in a Meyer singleton), it implies that you must *leak* objects when
// putting them in the registry. This is done by deleting the destructor
// on DeviceGuardImplInterface.
extern C10_API std::atomic<const DeviceGuardImplInterface*>
device_guard_impl_registry[static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)];
// I can't conveniently use c10/util/Registry.h for the following reason:
// c10/util/Registry.h gives me a slow way of Create'ing a object of some
// interface from the registry, but no way of quickly accessing an already
// created object. I'll be banging on getDeviceGuardImpl every time we do a
// DeviceGuard, so I really don't want to be doing an unordered_map lookup.
// Better if the registration mechanism directly drops its implementation
// into device_guard_impl_registry.
class C10_API DeviceGuardImplRegistrar {
public:
DeviceGuardImplRegistrar(DeviceType, const DeviceGuardImplInterface*);
};
#define C10_REGISTER_GUARD_IMPL(DevType, DeviceGuardImpl) \
static ::c10::impl::DeviceGuardImplRegistrar C10_ANONYMOUS_VARIABLE(g_##DeviceType)(::c10::DeviceType::DevType, new DeviceGuardImpl());
inline const DeviceGuardImplInterface* getDeviceGuardImpl(DeviceType type) {
auto p = device_guard_impl_registry[static_cast<size_t>(type)].load();
AT_ASSERTM(p, "DeviceGuardImpl for ", type, " is not available");
return p;
}
}} // namespace c10::impl