Build device generic torch.Stream and torch.Event based on c10::Stream/Event (#123611)

This diff intends to build device generic torch.Stream and torch.Event for newly added accelerators in PyTorch.
------------
**torch.Stream APIs**
```
# Defined in torch/csrc/Stream.cpp
class Stream(_StreamBase):
    stream_id: _int  # Stream id
    device_index: _int
    device_type: _int

    device: _device  # The device of the stream

    @overload
    def __new__(self, device: Optional[DeviceLikeType] = None, priority: _int = 0) -> Stream: ...
    @overload
    def __new__(self, stream_id: _int, device_index: _int, device_type: _int, priority: _int = 0) -> Stream: ...
    def query(self) -> _bool: ...
    def synchronize(self) -> None: ...
    def wait_event(self, event: Event) -> None: ...
    def wait_stream(self, other: Stream) -> None: ...
    def record_event(self, event: Optional[Event] = None) -> Event: ...
    def query(self) -> None: ...
    def synchronize(self) -> None: ...
    def __hash__(self) -> _int: ...
    def __repr__(self) -> str: ...
    def __eq__(self, other: object) -> _bool: ...
```
------------------
**torch.Event APIs**:
- IPC related APIs are not implemented, since many device backends don't support it, but we leave interfaces there for future adaption of torch.cuda.Stream.
- currently only the enable_timing is supported, since it is the most common one used in other device backends. We have to refactor the event flag system in PyTorch to support more fancy flag.
- elapsedTime API is added to c10::Event

```
# Defined in torch/csrc/Event.cpp
class Event(_EventBase):

    device: _device  # The device of the Event
    event_id: _int # The raw event created by device backend

    def __new__(self,
        device: Optional[DeviceLikeType] = None,
        enable_timing: _bool = False,
        blocking: _bool = False,
        interprocess: _bool = False) -> Event: ...
    @classmethod
    def from_ipc_handle(self, device: DeviceLikeType, ipc_handle: bytes) -> Event: ...
    def record(self, stream: Optional[Stream] = None) -> None: ...
    def wait(self, stream: Optional[Stream] = None) -> None: ...
    def query(self) -> _bool: ...
    def elapsed_time(self, other: Event) -> _float: ...
    def synchronize(self) -> None: ...
    def ipc_handle(self) -> bytes: ...
    def __repr__(self) -> str: ...
```

-----------

c10::Event provides new APIs
- calculate **elapsedTime**.
- Get raw event id
- Synchronize event.

```
  double elapsedTime(const Event& event) const {
    return impl_.elapsedTime(event.impl_);
  }

  void* eventId() const {
    return impl_.eventId();
  }

  void synchronize() const {
    return impl_.synchronize();
  }
```
----------
TODO: need to find a good way to test them in PyTorch with API mocks.

Differential Revision: [D55351839](https://our.internmc.facebook.com/intern/diff/D55351839/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123611
Approved by: https://github.com/albanD
diff --git a/build_variables.bzl b/build_variables.bzl
index 5716230..71b5c5e 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -795,6 +795,7 @@
     "torch/csrc/StorageMethods.cpp",
     "torch/csrc/StorageSharing.cpp",
     "torch/csrc/Stream.cpp",
+    "torch/csrc/Event.cpp",
     "torch/csrc/TypeInfo.cpp",
     "torch/csrc/api/src/python/init.cpp",
     "torch/csrc/autograd/functions/init.cpp",
diff --git a/c10/core/Event.h b/c10/core/Event.h
index 2cbaf18..b94db9f 100644
--- a/c10/core/Event.h
+++ b/c10/core/Event.h
@@ -118,6 +118,18 @@
     return impl_.query();
   }
 
+  double elapsedTime(const Event& event) const {
+    return impl_.elapsedTime(event.impl_);
+  }
+
+  void* eventId() const {
+    return impl_.eventId();
+  }
+
+  void synchronize() const {
+    return impl_.synchronize();
+  }
+
  private:
   impl::InlineEvent<impl::VirtualGuardImpl> impl_;
 };
diff --git a/c10/core/impl/DeviceGuardImplInterface.h b/c10/core/impl/DeviceGuardImplInterface.h
index 1b168f7..59210a9 100644
--- a/c10/core/impl/DeviceGuardImplInterface.h
+++ b/c10/core/impl/DeviceGuardImplInterface.h
@@ -123,6 +123,16 @@
   }
 
   /**
+   * Return a new stream for a given device and priority. The stream will be
+   * copied and shared around, device backend should be able to correctly handle
+   * the lifetime of the stream.
+   */
+  virtual Stream getNewStream(Device, int priority = 0) const {
+    (void)priority;
+    TORCH_CHECK(false, "Backend doesn't support create a new Stream.")
+  }
+
+  /**
    * Set a stream to be the thread local current stream for its device.
    * Return the previous stream for that device. You are NOT required
    * to set the current device to match the device of this stream.
@@ -195,6 +205,14 @@
   }
 
   /**
+   * Wait (by blocking the calling thread) until all the work previously
+   * recorded on the event has completed running on the device.
+   */
+  virtual void synchronizeEvent(void* /*event*/) const {
+    TORCH_CHECK(false, "Backend doesn't support synchronizing events.");
+  }
+
+  /**
    * Ensure the caching allocator (if any) is aware that the given DataPtr is
    * being used on the given stream, and that it should thus avoid recycling the
    * DataPtr until all work on that stream is done.
@@ -203,6 +221,13 @@
   }
 
   /**
+   * Fetch the elapsed time between two recorded events.
+   */
+  virtual double elapsedTime(void* /*event1*/, void* /*event2*/) const {
+    TORCH_CHECK(false, "Backend doesn't support elapsedTime.");
+  }
+
+  /**
    * Intended use of this class is to leak the DeviceGuardImpl at program end.
    * So you better not call the destructor, buster!
    */
@@ -234,6 +259,13 @@
     // no-op
     return Stream(Stream::DEFAULT, Device(D, -1));
   }
+
+  Stream getNewStream(Device, int priority = 0) const override {
+    // no-op
+    (void)priority;
+    return Stream(Stream::DEFAULT, Device(D, -1));
+  }
+
   // NB: These do NOT set the current device
   Stream exchangeStream(Stream) const noexcept override {
     // no-op
diff --git a/c10/core/impl/InlineEvent.h b/c10/core/impl/InlineEvent.h
index ef1e2c6..3485da3 100644
--- a/c10/core/impl/InlineEvent.h
+++ b/c10/core/impl/InlineEvent.h
@@ -101,6 +101,32 @@
     return backend_.queryEvent(event_);
   }
 
+  void* eventId() const {
+    return event_;
+  }
+
+  double elapsedTime(const InlineEvent& other) const {
+    TORCH_CHECK(
+        other.was_marked_for_recording(),
+        "other was not marked for recording.");
+    TORCH_CHECK(
+        was_marked_for_recording(), "self was not marked for recording.");
+    TORCH_CHECK(
+        other.device_type() == device_type_,
+        "Event device type ",
+        DeviceTypeName(device_type_),
+        " does not match other's device type ",
+        DeviceTypeName(other.device_type()),
+        ".");
+    return backend_.elapsedTime(event_, other.event_);
+  }
+
+  void synchronize() const {
+    if (!was_marked_for_recording_)
+      return;
+    backend_.synchronizeEvent(event_);
+  }
+
  private:
   void* event_ = nullptr;
   T backend_;
diff --git a/c10/core/impl/VirtualGuardImpl.h b/c10/core/impl/VirtualGuardImpl.h
index ce32411..2065150 100644
--- a/c10/core/impl/VirtualGuardImpl.h
+++ b/c10/core/impl/VirtualGuardImpl.h
@@ -39,6 +39,9 @@
   Stream getStream(Device d) const noexcept override {
     return impl_->getStream(d);
   }
+  Stream getNewStream(Device d, int priority = 0) const override {
+    return impl_->getNewStream(d, priority);
+  }
   Stream getDefaultStream(Device d) const override {
     return impl_->getDefaultStream(d);
   }
@@ -84,6 +87,14 @@
     impl_->recordDataPtrOnStream(data_ptr, stream);
   }
 
+  double elapsedTime(void* event1, void* event2) const override {
+    return impl_->elapsedTime(event1, event2);
+  }
+
+  void synchronizeEvent(void* event) const override {
+    return impl_->synchronizeEvent(event);
+  }
+
  private:
   const DeviceGuardImplInterface* impl_ = nullptr;
 };
diff --git a/c10/cuda/impl/CUDAGuardImpl.h b/c10/cuda/impl/CUDAGuardImpl.h
index 7c0ea21..2d983be 100644
--- a/c10/cuda/impl/CUDAGuardImpl.h
+++ b/c10/cuda/impl/CUDAGuardImpl.h
@@ -62,6 +62,9 @@
   Stream getDefaultStream(Device d) const override {
     return getDefaultCUDAStream(d.index());
   }
+  Stream getNewStream(Device d, int priority = 0) const override {
+    return getStreamFromPool(priority, d.index());
+  }
   Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
       const override {
     return getStreamFromPool(isHighPriority, d.index());
diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py
index 65aa339..18e373e 100644
--- a/test/test_public_bindings.py
+++ b/test/test_public_bindings.py
@@ -228,6 +228,7 @@
             "StaticModule",
             "Stream",
             "StreamObjType",
+            "Event",
             "StringType",
             "SUM",
             "SymFloat",
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 1c6e40b..882055e 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -112,7 +112,44 @@
     device_index: _int
     device_type: _int
 
-    device: device  # The device of the stream
+    device: _device  # The device of the stream
+
+    @overload
+    def __new__(self, device: Optional[DeviceLikeType] = None, *, priority: _int = 0) -> Stream: ...
+    @overload
+    def __new__(self, stream_id: _int, device_index: _int, device_type: _int, *, priority: _int = 0) -> Stream: ...
+    def query(self) -> _bool: ...
+    def synchronize(self) -> None: ...
+    def wait_event(self, event: Event) -> None: ...
+    def wait_stream(self, other: Stream) -> None: ...
+    def record_event(self, event: Optional[Event] = None) -> Event: ...
+    def __hash__(self) -> _int: ...
+    def __repr__(self) -> str: ...
+    def __eq__(self, other: object) -> _bool: ...
+
+
+# Defined in torch/csrc/Event.cpp
+class Event:
+
+    device: _device  # The device of the Event
+    event_id: _int # The raw event created by device backend
+
+    def __new__(self,
+        device: Optional[DeviceLikeType] = None,
+        *,
+        enable_timing: _bool = False,
+        blocking: _bool = False,
+        interprocess: _bool = False) -> Event: ...
+    @classmethod
+    def from_ipc_handle(self, device: _device, ipc_handle: bytes) -> Event: ...
+    def record(self, stream: Optional[Stream] = None) -> None: ...
+    def wait(self, stream: Optional[Stream] = None) -> None: ...
+    def query(self) -> _bool: ...
+    def elapsed_time(self, other: Event) -> _float: ...
+    def synchronize(self) -> None: ...
+    def ipc_handle(self) -> bytes: ...
+    def __repr__(self) -> str: ...
+
 
 # Defined in torch/csrc/Size.cpp
 class Size(Tuple[_int, ...]):
diff --git a/torch/csrc/Event.cpp b/torch/csrc/Event.cpp
new file mode 100644
index 0000000..b8cf8b2
--- /dev/null
+++ b/torch/csrc/Event.cpp
@@ -0,0 +1,328 @@
+#include <pybind11/pybind11.h>
+#include <torch/csrc/Device.h>
+#include <torch/csrc/Event.h>
+#include <torch/csrc/Stream.h>
+#include <torch/csrc/THP.h>
+#include <torch/csrc/utils/pybind.h>
+#include <torch/csrc/utils/pycfunction_helpers.h>
+#include <torch/csrc/utils/python_arg_parser.h>
+
+#include <c10/core/Event.h>
+#include <c10/core/Stream.h>
+
+#include <c10/core/DeviceType.h>
+#include <c10/core/impl/DeviceGuardImplInterface.h>
+#include <structmember.h>
+#include <string>
+
+PyObject* THPEventClass = nullptr;
+
+static PyObject* THPEvent_pynew(
+    PyTypeObject* type,
+    PyObject* args,
+    PyObject* kwargs) {
+  HANDLE_TH_ERRORS
+
+  unsigned char enable_timing = 0;
+  unsigned char blocking = 0;
+  unsigned char interprocess = 0;
+
+  static torch::PythonArgParser parser({
+      "Event(Device device=None, *, bool enable_timing=True, bool blocking=False, bool interprocess=False)",
+  });
+
+  torch::ParsedArgs<4> parsed_args;
+  auto r = parser.parse(args, kwargs, parsed_args);
+
+  auto device = r.deviceOptional(0);
+
+  if (!device.has_value()) {
+    device = at::Device(at::getAccelerator(false).value_or(at::kCPU));
+  }
+  enable_timing = r.toBoolWithDefault(1, true);
+  blocking = r.toBoolWithDefault(2, false);
+  interprocess = r.toBoolWithDefault(3, false);
+
+  THPObjectPtr ptr(type->tp_alloc(type, 0));
+  if (!ptr) {
+    TORCH_CHECK(ptr, "Failed to allocate memory for Event");
+  }
+
+  THPEvent* self = (THPEvent*)ptr.get();
+
+  // TODO: blocking and interprocess are not supported yet. To support them, the
+  // flag system of c10::Event needs to be refactored. C10::Event should also
+  // provide a generic constructor to support blocking and interprocess events.
+  (void)blocking;
+  (void)interprocess;
+
+  new (&self->event) c10::Event(
+      device->type(),
+      (enable_timing ? c10::EventFlag::PYTORCH_DEFAULT
+                     : c10::EventFlag::BACKEND_DEFAULT));
+
+  return (PyObject*)ptr.release();
+  END_HANDLE_TH_ERRORS
+}
+
+PyObject* THPEvent_new(c10::DeviceType device_type, c10::EventFlag flag) {
+  auto type = (PyTypeObject*)&THPEventType;
+  auto self = THPObjectPtr{type->tp_alloc(type, 0)};
+  TORCH_CHECK(self, "Failed to allocate memory for Event");
+  auto self_ = reinterpret_cast<THPEvent*>(self.get());
+  new (&self_->event) c10::Event(device_type, flag);
+  return self.release();
+}
+
+static void THPEvent_dealloc(THPEvent* self) {
+  {
+    pybind11::gil_scoped_release no_gil{};
+    self->event.~Event();
+  }
+  Py_TYPE(self)->tp_free((PyObject*)self);
+}
+
+static PyObject* THPEvent_get_device(THPEvent* self, void* unused) {
+  HANDLE_TH_ERRORS
+  at::optional<at::Device> device = self->event.device();
+  if (!device) {
+    Py_RETURN_NONE;
+  }
+  return THPDevice_New(device.value());
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPEvent_record(
+    PyObject* _self,
+    PyObject* args,
+    PyObject* kwargs) {
+  HANDLE_TH_ERRORS
+  auto self = (THPEvent*)_self;
+  PyObject* _stream = Py_None;
+  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
+  constexpr const char* accepted_args[] = {"stream", nullptr};
+  if (!PyArg_ParseTupleAndKeywords(
+          args,
+          kwargs,
+          "|O",
+          // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+          const_cast<char**>(accepted_args),
+          &_stream)) {
+    TORCH_WARN("Parsing THPEvent_record arg fails");
+    return nullptr;
+  }
+  if (_stream != Py_None) {
+    auto stream = (THPStream*)_stream;
+    self->event.record(c10::Stream::unpack3(
+        stream->stream_id,
+        stream->device_index,
+        static_cast<c10::DeviceType>(stream->device_type)));
+  } else {
+    c10::impl::VirtualGuardImpl impl{
+        static_cast<c10::DeviceType>(self->event.device_type())};
+    self->event.record(impl.getStream(impl.getDevice()));
+  }
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPEvent_from_ipc_handle(
+    PyObject* _type,
+    PyObject* args,
+    PyObject* kwargs) {
+  HANDLE_TH_ERRORS
+  auto type = (PyTypeObject*)_type;
+
+  static torch::PythonArgParser parser({
+      "from_ipc_handle(Device device, std::string ipc_handle)",
+  });
+  torch::ParsedArgs<2> parsed_args;
+  auto r = parser.parse(args, kwargs, parsed_args);
+
+  at::Device device = r.device(0);
+  std::string handle_string = r.string(1);
+  TORCH_CHECK_NOT_IMPLEMENTED(
+      false,
+      "torch.Event ipc is not supported yet, please open an issue if you need this!");
+  THPObjectPtr ptr(type->tp_alloc(type, 0));
+  if (!ptr) {
+    return nullptr;
+  }
+  THPEvent* self = (THPEvent*)ptr.get();
+
+  // TODO: for constructing event from ipc handle, the c10::Event needs to have
+  // more general constructor to achieve that.
+  new (&self->event) c10::Event(device.type(), c10::EventFlag::PYTORCH_DEFAULT);
+
+  return (PyObject*)ptr.release();
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPEvent_ipc_handle(PyObject* _self, PyObject* noargs) {
+  HANDLE_TH_ERRORS
+  auto self = (THPEvent*)_self;
+  (void)self;
+  TORCH_CHECK_NOT_IMPLEMENTED(
+      false,
+      "torch.Event ipc is not supported yet, please open an issue if you need this!");
+  std::string handle = "0";
+  return PyBytes_FromStringAndSize((const char*)&handle, sizeof(handle));
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPEvent_wait(
+    PyObject* _self,
+    PyObject* args,
+    PyObject* kwargs) {
+  HANDLE_TH_ERRORS {
+    auto self = (THPEvent*)_self;
+    PyObject* _stream = Py_None;
+    // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
+    constexpr const char* accepted_args[] = {"stream", nullptr};
+    if (!PyArg_ParseTupleAndKeywords(
+            args,
+            kwargs,
+            "|O",
+            // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+            const_cast<char**>(accepted_args),
+            &_stream)) {
+      TORCH_WARN("Parsing THPEvent_wait arg fails");
+      return nullptr;
+    }
+    if (_stream != Py_None) {
+      auto stream = (THPStream*)_stream;
+      self->event.block(c10::Stream::unpack3(
+          stream->stream_id,
+          stream->device_index,
+          static_cast<c10::DeviceType>(stream->device_type)));
+    } else {
+      c10::impl::VirtualGuardImpl impl{
+          static_cast<c10::DeviceType>(self->event.device_type())};
+      self->event.block(impl.getStream(impl.getDevice()));
+    }
+  }
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPEvent_query(PyObject* _self, PyObject* noargs) {
+  HANDLE_TH_ERRORS
+  auto self = (THPEvent*)_self;
+  return PyBool_FromLong(self->event.query());
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPEvent_elapsed_time(PyObject* _self, PyObject* _other) {
+  HANDLE_TH_ERRORS
+  auto self = (THPEvent*)_self;
+  auto other = (THPEvent*)_other;
+  return PyFloat_FromDouble(self->event.elapsedTime(other->event));
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPEvent_synchronize(PyObject* _self, PyObject* noargs) {
+  HANDLE_TH_ERRORS {
+    pybind11::gil_scoped_release no_gil{};
+    auto self = (THPEvent*)_self;
+    self->event.synchronize();
+  }
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPEvent_evend_id(PyObject* _self, PyObject* noargs) {
+  HANDLE_TH_ERRORS
+  auto self = (THPEvent*)_self;
+  return PyLong_FromVoidPtr(self->event.eventId());
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPEvent_repr(THPEvent* self) {
+  HANDLE_TH_ERRORS
+  return THPUtils_packString(
+      "torch.Event device_type=" +
+      c10::DeviceTypeName(
+          static_cast<c10::DeviceType>(self->event.device_type()), true) +
+      ", device_index=" + std::to_string(self->event.device_index()) +
+      ", event_flag=" +
+      std::to_string(static_cast<int64_t>(self->event.flag())) + ", event_id=" +
+      std::to_string(reinterpret_cast<int64_t>(self->event.eventId())));
+  END_HANDLE_TH_ERRORS
+}
+
+// NOLINTNEXTLINE(*c-arrays*, *global-variables)
+static struct PyGetSetDef THPEvent_properties[] = {
+    {"device", (getter)THPEvent_get_device, nullptr, nullptr, nullptr},
+    {"event_id", (getter)THPEvent_evend_id, nullptr, nullptr, nullptr},
+    {nullptr}};
+
+// NOLINTNEXTLINE(*c-arrays*, *global-variables)
+static PyMethodDef THPEvent_methods[] = {
+    {(char*)"from_ipc_handle",
+     castPyCFunctionWithKeywords(THPEvent_from_ipc_handle),
+     METH_CLASS | METH_VARARGS | METH_KEYWORDS,
+     nullptr},
+    {(char*)"record",
+     castPyCFunctionWithKeywords(THPEvent_record),
+     METH_VARARGS | METH_KEYWORDS,
+     nullptr},
+    {(char*)"wait",
+     castPyCFunctionWithKeywords(THPEvent_wait),
+     METH_VARARGS | METH_KEYWORDS,
+     nullptr},
+    {(char*)"query", THPEvent_query, METH_NOARGS, nullptr},
+    {(char*)"elapsed_time", THPEvent_elapsed_time, METH_O, nullptr},
+    {(char*)"synchronize", THPEvent_synchronize, METH_NOARGS, nullptr},
+    {(char*)"ipc_handle", THPEvent_ipc_handle, METH_NOARGS, nullptr},
+    {nullptr}};
+
+PyTypeObject THPEventType = {
+    PyVarObject_HEAD_INIT(nullptr, 0) "torch.Event", /* tp_name */
+    sizeof(THPEvent), /* tp_basicsize */
+    0, /* tp_itemsize */
+    (destructor)THPEvent_dealloc, /* tp_dealloc */
+    0, /* tp_vectorcall_offset */
+    nullptr, /* tp_getattr */
+    nullptr, /* tp_setattr */
+    nullptr, /* tp_reserved */
+    (reprfunc)THPEvent_repr, /* tp_repr */
+    nullptr, /* tp_as_number */
+    nullptr, /* tp_as_sequence */
+    nullptr, /* tp_as_mapping */
+    nullptr, /* tp_hash  */
+    nullptr, /* tp_call */
+    nullptr, /* tp_str */
+    nullptr, /* tp_getattro */
+    nullptr, /* tp_setattro */
+    nullptr, /* tp_as_buffer */
+    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
+    nullptr, /* tp_doc */
+    nullptr, /* tp_traverse */
+    nullptr, /* tp_clear */
+    nullptr, /* tp_richcompare */
+    0, /* tp_weaklistoffset */
+    nullptr, /* tp_iter */
+    nullptr, /* tp_iternext */
+    THPEvent_methods, /* tp_methods */
+    nullptr, /* tp_members */
+    THPEvent_properties, /* tp_getset */
+    nullptr, /* tp_base */
+    nullptr, /* tp_dict */
+    nullptr, /* tp_descr_get */
+    nullptr, /* tp_descr_set */
+    0, /* tp_dictoffset */
+    nullptr, /* tp_init */
+    nullptr, /* tp_alloc */
+    THPEvent_pynew, /* tp_new */
+};
+
+void THPEvent_init(PyObject* module) {
+  THPEventClass = (PyObject*)&THPEventType;
+  if (PyType_Ready(&THPEventType) < 0) {
+    throw python_error();
+  }
+  Py_INCREF(&THPEventType);
+  if (PyModule_AddObject(module, "Event", (PyObject*)&THPEventType) < 0) {
+    throw python_error();
+  }
+}
diff --git a/torch/csrc/Event.h b/torch/csrc/Event.h
new file mode 100644
index 0000000..745610d
--- /dev/null
+++ b/torch/csrc/Event.h
@@ -0,0 +1,21 @@
+#ifndef THP_EVENT_INC
+#define THP_EVENT_INC
+
+#include <c10/core/Event.h>
+#include <torch/csrc/python_headers.h>
+
+struct TORCH_API THPEvent {
+  PyObject_HEAD c10::Event event;
+};
+extern PyObject* THPEventClass;
+TORCH_API extern PyTypeObject THPEventType;
+
+TORCH_API void THPEvent_init(PyObject* module);
+TORCH_API PyObject* THPEvent_new(
+    c10::DeviceType device_type,
+    c10::EventFlag flag);
+inline bool THPEvent_Check(PyObject* obj) {
+  return THPEventClass && PyObject_IsInstance(obj, THPEventClass);
+}
+
+#endif // THP_EVENT_INC
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index 9343a48..8aff730 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -39,6 +39,7 @@
 #include <torch/csrc/Device.h>
 #include <torch/csrc/Dtype.h>
 #include <torch/csrc/DynamicTypes.h>
+#include <torch/csrc/Event.h>
 #include <torch/csrc/Generator.h>
 #include <torch/csrc/Layout.h>
 #include <torch/csrc/MemoryFormat.h>
@@ -1603,6 +1604,7 @@
   THPQScheme_init(module);
   THPDevice_init(module);
   THPStream_init(module);
+  THPEvent_init(module);
   ASSERT_TRUE(THPVariable_initModule(module));
   ASSERT_TRUE(THPFunction_initModule(module));
   ASSERT_TRUE(THPEngine_initModule(module));
diff --git a/torch/csrc/Stream.cpp b/torch/csrc/Stream.cpp
index bd8abb0..06dac51 100644
--- a/torch/csrc/Stream.cpp
+++ b/torch/csrc/Stream.cpp
@@ -1,10 +1,19 @@
 #include <pybind11/pybind11.h>
 #include <torch/csrc/Device.h>
+#include <torch/csrc/Event.h>
+#include <torch/csrc/Stream.h>
 #include <torch/csrc/THP.h>
 #include <torch/csrc/utils/pybind.h>
+#include <torch/csrc/utils/pycfunction_helpers.h>
 #include <torch/csrc/utils/python_arg_parser.h>
 
+#include <c10/core/DeviceGuard.h>
+#include <c10/core/Stream.h>
+#include <c10/core/impl/DeviceGuardImplInterface.h>
+#include <c10/util/Exception.h>
+#include <c10/util/hash.h>
 #include <structmember.h>
+#include <cstdint>
 
 PyTypeObject* THPStreamClass = nullptr;
 
@@ -13,22 +22,53 @@
     PyObject* args,
     PyObject* kwargs) {
   HANDLE_TH_ERRORS
-  int64_t stream_id = 0;
-  int64_t device_index = 0;
+
+  int64_t stream_id = -1;
   int64_t device_type = 0;
-  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
-  constexpr const char* kwlist[] = {
-      "stream_id", "device_index", "device_type", nullptr};
-  if (!PyArg_ParseTupleAndKeywords(
-          args,
-          kwargs,
-          "|LLL",
-          // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
-          const_cast<char**>(kwlist),
-          &stream_id,
-          &device_index,
-          &device_type)) {
-    return nullptr;
+  int64_t device_index = 0;
+  int64_t priority = 0;
+
+  static torch::PythonArgParser parser({
+      "Steram(Device device=None, *, int64_t priority=0)",
+      "Stream(int64_t stream_id, int64_t device_index, int64_t device_type, *, int64_t priority=0)",
+  });
+
+  torch::ParsedArgs<4> parsed_args;
+  auto r = parser.parse(args, kwargs, parsed_args);
+
+  std::unique_ptr<c10::DeviceGuard> device_guard_ptr;
+
+  if (r.idx == 0) {
+    auto default_accelerator = at::getAccelerator(false);
+    auto device = r.deviceOptional(0);
+    if (device.has_value()) {
+      device_type = static_cast<int64_t>(device->type());
+      device_index = static_cast<int64_t>(device->index());
+      // Initialize device guard if device is not None.
+      device_guard_ptr = std::make_unique<c10::DeviceGuard>(device.value());
+    } else {
+      // If device is None, we will use the current accelerator and index.
+      // If the current accelerator is not set, we will use the CPU as device
+      // type.
+      device_type = static_cast<int64_t>(
+          default_accelerator.value_or(c10::DeviceType::CPU));
+      c10::impl::VirtualGuardImpl impl{
+          static_cast<c10::DeviceType>(device_type)};
+      const auto current_device = impl.getDevice();
+      device_index = current_device.index();
+    }
+    priority = r.toInt64WithDefault(1, 0);
+  } else if (r.idx == 1) {
+    stream_id = r.toInt64WithDefault(0, -1);
+    device_index = r.toInt64WithDefault(1, 0);
+    device_type =
+        r.toInt64WithDefault(2, static_cast<int64_t>(c10::DeviceType::CPU));
+    priority = r.toInt64WithDefault(3, 0);
+  } else {
+    TORCH_CHECK(
+        false,
+        "parse stream arg fails please check the usage: ",
+        parser.get_signatures());
   }
 
   THPObjectPtr ptr(type->tp_alloc(type, 0));
@@ -37,9 +77,29 @@
   }
 
   THPStream* self = (THPStream*)ptr.get();
-  self->stream_id = stream_id;
-  self->device_index = device_index;
-  self->device_type = device_type;
+
+  // If torch.Stream is not created from existing Stream, then create a new one.
+  // It requires other device backends override getNewStream method. How the new
+  // stream is created is backend specific. Backend should be able to correctly
+  // manage the lifetime of streams.
+  c10::optional<c10::Stream> stream_opt;
+  if (r.idx == 0) {
+    c10::impl::VirtualGuardImpl impl{static_cast<c10::DeviceType>(device_type)};
+    stream_opt = impl.getNewStream(
+        c10::Device(static_cast<c10::DeviceType>(device_type), device_index),
+        static_cast<int>(priority));
+  } else {
+    stream_opt = c10::Stream::unpack3(
+        stream_id,
+        static_cast<c10::DeviceIndex>(device_index),
+        static_cast<c10::DeviceType>(device_type));
+  }
+
+  TORCH_CHECK(stream_opt.has_value(), "Failed to create stream");
+  self->stream_id = static_cast<int64_t>(stream_opt->id());
+  self->device_index = static_cast<int64_t>(stream_opt->device_index());
+  self->device_type = static_cast<int64_t>(stream_opt->device_type());
+
   return (PyObject*)ptr.release();
   END_HANDLE_TH_ERRORS
 }
@@ -73,15 +133,167 @@
   END_HANDLE_TH_ERRORS
 }
 
+static PyObject* THPStream_query(PyObject* _self, PyObject* noargs) {
+  HANDLE_TH_ERRORS
+  auto self = (THPStream*)_self;
+
+  return PyBool_FromLong(c10::Stream::unpack3(
+                             self->stream_id,
+                             self->device_index,
+                             static_cast<c10::DeviceType>(self->device_type))
+                             .query());
+
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPStream_synchronize(PyObject* _self, PyObject* noargs) {
+  HANDLE_TH_ERRORS {
+    pybind11::gil_scoped_release no_gil;
+    auto self = (THPStream*)_self;
+
+    c10::Stream::unpack3(
+        self->stream_id,
+        self->device_index,
+        static_cast<c10::DeviceType>(self->device_type))
+        .synchronize();
+  }
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPStream_wait_event(PyObject* _self, PyObject* _event) {
+  HANDLE_TH_ERRORS {
+    auto self = (THPStream*)_self;
+    auto event = (THPEvent*)_event;
+    c10::Stream::unpack3(
+        self->stream_id,
+        self->device_index,
+        static_cast<c10::DeviceType>(self->device_type))
+        .wait(event->event);
+  }
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPStream_wait_stream(PyObject* _self, PyObject* _other) {
+  HANDLE_TH_ERRORS {
+    auto self = (THPStream*)_self;
+    auto other_stream = (THPStream*)_other;
+    c10::Event new_event(
+        static_cast<c10::DeviceType>(other_stream->device_type),
+        c10::EventFlag::PYTORCH_DEFAULT);
+    new_event.record(c10::Stream::unpack3(
+        other_stream->stream_id,
+        other_stream->device_index,
+        static_cast<c10::DeviceType>(other_stream->device_type)));
+    c10::Stream::unpack3(
+        self->stream_id,
+        self->device_index,
+        static_cast<c10::DeviceType>(self->device_type))
+        .wait(new_event);
+  }
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPStream_record_event(
+    PyObject* _self,
+    PyObject* args,
+    PyObject* kwargs) {
+  HANDLE_TH_ERRORS
+  auto self = (THPStream*)_self;
+  PyObject* _new_event;
+  PyObject* _event = Py_None;
+
+  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
+  constexpr const char* accepted_args[] = {"event", nullptr};
+  if (!PyArg_ParseTupleAndKeywords(
+          args,
+          kwargs,
+          "|O",
+          // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
+          const_cast<char**>(accepted_args),
+          &_event)) {
+    TORCH_CHECK(false, "parse record_event arg fails");
+  }
+  if (_event != Py_None) {
+    // Increase the refcount of the event to avoid it being destroyed.
+    Py_INCREF(_event);
+    _new_event = _event;
+  } else {
+    _new_event = THPEvent_new(
+        static_cast<c10::DeviceType>(self->device_type),
+        c10::EventFlag::PYTORCH_DEFAULT);
+  }
+  auto new_event = (THPEvent*)_new_event;
+  TORCH_CHECK(new_event, "event must not be null");
+  new_event->event.record(c10::Stream::unpack3(
+      self->stream_id,
+      self->device_index,
+      static_cast<c10::DeviceType>(self->device_type)));
+  return (PyObject*)new_event;
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPStream_repr(THPStream* self) {
+  HANDLE_TH_ERRORS
+  return THPUtils_packString(
+      "torch.Stream device_type=" +
+      c10::DeviceTypeName(
+          static_cast<c10::DeviceType>(self->device_type), true) +
+      ", device_index=" + std::to_string(self->device_index) +
+      ", stream_id=" + std::to_string(self->stream_id));
+  END_HANDLE_TH_ERRORS
+}
+
+static Py_hash_t THPStream_hash(THPStream* self) {
+  return static_cast<long>(at::hash_combine(
+      self->device_type,
+      (at::hash_combine(self->stream_id, self->device_index))));
+}
+
 static PyObject* THPStream_eq(THPStream* self, THPStream* other) {
   HANDLE_TH_ERRORS
   return PyBool_FromLong(
-      self->stream_id == other->stream_id &&
-      self->device_index == other->device_index &&
-      self->device_type == other->device_type);
+      (self->stream_id == other->stream_id) &&
+      (self->device_index == other->device_index) &&
+      (self->device_type == other->device_type));
   END_HANDLE_TH_ERRORS
 }
 
+static PyObject* THPStream_ne(THPStream* self, THPStream* other) {
+  HANDLE_TH_ERRORS
+  return PyBool_FromLong(
+      (self->stream_id != other->stream_id) ||
+      (self->device_index != other->device_index) ||
+      (self->device_type != other->device_type));
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* THPStream_richcompare(
+    PyObject* self,
+    PyObject* other,
+    int op) {
+  PyObject* result = NULL;
+  if (other == Py_None) {
+    result = Py_False;
+  } else {
+    switch (op) {
+      case Py_EQ:
+        result = THPStream_eq((THPStream*)self, (THPStream*)other);
+        break;
+      case Py_NE:
+        result = THPStream_ne((THPStream*)self, (THPStream*)other);
+        break;
+      default:
+        result = Py_False;
+        break;
+    }
+  }
+  Py_XINCREF(result);
+  return result;
+}
+
 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
 static struct PyMemberDef THPStream_members[] = {
     {"stream_id",
@@ -108,6 +320,14 @@
 
 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
 static PyMethodDef THPStream_methods[] = {
+    {"query", THPStream_query, METH_NOARGS, nullptr},
+    {"synchronize", THPStream_synchronize, METH_NOARGS, nullptr},
+    {"wait_event", THPStream_wait_event, METH_O, nullptr},
+    {"wait_stream", THPStream_wait_stream, METH_O, nullptr},
+    {"record_event",
+     castPyCFunctionWithKeywords(THPStream_record_event),
+     METH_VARARGS | METH_KEYWORDS,
+     nullptr},
     {"__eq__", (PyCFunction)THPStream_eq, METH_O, nullptr},
     {nullptr}};
 
@@ -120,11 +340,11 @@
     nullptr, /* tp_getattr */
     nullptr, /* tp_setattr */
     nullptr, /* tp_reserved */
-    nullptr, /* tp_repr */
+    (reprfunc)THPStream_repr, /* tp_repr */
     nullptr, /* tp_as_number */
     nullptr, /* tp_as_sequence */
     nullptr, /* tp_as_mapping */
-    nullptr, /* tp_hash  */
+    (hashfunc)THPStream_hash, /* tp_hash  */
     nullptr, /* tp_call */
     nullptr, /* tp_str */
     nullptr, /* tp_getattro */
@@ -135,7 +355,7 @@
     nullptr, /* tp_doc */
     nullptr, /* tp_traverse */
     nullptr, /* tp_clear */
-    nullptr, /* tp_richcompare */
+    THPStream_richcompare, /* tp_richcompare */
     0, /* tp_weaklistoffset */
     nullptr, /* tp_iter */
     nullptr, /* tp_iternext */