Improves ATen CUDAEvent (#11293)

Summary:
After submitting PR #9726, PR #10581 created a different CUDAEvent class. The CUDAEvent proposed in #9726 was similar to the c10d::CUDAEvent class with additional testing and functionality. In particular, it was movable but not copyable. The CUDAEvent created by #10581 is refcounted and copyable. This PR retains the refcounting of the latter PR while fixing several bugs, adding tests, and extending the functionality to support testing and usage like in PR #8354. In particular, this PR:

- Adds set_device() to CUDAContext
- Adds three CUDAEvent tests to stream_test.cpp
- Fixes three bugs:
- Refcounting was broken. Destroying an of the RAIIs holding a particular CUDAEvent would destroy the event UNLESS it was the last RAII (the check was backwards).
- Moving an event would cause a segfault.
- Events were not destroyed on the device they were created on. See PR #9415 (pietern)
- Adds the happened() and recordOnce() functions
- Changes the record() functions to not be const
- Adds additional assertions to verify correctness

This PR does not:

- Make c10d use the ATen CUDAEvent (this is appropriate for a separate PR)

Whether events should be refcounted is an interesting question. It adds some atomic operations and makes event creation eager. Making events movable but not copyable (like the c10d events) avoids these costs and allows events to be lazily constructed. Lazy construction is preferable when working with containers (like std::array or std::vector) and because the event's device can be set automatically to the first stream it's recorded on. With eager construction the user is required to understand that events have a device and acquire the device of the stream the event will be recorded on upfront. This can be seen here:

https://github.com/pytorch/pytorch/blob/542aadd9a7609892e207c1e15de08a975b697752/aten/src/ATen/native/cudnn/RNN.cpp#L1130-L1132

and that file is the only one which currently uses the ATen CUDAEvent.

Refcounting does allow single writer multi-reader scenarios, although these scenarios can be also be supported by providing indirect access to the underlying CUDAEvent. I believe all current and planned usage scenarios do not require refcounting, and if desired I can update this PR to remove refcounting and make the ATen event movable but not copyable like the c10d event. I think not refcounting is preferable because it can improve performance, ease usability, and simplify the code (as seen with two of the above bugs).

I have decided to separate this from PR #8354 since while it's required for PR #8354 the changes are, clearly, of independent interest. PR #8354 has a new dependency on this one, however. I am closing PR #9726 in favor of this PR.

apaszke ezyang pietern
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11293

Differential Revision: D9665836

Pulled By: soumith

fbshipit-source-id: a1513fa4f9761e2f304d126e402f6b6950e1c1d2
diff --git a/aten/src/ATen/cuda/CUDAContext.cpp b/aten/src/ATen/cuda/CUDAContext.cpp
index e30162d..7f934ef 100644
--- a/aten/src/ATen/cuda/CUDAContext.cpp
+++ b/aten/src/ATen/cuda/CUDAContext.cpp
@@ -16,6 +16,10 @@
   return cur_device;
 }
 
+void set_device(int64_t device) {
+  AT_CUDA_CHECK(cudaSetDevice((int)device));
+}
+
 cudaDeviceProp* getCurrentDeviceProperties() {
   return THCState_getCurrentDeviceProperties(at::globalContext().getTHCState());
 }
diff --git a/aten/src/ATen/cuda/CUDAContext.h b/aten/src/ATen/cuda/CUDAContext.h
index fee250e..3a75483 100644
--- a/aten/src/ATen/cuda/CUDAContext.h
+++ b/aten/src/ATen/cuda/CUDAContext.h
@@ -39,6 +39,8 @@
 
 AT_API int64_t current_device();
 
+AT_API void set_device(int64_t device);
+
 AT_API cudaDeviceProp* getCurrentDeviceProperties();
 
 AT_API cudaDeviceProp* getDeviceProperties(int64_t device);
diff --git a/aten/src/ATen/cuda/CUDAEvent.cpp b/aten/src/ATen/cuda/CUDAEvent.cpp
deleted file mode 100644
index ab6c842..0000000
--- a/aten/src/ATen/cuda/CUDAEvent.cpp
+++ /dev/null
@@ -1,66 +0,0 @@
-#include "ATen/cuda/CUDAEvent.h"
-#include "ATen/cuda/CUDAContext.h"
-#include "ATen/cuda/CUDAStream.h"
-#include "ATen/cuda/Exceptions.h"
-#include "ATen/core/Error.h"
-
-#include <mutex>
-#include <atomic>
-
-// Internal implementation is entirely hidden
-struct CUDAEventInternals {
-  std::atomic<int> refcount;
-  int64_t device; // Note: cudaGetDevice works with int32_t, not int64_t
-  cudaEvent_t event;
-};
-
-namespace at {
-namespace cuda {
-
-namespace detail {
-
-/*
-* Pointer-based event API
-*/
-CUDAEventInternals* CUDAEvent_create(unsigned int flags) {
-  std::unique_ptr<CUDAEventInternals> internals { new CUDAEventInternals() };
-  internals->refcount = 1;
-  internals->device = current_device();
-  AT_CUDA_CHECK(cudaEventCreateWithFlags(&internals->event, flags));
-  return internals.release();
-}
-
-void CUDAEvent_retain(CUDAEventInternals* internals) {
-  internals->refcount++;
-}
-
-void CUDAEvent_uncheckedFree(CUDAEventInternals* internals) {
-  if (--internals->refcount) {
-    cudaEventDestroy(internals->event);
-  }
-}
-cudaEvent_t CUDAEvent_event(CUDAEventInternals* internals) {
-  return internals->event;
-}
-
-int64_t CUDAEvent_device(CUDAEventInternals* internals) {
-  return internals->device;
-}
-
-void CUDAEvent_record(CUDAEventInternals* internals, const CUDAStream& stream) {
-  AT_CUDA_CHECK(cudaEventRecord(internals->event, stream));
-}
-
-} // namespace detail
-
-void CUDAEvent::record() const {
-  record(getCurrentCUDAStream());
-}
-
-void CUDAEvent::record(const CUDAStream& stream) const {
-  detail::CUDAEvent_record(internals_, stream);
-}
-
-
-} // namespace cuda
-} // namespace at
diff --git a/aten/src/ATen/cuda/CUDAEvent.h b/aten/src/ATen/cuda/CUDAEvent.h
index 7a711bf..04aba27 100644
--- a/aten/src/ATen/cuda/CUDAEvent.h
+++ b/aten/src/ATen/cuda/CUDAEvent.h
@@ -1,78 +1,116 @@
 #pragma once
 
-#include <cstdint>
-#include <utility>
+#include "ATen/cuda/ATenCUDAGeneral.h"
+#include "ATen/cuda/CUDAStream.h"
+#include "ATen/cuda/CUDAContext.h"
+#include "ATen/cuda/Exceptions.h"
+#include "ATen/core/Error.h"
+#include "ATen/DeviceGuard.h"
 
 #include "cuda_runtime_api.h"
 
-#include <ATen/core/ATenGeneral.h>
-#include <ATen/Error.h>
+#include <cstdint>
+#include <utility>
+
+namespace at { namespace cuda {
 
 /*
-* A CUDA event interface with no CUDA build dependency.
+* CUDAEvents are movable not copyable wrappers around CUDA's events.
 *
-* Includes the CUDAEvent RAII class and a pointer-based event API.
+* CUDAEvents are constructed lazily when recorded on streams. The events
+* have a device, and this device is acquired from the first recording stream.
+* Later streams that record to the event must share this device, but streams
+* on any device can wait on the event.
 */
-
-struct CUDAEventInternals;
-
-namespace at {
-namespace cuda {
-
-struct CUDAStream;
-
-namespace detail {
-
-// Pointer-based API (for internal use)
-// Note: ATen/Context is preferred to work with streams safely
-AT_API CUDAEventInternals* CUDAEvent_create(unsigned int flags);
-AT_API void CUDAEvent_retain(CUDAEventInternals* internals);
-AT_API void CUDAEvent_uncheckedFree(CUDAEventInternals* internals);
-AT_API cudaEvent_t CUDAEvent_event(CUDAEventInternals* internals);
-AT_API int64_t CUDAEvent_device(CUDAEventInternals* internals);
-
-} // namespace detail
-
-struct CUDAEvent {
+struct AT_CUDA_API CUDAEvent {
   // Constants
   static constexpr unsigned int DEFAULT_FLAGS = cudaEventDisableTiming;
 
   // Constructors
-  CUDAEvent(unsigned int flags = DEFAULT_FLAGS)
-    : internals_(detail::CUDAEvent_create(flags)) {}
+  CUDAEvent(unsigned int flags = DEFAULT_FLAGS) 
+  : flags_{flags} { }
 
-  ~CUDAEvent() { detail::CUDAEvent_uncheckedFree(internals_); }
-
-  CUDAEvent(const CUDAEvent& other) {
-    detail::CUDAEvent_retain(other.internals_);
-    internals_ = other.internals_;
+  // Note: event destruction done on creating device to avoid creating a 
+  // CUDA context on other devices.
+  ~CUDAEvent() { 
+    try {
+      if (is_created_) {
+        at::DeviceGuard device_guard{(int)device_};
+        cudaEventDestroy(event_);
+      }
+    } catch (...) { /* No throw */ }
   }
 
-  CUDAEvent(CUDAEvent&& other) {
-    std::swap(internals_, other.internals_);
-  }
+  CUDAEvent(const CUDAEvent&) = delete;
+  CUDAEvent& operator=(const CUDAEvent&) = delete;
 
-  CUDAEvent& operator=(CUDAEvent other) noexcept {
-    std::swap(internals_, other.internals_);
+  CUDAEvent(CUDAEvent&& other) { moveHelper(std::move(other)); } 
+  CUDAEvent& operator=(CUDAEvent&& other) {
+    moveHelper(std::move(other));
     return *this;
   }
 
-  operator cudaEvent_t() const { return detail::CUDAEvent_event(internals_); }
+  operator cudaEvent_t() const { return event(); }
 
   // Less than operator (to allow use in sets)
   friend bool operator<(const CUDAEvent& left, const CUDAEvent& right) {
-    return left.internals_ < right.internals_;
+    return left.event_ < right.event_;
   }
 
-  int64_t device() const { return detail::CUDAEvent_device(internals_); }
-  cudaEvent_t event() const { return detail::CUDAEvent_event(internals_); }
-  CUDAEventInternals* internals() const { return internals_; }
+  bool isCreated() const { return is_created_; }
+  int64_t device() const { return device_; }
+  cudaEvent_t event() const { return event_; }
+  
+  bool happened() const { 
+    return (was_recorded_ && cudaEventQuery(event_) == cudaSuccess);
+  }
 
-  void record() const; // Record on the current stream
-  void record(const CUDAStream& stream) const;
+  void record() { record(getCurrentCUDAStream()); }
+  
+  void recordOnce(const CUDAStream& stream) { 
+    if (!was_recorded_) record(stream);
+  }
+  
+  void record(const CUDAStream& stream) {
+    if (is_created_) {
+      AT_ASSERT(device_ == stream.device());
+    } else {
+      create(stream.device());
+    }
+
+    AT_CUDA_CHECK(cudaEventRecord(event_, stream));
+    was_recorded_ = true;
+  }
+
+  void block (const CUDAStream& stream) {
+    if (is_created_) {
+      AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, 0));
+    }
+  }
+  
 
 private:
-  CUDAEventInternals* internals_;
+  unsigned int flags_ = DEFAULT_FLAGS;
+  bool is_created_ = false;
+  bool was_recorded_ = false;
+  int64_t device_ = -1;
+  cudaEvent_t event_;
+
+  void moveHelper(CUDAEvent&& other) {
+    std::swap(flags_, other.flags_);
+    std::swap(is_created_, other.is_created_);
+    std::swap(was_recorded_, other.was_recorded_);
+    std::swap(device_, other.device_);
+    std::swap(event_, other.event_);
+  }
+
+  void create(const int64_t device) {
+    at::DeviceGuard device_guard{(int)device};
+    AT_CUDA_CHECK(cudaEventCreateWithFlags(&event_, flags_));
+
+    is_created_ = true;
+    device_ = device;
+  }
 };
 
 } // namespace cuda
diff --git a/aten/src/ATen/cuda/CUDAStream.cpp b/aten/src/ATen/cuda/CUDAStream.cpp
index 2cf29c9..7a5ff19 100644
--- a/aten/src/ATen/cuda/CUDAStream.cpp
+++ b/aten/src/ATen/cuda/CUDAStream.cpp
@@ -209,7 +209,8 @@
 }
 
 void CUDAStream_synchronize_with(CUDAStreamInternals* ptr, const CUDAEvent& event) {
-    AT_CUDA_CHECK(cudaStreamWaitEvent(ptr->stream, event, 0));
+    if (event.isCreated())
+      AT_CUDA_CHECK(cudaStreamWaitEvent(ptr->stream, event, 0));
 }
 
 } // namespace detail
diff --git a/aten/src/ATen/cuda/CUDAStream.h b/aten/src/ATen/cuda/CUDAStream.h
index 6802143..cd73922 100644
--- a/aten/src/ATen/cuda/CUDAStream.h
+++ b/aten/src/ATen/cuda/CUDAStream.h
@@ -77,7 +77,7 @@
 
 // RAII for a CUDA stream
 // Allows use as a cudaStream_t, copying, moving, and metadata access.
-struct CUDAStream {
+struct AT_CUDA_API CUDAStream {
 
   // Constructors
   CUDAStream() = default;
diff --git a/aten/src/ATen/test/stream_test.cpp b/aten/src/ATen/test/stream_test.cpp
index 18212f6..145c4f4 100644
--- a/aten/src/ATen/test/stream_test.cpp
+++ b/aten/src/ATen/test/stream_test.cpp
@@ -3,6 +3,7 @@
 
 #include "ATen/cuda/CUDAContext.h"
 #include "ATen/cuda/CUDAGuard.h"
+#include "ATen/cuda/CUDAEvent.h"
 
 #include "cuda_runtime.h"
 
@@ -211,7 +212,6 @@
   REQUIRE(hasDuplicates);
 }
 
-// Note: to be expanded once CUDAEvent PR is accepted
 TEST_CASE("Multi-GPU") {
   if (at::cuda::getNumGPUs() < 2) return;
 
@@ -226,3 +226,44 @@
   at::DeviceGuard device_guard{1};
   REQUIRE(s1 == at::cuda::getCurrentCUDAStream());
 }
+
+TEST_CASE("CUDAEvent Syncs") {
+  const auto stream = at::cuda::createCUDAStream();
+  at::cuda::CUDAEvent event;
+
+  REQUIRE(!event.happened());
+
+  event.recordOnce(stream);
+
+  const auto wait_stream0 = at::cuda::createCUDAStream();
+  const auto wait_stream1 = at::cuda::createCUDAStream();
+
+  wait_stream0.synchronize_with(event);
+  wait_stream1.synchronize_with(event);
+
+  cudaStreamSynchronize(wait_stream0);
+  REQUIRE(event.happened());
+}
+
+TEST_CASE("Cross-Device Events") {
+  if (at::cuda::getNumGPUs() < 2) return;
+
+  const auto stream0 = at::cuda::createCUDAStream();
+  at::cuda::CUDAEvent event0;
+
+  at::cuda::set_device(1);
+  const auto stream1 = at::cuda::createCUDAStream();
+  at::cuda::CUDAEvent event1;
+
+  event0.record(stream0);
+  event1.record(stream1);
+  
+  event0 = std::move(event1);
+  
+  REQUIRE(event0.device() == 1);
+
+  stream0.synchronize_with(event0);
+  
+  cudaStreamSynchronize(stream0);
+  REQUIRE(event0.happened());
+}