|  | #pragma once | 
|  |  | 
|  | #include <c10/core/impl/InlineDeviceGuard.h> | 
|  |  | 
|  | namespace c10 { | 
|  |  | 
|  | /// RAII guard that sets a certain default device in its constructor, and | 
|  | /// changes it back to the device that was originally active upon destruction. | 
|  | /// | 
|  | /// The device is always reset to the one that was active at the time of | 
|  | /// construction of the guard. Even if you `set_device` after construction, the | 
|  | /// destructor will still reset the device to the one that was active at | 
|  | /// construction time. | 
|  | /// | 
|  | /// This device guard does NOT have an uninitialized state; it is guaranteed | 
|  | /// to reset a device on exit.  If you are in a situation where you *might* | 
|  | /// want to setup a guard (i.e., are looking for the moral equivalent | 
|  | /// of optional<DeviceGuard>), see OptionalDeviceGuard. | 
|  | class DeviceGuard { | 
|  | public: | 
|  | /// No default constructor; see Note [Omitted default constructor from RAII] | 
|  | explicit DeviceGuard() = delete; | 
|  |  | 
|  | /// Set the current device to the passed Device. | 
|  | explicit DeviceGuard(Device device) : guard_(device) {} | 
|  |  | 
|  | /// This constructor is for testing only. | 
|  | explicit DeviceGuard( | 
|  | Device device, | 
|  | const impl::DeviceGuardImplInterface* impl) | 
|  | : guard_(device, impl) {} | 
|  |  | 
|  | /// Copy is disallowed | 
|  | DeviceGuard(const DeviceGuard&) = delete; | 
|  | DeviceGuard& operator=(const DeviceGuard&) = delete; | 
|  |  | 
|  | /// Move is disallowed, as DeviceGuard does not have an uninitialized state, | 
|  | /// which is required for moves on types with nontrivial destructors. | 
|  | DeviceGuard(DeviceGuard&& other) = delete; | 
|  | DeviceGuard& operator=(DeviceGuard&& other) = delete; | 
|  |  | 
|  | /// Sets the device to the given one.  The specified device must be consistent | 
|  | /// with the device type originally specified during guard construction. | 
|  | /// | 
|  | /// TODO: The consistency check here is inconsistent with StreamGuard's | 
|  | /// behavior with set_stream, where a stream on a different device than | 
|  | /// the original one isn't an error; we just reset the stream and then | 
|  | /// switch devices. | 
|  | void reset_device(at::Device device) { | 
|  | guard_.reset_device(device); | 
|  | } | 
|  |  | 
|  | /// This method is for testing only. | 
|  | void reset_device( | 
|  | at::Device device, | 
|  | const impl::DeviceGuardImplInterface* impl) { | 
|  | guard_.reset_device(device, impl); | 
|  | } | 
|  |  | 
|  | /// Sets the device index to the given one.  The device type is inferred | 
|  | /// from the original device type the guard was constructed with. | 
|  | void set_index(DeviceIndex index) { | 
|  | guard_.set_index(index); | 
|  | } | 
|  |  | 
|  | /// Returns the device that was set at the time the guard was constructed. | 
|  | Device original_device() const { | 
|  | return guard_.original_device(); | 
|  | } | 
|  |  | 
|  | /// Returns the most recent device that was set using this device guard, | 
|  | /// either from construction, or via set_device. | 
|  | Device current_device() const { | 
|  | return guard_.current_device(); | 
|  | } | 
|  |  | 
|  | private: | 
|  | impl::InlineDeviceGuard<impl::VirtualGuardImpl> guard_; | 
|  | }; | 
|  |  | 
|  | /** | 
|  | * A OptionalDeviceGuard is an RAII class that sets a device to some value on | 
|  | * initialization, and resets the device to its original value on destruction. | 
|  | * Morally, a OptionalDeviceGuard is equivalent to optional<DeviceGuard>, but | 
|  | * with extra constructors and methods as appropriate. | 
|  | * | 
|  | * Besides its obvious use (optionally applying a DeviceGuard), | 
|  | * OptionalDeviceGuard is often also used for the following idiom: | 
|  | * | 
|  | *    OptionalDeviceGuard g; | 
|  | *    for (const auto& t : tensors) { | 
|  | *      g.set_device(t.device()); | 
|  | *      do_something_with(t); | 
|  | *    } | 
|  | * | 
|  | * This usage is marginally more efficient than constructing a DeviceGuard every | 
|  | * iteration of the for loop, as it avoids an unnecessary device reset. | 
|  | * | 
|  | * Unlike DeviceGuard, a OptionalDeviceGuard may be uninitialized.  This occurs | 
|  | * when you use the nullary constructor, or pass a nullopt to the constructor. | 
|  | * Uninitialized OptionalDeviceGuards do *nothing*; they do not know what the | 
|  | * original device was and they do not reset on destruction.  This is why | 
|  | * original_device() and current_device() return optional<Device> rather than | 
|  | * Device (as they do in DeviceGuard), and also is why we didn't just | 
|  | * provide OptionalDeviceGuard by default and hide DeviceGuard from users. | 
|  | * | 
|  | * The semantics of an OptionalDeviceGuard are exactly explained by thinking | 
|  | * of it as an optional<DeviceGuard>.  In particular, an initialized | 
|  | * OptionalDeviceGuard doesn't restore device to its value at construction; it | 
|  | * restores device to its value *at initialization*.  So if you have the | 
|  | * program: | 
|  | * | 
|  | *     setDevice(1); | 
|  | *     OptionalDeviceGuard g; | 
|  | *     setDevice(2); | 
|  | *     g.reset_device(Device(DeviceType::CUDA, 3));  // initializes! | 
|  | * | 
|  | * On destruction, g will reset device to 2, rather than 1. | 
|  | * | 
|  | * An uninitialized OptionalDeviceGuard is distinct from a (initialized) | 
|  | * DeviceGuard whose original_device_ and current_device_ match, since the | 
|  | * DeviceGuard will still reset the device to original_device_. | 
|  | */ | 
|  | class OptionalDeviceGuard { | 
|  | public: | 
|  | /// Create an uninitialized guard.  Set the guard later using reset_device. | 
|  | explicit OptionalDeviceGuard() : guard_() {} | 
|  |  | 
|  | /// Initialize the guard, setting the current device to the passed Device. | 
|  | explicit OptionalDeviceGuard(Device device) : guard_(device) {} | 
|  |  | 
|  | /// Initialize the guard if a Device is passed; otherwise leave the | 
|  | /// guard uninitialized. | 
|  | explicit OptionalDeviceGuard(optional<Device> device) : guard_(device) {} | 
|  |  | 
|  | /// Constructor for testing only. | 
|  | explicit OptionalDeviceGuard( | 
|  | Device device, | 
|  | const impl::DeviceGuardImplInterface* impl) | 
|  | : guard_(device, impl) {} | 
|  |  | 
|  | /// Copy is disallowed | 
|  | OptionalDeviceGuard(const OptionalDeviceGuard&) = delete; | 
|  | OptionalDeviceGuard& operator=(const OptionalDeviceGuard&) = delete; | 
|  |  | 
|  | /// Move is disallowed | 
|  | /// See Note [Explicit initialization of optional fields] | 
|  | /// and // Note [Move construction for RAII guards is tricky] | 
|  | /// for rationale. | 
|  | OptionalDeviceGuard(OptionalDeviceGuard&& other) = delete; | 
|  | OptionalDeviceGuard& operator=(OptionalDeviceGuard&& other) = delete; | 
|  |  | 
|  | /// Sets the device to the given one.  The specified device must be consistent | 
|  | /// with the device type originally specified during guard construction. | 
|  | void reset_device(at::Device device) { | 
|  | guard_.reset_device(device); | 
|  | } | 
|  |  | 
|  | /// For testing only | 
|  | void reset_device( | 
|  | at::Device device, | 
|  | const impl::DeviceGuardImplInterface* impl) { | 
|  | guard_.reset_device(device, impl); | 
|  | } | 
|  |  | 
|  | /// Returns the device that was set at the time the guard was constructed. | 
|  | optional<Device> original_device() const { | 
|  | return guard_.original_device(); | 
|  | } | 
|  |  | 
|  | /// Returns the most recent device that was set using this device guard, | 
|  | /// either from construction, or via reset_device. | 
|  | optional<Device> current_device() const { | 
|  | return guard_.current_device(); | 
|  | } | 
|  |  | 
|  | private: | 
|  | impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl> guard_; | 
|  | }; | 
|  |  | 
|  | // Note [Whither the DeviceGuard boilerplate] | 
|  | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | 
|  | // Design note: in principle, we could avoid these wrappers using: | 
|  | // | 
|  | // using DeviceGuard = impl::InlineDeviceGuard<impl::VirtualGuardImpl>; | 
|  | // using OptionalDeviceGuard = | 
|  | // impl::InlineOptionalDeviceGuard<impl::VirtualGuardImpl>; | 
|  | // | 
|  | // But the error messages are worse, and our users can't just look at the | 
|  | // header file to find out what's going on.  Furthermore, for specializations | 
|  | // like CUDAStreamGuard, it can be profitable to replace some interfaces with | 
|  | // refined types (e.g., return CUDAStream instead of Stream).  So, we eat | 
|  | // the boilerplate and write out the API explicitly. | 
|  |  | 
|  | } // namespace c10 |