| #pragma once |
| |
| #include <c10/core/impl/DeviceGuardImplInterface.h> |
| #include <c10/macros/Macros.h> |
| #include <c10/util/Exception.h> |
| |
| #include <c10/cuda/CUDAException.h> |
| #include <c10/cuda/CUDAStream.h> |
| #include <c10/cuda/CUDAFunctions.h> |
| |
| #include <cuda_runtime_api.h> |
| |
| namespace c10 { |
| namespace cuda { |
| namespace impl { |
| |
| struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { |
| static constexpr DeviceType static_type = DeviceType::CUDA; |
| |
| CUDAGuardImpl() {} |
| explicit CUDAGuardImpl(DeviceType t) { |
| TORCH_INTERNAL_ASSERT(t == DeviceType::CUDA); |
| } |
| DeviceType type() const override { |
| return DeviceType::CUDA; |
| } |
| Device exchangeDevice(Device d) const override { |
| TORCH_INTERNAL_ASSERT(d.type() == DeviceType::CUDA); |
| Device old_device = getDevice(); |
| if (old_device.index() != d.index()) { |
| C10_CUDA_CHECK(cudaSetDevice(d.index())); |
| } |
| return old_device; |
| } |
| Device getDevice() const override { |
| int device; |
| C10_CUDA_CHECK(cudaGetDevice(&device)); |
| return Device(DeviceType::CUDA, device); |
| } |
| void setDevice(Device d) const override { |
| TORCH_INTERNAL_ASSERT(d.type() == DeviceType::CUDA); |
| C10_CUDA_CHECK(cudaSetDevice(d.index())); |
| } |
| void uncheckedSetDevice(Device d) const noexcept override { |
| C10_CUDA_CHECK_WARN(cudaSetDevice(d.index())); |
| } |
| Stream getStream(Device d) const noexcept override { |
| return getCurrentCUDAStream(d.index()).unwrap(); |
| } |
| Stream getDefaultStream(Device d) const override { |
| return getDefaultCUDAStream(d.index()); |
| } |
| // NB: These do NOT set the current device |
| Stream exchangeStream(Stream s) const noexcept override { |
| CUDAStream cs(s); |
| auto old_stream = getCurrentCUDAStream(s.device().index()); |
| setCurrentCUDAStream(cs); |
| return old_stream.unwrap(); |
| } |
| DeviceIndex deviceCount() const noexcept override { |
| return device_count(); |
| } |
| |
| // Event-related functions |
| void createEvent( |
| cudaEvent_t* cuda_event, |
| const EventFlag flag) const { |
| // Maps PyTorch's Event::Flag to CUDA flag |
| auto cuda_flag = cudaEventDefault; |
| switch (flag) { |
| case EventFlag::PYTORCH_DEFAULT: |
| case EventFlag::CUDA_EVENT_DISABLE_TIMING: |
| cuda_flag = cudaEventDisableTiming; |
| break; |
| case EventFlag::BACKEND_DEFAULT: |
| case EventFlag::CUDA_EVENT_DEFAULT: |
| cuda_flag = cudaEventDefault; |
| break; |
| default: |
| TORCH_CHECK(false, "CUDA event received unknown flag"); |
| } |
| |
| C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag)); |
| } |
| |
| void destroyEvent( |
| void* event, |
| const DeviceIndex device_index) const noexcept override { |
| if (!event) return; |
| auto cuda_event = static_cast<cudaEvent_t>(event); |
| int orig_device; |
| C10_CUDA_CHECK_WARN(cudaGetDevice(&orig_device)); |
| C10_CUDA_CHECK_WARN(cudaSetDevice(device_index)); |
| C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event)); |
| C10_CUDA_CHECK_WARN(cudaSetDevice(orig_device)); |
| } |
| |
| void record( |
| void** event, |
| const Stream& stream, |
| const DeviceIndex device_index, |
| const EventFlag flag) const override { |
| TORCH_CHECK(device_index == -1 || device_index == stream.device_index(), |
| "Event device index ", |
| device_index, |
| " does not match recording stream's device index ", |
| stream.device_index(), |
| "."); |
| |
| cudaEvent_t cuda_event = static_cast<cudaEvent_t>(*event); |
| CUDAStream cuda_stream{stream}; |
| |
| // Moves to stream's device to record |
| const auto orig_device = getDevice(); |
| setDevice(stream.device()); |
| |
| // Creates the event (lazily) |
| if (!cuda_event) createEvent(&cuda_event, flag); |
| C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream)); |
| // Makes the void* point to the (possibly just allocated) CUDA event |
| *event = cuda_event; |
| |
| // Resets device |
| setDevice(orig_device); |
| } |
| |
| void block( |
| void* event, |
| const Stream& stream) const override { |
| if (!event) return; |
| cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event); |
| CUDAStream cuda_stream{stream}; |
| const auto orig_device = getDevice(); |
| setDevice(stream.device()); |
| C10_CUDA_CHECK(cudaStreamWaitEvent( |
| cuda_stream, |
| cuda_event, |
| /*flags (must be zero)=*/ 0)); |
| setDevice(orig_device); |
| } |
| |
| // May be called from any device |
| bool queryEvent(void* event) const override { |
| if (!event) return true; |
| cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event); |
| const cudaError_t err = cudaEventQuery(cuda_event); |
| if (err != cudaErrorNotReady) { |
| C10_CUDA_CHECK(err); |
| } |
| return (err == cudaSuccess); |
| } |
| }; |
| |
| }}} // namespace c10::cuda::impl |