Adding JIT support for cuda streams and events (#48020)

Summary:
=======

This PR addresses the following:

 * Adds JIT support for CUDA Streams
 * Adds JIT support for CUDA Events
 * Adds JIT support for CUDA Stream context manager

Testing:
======

python test/test_jit.py -v TestCUDA

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48020

Reviewed By: navahgar

Differential Revision: D25725749

Pulled By: nikithamalgifb

fbshipit-source-id: b0addeb49630f8f0c430ed7badeca43bb9d2535c
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index 8065300..f99dc3c 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -17,6 +17,7 @@
 #define FORALL_NS_SYMBOLS(_)         \
   _(namespaces, prim)                \
   _(namespaces, aten)                \
+  _(namespaces, cuda)                \
   _(namespaces, onnx)                \
   _(namespaces, attr)                \
   _(namespaces, scope)               \
@@ -284,6 +285,9 @@
   _(aten, zero_)                     \
   _(aten, fill_)                     \
   _(aten, masked_fill_)              \
+  _(cuda, _set_device)               \
+  _(cuda, set_stream)                \
+  _(cuda, _current_device)           \
   _(aten, swapaxes)                  \
   _(aten, swapaxes_)                 \
   _(aten, swapdims)                  \
@@ -383,6 +387,7 @@
 #define FORALL_NS_SYMBOLS(_) \
   _(namespaces, prim)              \
   _(namespaces, aten)              \
+  _(namespaces, cuda)              \
   _(namespaces, onnx)              \
   _(namespaces, attr)              \
   _(namespaces, scope)             \
@@ -453,6 +458,7 @@
   // (and if it's not, you should add it to the built-ins list above.)
   static Symbol attr(const std::string & s);
   static Symbol aten(const std::string & s);
+  static Symbol cuda(const std::string & s);
   static Symbol onnx(const std::string & s);
   static Symbol prim(const std::string & s);
   static Symbol user(const std::string & s);
@@ -463,6 +469,7 @@
 
   bool is_attr() const;
   bool is_aten() const;
+  bool is_cuda() const;
   bool is_prim() const;
   bool is_onnx() const;
   bool is_user() const;
@@ -523,6 +530,7 @@
 
 inline Symbol Symbol::attr(const std::string & s) { return Symbol::fromQualString("attr::" + s); }
 inline Symbol Symbol::aten(const std::string & s)  { return Symbol::fromQualString("aten::" + s); }
+inline Symbol Symbol::cuda(const std::string & s)  { return Symbol::fromQualString("cuda::" + s); }
 inline Symbol Symbol::onnx(const std::string & s)  { return Symbol::fromQualString("onnx::" + s); }
 inline Symbol Symbol::prim(const std::string & s)  { return Symbol::fromQualString("prim::" + s); }
 inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); }
@@ -531,6 +539,7 @@
 inline Symbol Symbol::dimname(const std::string & s) { return Symbol::fromQualString("dimname::" + s); }
 inline bool Symbol::is_attr() const { return ns() == namespaces::attr; }
 inline bool Symbol::is_aten() const { return ns() == namespaces::aten; }
+inline bool Symbol::is_cuda() const { return ns() == namespaces::cuda; }
 inline bool Symbol::is_prim() const { return ns() == namespaces::prim; }
 inline bool Symbol::is_onnx() const { return ns() == namespaces::onnx; }
 inline bool Symbol::is_user() const { return ns() == namespaces::user; }
diff --git a/test/cpp/jit/test_save_load.cpp b/test/cpp/jit/test_save_load.cpp
index 2e59358..e102a6f 100644
--- a/test/cpp/jit/test_save_load.cpp
+++ b/test/cpp/jit/test_save_load.cpp
@@ -120,5 +120,33 @@
   }
 }
 
+TEST(SerializationTest, TestJitStream_CUDA) {
+  torch::jit::Module model;
+  std::vector<torch::jit::IValue> inputs;
+  // Deserialize the ScriptModule from a file using torch::jit::load().
+  // Load the scripted model. This should have been generated by tests_setup.py
+  // Refer: TorchSaveJitStream_CUDA in test/cpp/jit/tests_setup.py
+  model = torch::jit::load("saved_stream_model.pt");
+
+  auto output = model.forward(inputs);
+  auto list_of_elements = output.toTuple()->elements();
+  auto is_stream_s = list_of_elements[0].toBool();
+
+  // a,b: These are the two input tensors
+  // c: This is output tensor generated by the operation torch.cat(a,b)
+  auto a = list_of_elements[1].toTensor();
+  auto b = list_of_elements[2].toTensor();
+  auto c = list_of_elements[3].toTensor();
+  // op: this is used to verify if the cat operation produced the same results
+  // as that on the GPU with torch.cat
+  auto op = at::cat({a, b}, 0);
+
+  // Check if the stream is set
+  ASSERT_TRUE(is_stream_s);
+  // Check if the sizes of the outputs (op and c) is same on the GPU and CPU
+  ASSERT_EQ(op.sizes(), c.sizes());
+  // Check if both the output tensors are equal
+  ASSERT_TRUE(op.equal(c));
+}
 } // namespace jit
 } // namespace torch
diff --git a/test/cpp/jit/tests_setup.py b/test/cpp/jit/tests_setup.py
index 68871d1..928a06d 100644
--- a/test/cpp/jit/tests_setup.py
+++ b/test/cpp/jit/tests_setup.py
@@ -63,11 +63,38 @@
 
         torch.save(value, self.path, _use_new_zipfile_serialization=False)
 
+class TorchSaveJitStream_CUDA(FileSetup):
+    path = 'saved_stream_model.pt'
+
+    def setup(self):
+        if not torch.cuda.is_available():
+            return
+
+        class Model(torch.nn.Module):
+            def forward(self):
+                device_index = torch.cuda._current_device()
+                s = torch.jit.cuda.Stream(device_index, 0)
+                a = torch.rand(3, 4, device="cuda")
+                b = torch.rand(3, 4, device="cuda")
+
+                with torch.jit.cuda.stream(s):
+                    is_stream_s = torch.cuda.current_stream(s.device_index()).id() == s.id()
+                    c = torch.cat((a, b), 0).to("cuda")
+                s.synchronize()
+                return is_stream_s, a, b, c
+
+        model = Model()
+
+        # Script the model and save
+        script_model = torch.jit.script(model)
+        torch.jit.save(script_model, self.path)
+
 
 tests = [
     EvalModeForLoadedModule(),
     SerializationInterop(),
     TorchSaveError(),
+    TorchSaveJitStream_CUDA()
 ]
 
 def setup():
diff --git a/test/jit/test_cuda.py b/test/jit/test_cuda.py
new file mode 100644
index 0000000..f7af8e3
--- /dev/null
+++ b/test/jit/test_cuda.py
@@ -0,0 +1,476 @@
+import os
+import sys
+import gc
+import unittest
+
+import torch
+from typing import NamedTuple
+from torch.testing._internal.jit_utils import JitTestCase
+from torch.testing._internal.common_utils import skipIfRocm, skipCUDANonDefaultStreamIf
+
+# Make the helper files in test/ importable
+pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.append(pytorch_test_dir)
+
+# Check if GPU is available
+TEST_CUDA = torch.cuda.is_available()
+# Check if multiple GPU's are available
+TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
+
+# If GPU is not available, then do not run the tests
+if not TEST_CUDA:
+    print('CUDA not available, skipping tests', file=sys.stderr)
+    JitTestCase = object  # noqa: F811
+
+TEST_LARGE_TENSOR = TEST_CUDA
+
+# If GPU is available, then initialize the cuda context and check
+# if there is memory available to allocate for LARGE Tensors.
+if TEST_CUDA:
+    torch.ones(1).cuda()  # initialize cuda context
+    TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 5e9
+
+if __name__ == "__main__":
+    raise RuntimeError(
+        "This test file is not meant to be run directly, use:\n\n"
+        "\tpython test/test_jit.py TESTNAME\n\n"
+        "instead."
+    )
+
+class TestCUDA(JitTestCase):
+    """
+    A suite of tests for the CUDA API in TorchScript.
+    """
+    def setUp(self):
+        super(TestCUDA, self).setUp()
+
+    def tearDown(self):
+        gc.collect()
+        torch.cuda.empty_cache()
+        super(TestCUDA, self).tearDown()
+
+    @skipIfRocm
+    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
+    def test_current_stream(self):
+        # Test current stream on the device and check if the stream device index
+        # matches with the device ID
+        @torch.jit.script
+        def fn():
+            device_index = torch.cuda._current_device()
+            s0 = torch.cuda.current_stream(device_index)
+            s1 = torch.cuda.current_stream(1)
+            s2 = torch.cuda.current_stream(0)
+
+            return s0.device_index(), s1.device_index(), s2.device_index()
+
+        d0, d1, d2 = fn()
+
+        # By default, the current device ID is 0.
+        self.assertEqual(0, d0)
+        self.assertEqual(1, d1)
+        self.assertEqual(0, d2)
+        self.assertEqual(d0, d2)
+
+    @skipIfRocm
+    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
+    @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
+    @skipCUDANonDefaultStreamIf(True)
+    def test_streams_and_events(self):
+        # This test checks for the default stream ID is set to 0 on the device
+        @torch.jit.script
+        def test_default_streams():
+            s0 = torch.cuda.default_stream(0)
+            s1 = torch.cuda.default_stream(1)
+
+            d = torch.device('cuda:1')
+
+            # Check the current stream id and default id are same
+            # on the current device. The current device id by default is 0
+            s2 = torch.cuda.current_stream(0)
+            check_s2 = s2.id() == s0.id()
+            check_d0 = torch.cuda._current_device() == s2.device_index()
+
+            # Set the current device to d1 and check if the stream
+            # has been set to the default stream on d1
+            with torch.jit.cuda.device(d):
+                s3 = torch.cuda.current_stream(1)
+                check_s3 = s3.id() == s1.id()
+                check_d1 = torch.cuda._current_device() == s3.device_index()
+
+            # Check if the current device was reset to 0
+            is_device_d0 = torch.cuda._current_device() == s2.device_index()
+
+            return s0.device_index(), s1.device_index(), check_s2, check_s3, check_d0, check_d1, is_device_d0
+
+        d0, d1, check_s2, check_s3, check_d0, check_d1, is_device_d0 = test_default_streams()
+
+        self.assertEqual(d0, 0)
+        self.assertEqual(d1, 1)
+        self.assertTrue(check_s2)
+        self.assertTrue(check_s3)
+        self.assertTrue(check_d0)
+        self.assertTrue(check_d1)
+        self.assertTrue(is_device_d0)
+
+        # This test checks if the Stream Context manager is a no op
+        # when the stream is none for `with torch.jit.cuda.stream`
+        @torch.jit.script
+        def test_set_none_stream():
+            device_index = torch.cuda._current_device()
+            current_stream = torch.cuda.current_stream(device_index)
+            default_stream = torch.cuda.default_stream(device_index)
+
+            # When stream is none, check if this operation is a no-op
+            with torch.jit.cuda.stream(None):
+                cur_device_index = torch.cuda._current_device()
+                is_device_index_same = cur_device_index == device_index
+                is_current_stream_same = torch.cuda.current_stream(cur_device_index).id() == current_stream.id()
+                is_default_stream_same = torch.cuda.default_stream(device_index).id() == default_stream.id()
+
+            # Check if the device index, current stream and default streams have not changed
+            are_streams_same = is_device_index_same and is_current_stream_same and is_default_stream_same
+            return are_streams_same
+        self.assertTrue(test_set_none_stream())
+
+        # This test checks if the Device Context manager is a no op
+        # when the device is none for `with torch.jit.cuda.device`
+        @torch.jit.script
+        def test_set_device_none():
+            device_index = torch.cuda._current_device()
+            # When device is none, check if this operation is a no-op
+            with torch.jit.cuda.device(None):
+                # Check if the current device is the same
+                is_device_same = torch.cuda._current_device() == device_index
+            return is_device_same
+        self.assertTrue(test_set_device_none())
+
+        # Check if a CUDA JIT stream is created
+        # on the _current_device
+        @torch.jit.script
+        def test_simple_stream():
+            device_index = torch.cuda._current_device()
+            s = torch.jit.cuda.Stream(device_index, 0)
+            return device_index == s.device_index()
+
+        self.assertTrue(test_simple_stream(), "Could not create Stream!")
+
+        # Class used to store results for the test: test_get_stream.
+        class Result(NamedTuple):
+            t1 : torch.Tensor
+            t2 : torch.Tensor
+            is_current_and_default_stream_same : bool
+            is_default_and_user_stream_not_same : bool
+            is_stream_set : bool
+            is_stream_reset : bool
+            default_stream_query : bool
+            default_stream_id : int
+            user_stream_id : int
+
+        # The test aims at checking different stream proporties.
+        @torch.jit.script
+        def test_get_stream():
+            device_index = torch.cuda._current_device()
+            current_stream = torch.cuda.current_stream(device_index)
+            default_stream = torch.cuda.default_stream(device_index)
+            user_stream = torch.jit.cuda.Stream(device_index, 0)
+
+            # Check if the current and default streams are the same on the device
+            is_current_and_default_stream_same = current_stream.id() == default_stream.id()
+            # Check if user stream and default stream are not the same on the device
+            is_default_and_user_stream_not_same = default_stream.id() != user_stream.id()
+
+            with torch.jit.cuda.stream(user_stream):
+                is_stream_set = torch.cuda.current_stream(device_index).id() == user_stream.id()
+
+            # Check if the stream was reset to current_stream
+            is_stream_reset = torch.cuda.current_stream(device_index).id() == current_stream.id()
+
+            tensor1 = torch.rand(10000, 10000, device="cuda")
+            tensor2 = torch.mm(tensor1, tensor1).to("cuda")
+            default_stream.synchronize()
+            default_stream_query = default_stream.query()
+
+            # Capture all the results in the class Result
+            res = Result(
+                tensor1, tensor2, is_current_and_default_stream_same,
+                is_default_and_user_stream_not_same, is_stream_set,
+                is_stream_reset, default_stream_query, default_stream.id(), user_stream.id())
+            return res
+
+        result = test_get_stream()
+
+        self.assertEqual(torch.matmul(result.t1, result.t1), result.t2)
+        self.assertTrue(result.is_current_and_default_stream_same)
+        self.assertTrue(result.is_default_and_user_stream_not_same)
+        self.assertTrue(result.is_stream_set)
+        self.assertTrue(result.is_stream_reset)
+        self.assertTrue(result.default_stream_query)
+        self.assertEqual(result.default_stream_id, 0)  # Check if the default stream ID is always 0
+        self.assertNotEqual(result.user_stream_id, 0)  # Check if the user stream is always non zero
+
+        # Test the stream context manager. This test checks if the stream is switched
+        # to the user stream on using the stream context manager.
+        @torch.jit.script
+        def test_stream_context():
+            device_index = torch.cuda._current_device()
+            current_stream = torch.cuda.current_stream(device_index)
+            user_stream = torch.jit.cuda.Stream(device_index, 0)
+            A = torch.rand(1000, 1000, device="cuda")
+
+            with torch.jit.cuda.stream(user_stream):
+                check = torch.cuda.current_stream(device_index).id() == user_stream.id()
+                B = torch.mm(A, A).to("cuda")
+            # Wait for B to be computed
+            user_stream.synchronize()
+            # Check if the stream has been reset on the current device
+            is_stream_reset = torch.cuda.current_stream(device_index).id() == current_stream.id()
+
+            return A, B, check, is_stream_reset
+
+        A, B, is_stream_set, is_stream_reset = test_stream_context()
+        self.assertEqual(torch.matmul(A, A), B)
+        self.assertTrue(is_stream_set, "Error: Current stream was not set to user stream!")
+        self.assertTrue(is_stream_reset, "Error: The stream was not restored to previous stream!")
+
+        # Test multiple nested streams. Check if the operations are computed as expected on the streams
+        # This test has been adapted from the eager mode tests available at test/test_cuda.py
+        @torch.jit.script
+        def test_multiple_stream():
+            prev_device_index = torch.cuda._current_device()
+            prev_current_stream = torch.cuda.current_stream(prev_device_index)
+            s1 = torch.jit.cuda.Stream(0, 0)
+            s2 = torch.jit.cuda.Stream(1, 0)
+
+            A = torch.rand(1000, 1000, device="cuda")
+            B = torch.rand(1000, 1000, device="cuda")
+            with torch.jit.cuda.stream(s1):
+                C = torch.mm(A, A).to("cuda")
+                # Check if the stream and device have been set to s1
+                is_stream_s1 = torch.cuda.current_stream(s1.device_index()).id() == s1.id()
+                is_device_s1 = torch.cuda._current_device() == s1.device_index()
+                with torch.jit.cuda.stream(s2):
+                    # Check if the stream and device have been set to s2
+                    is_stream_s2 = torch.cuda.current_stream(s2.device_index()).id() == s2.id()
+                    is_device_s2 = torch.cuda._current_device() == s2.device_index()
+                    D = torch.mm(B, B).to("cuda")
+                # Check if the stream and device have been set to s1
+                is_stream_s1_after = torch.cuda.current_stream(s1.device_index()).id() == s1.id()
+                is_device_s1_after = torch.cuda._current_device() == s1.device_index()
+                # Wait for D to be computed
+                s2.synchronize()
+            # Wait for C to be computed on S1
+            s1.synchronize()
+
+            # Check if the stream and device has been restored to previous stream and device
+            is_device_current = torch.cuda._current_device() == prev_device_index
+            is_stream_current = torch.cuda.current_stream(prev_device_index).id() == prev_current_stream.id()
+
+            check_stream = is_stream_s1 and is_stream_s2 and is_stream_s1_after and is_stream_current
+            check_device = is_device_s1 and is_device_s2 and is_device_s1_after and is_device_current
+            return A, B, C, D, check_stream, check_device
+        A, B, C, D, check_stream, check_device = test_multiple_stream()
+
+        self.assertEqual(torch.matmul(A, A), C)
+        self.assertEqual(torch.matmul(B, B), D)
+        self.assertTrue(check_stream)
+        self.assertTrue(check_device)
+
+        # Test multiple streams waiting on each other for the operations to be completed.
+        @torch.jit.script
+        def test_data_dependency_between_streams():
+            device_index = torch.cuda._current_device()
+            prev_current_stream = torch.cuda.current_stream(device_index)
+            s1 = torch.jit.cuda.Stream(0, 0)
+            s2 = torch.jit.cuda.Stream(0, 0)
+            event = torch.jit.cuda.Event(False, False, False)
+
+            A = torch.rand(1000, 1000, device="cuda")
+            with torch.jit.cuda.stream(s1):
+                is_stream_s1 = torch.cuda.current_stream(device_index).id() == s1.id()
+                B = torch.mm(A, A).to("cuda")
+            s1.record_event(event)
+            # Check if the current_stream is reset
+            is_current_stream_1 = torch.cuda.current_stream(device_index).id() == prev_current_stream.id()
+            # Wait for ops on s1 to be computed
+            s2.wait_event(event)
+            with torch.jit.cuda.stream(s2):
+                is_stream_s2 = torch.cuda.current_stream(device_index).id() == s2.id()
+                C = torch.mm(B, B).to("cuda")
+            # Wait for C to be computed
+            s2.synchronize()
+            # Check if the current_stream is reset
+            is_current_stream_2 = torch.cuda.current_stream(device_index).id() == prev_current_stream.id()
+
+            check_stream = is_current_stream_1 and is_current_stream_2 and is_stream_s1 and is_stream_s2
+            return A, B, C, check_stream
+
+        A, B, C, check_stream = test_data_dependency_between_streams()
+        self.assertEqual(torch.matmul(A, A), B)
+        self.assertEqual(torch.matmul(B, B), C)
+        self.assertTrue(check_stream)
+
+        # Test a simple CUDA event. Test if the CUDA event was created successfully
+        @torch.jit.script
+        def test_simple_event():
+            e = torch.jit.cuda.Event(True, False, False)
+            return e is not None
+        self.assertTrue(test_simple_event(), "Could not create CUDA Event!")
+
+        # Record the CUDA event for operation torch.mm on the current stream
+        # and then test if the elapsed time is greater than 0. This test is also
+        # an adaption from eager mdoe CUDA tests available at test/test_cuda.py
+        @torch.jit.script
+        def test_event():
+            device_index = torch.cuda._current_device()
+            stream = torch.cuda.current_stream(device_index)
+            event = torch.jit.cuda.Event(True, False, False)
+            is_true_event_query = event.query()
+            start_event = torch.jit.cuda.Event(True, False, False)
+            stream.record_event(start_event)
+            tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
+            tensor2 = torch.mm(tensor1, tensor1).to("cuda")
+            stream.record_event(event)
+            event.synchronize()
+            is_again_true_event_query = event.query()
+
+            if not (is_true_event_query and is_again_true_event_query):
+                return -1.0
+            return start_event.elapsed_time(event)
+
+        self.assertGreater(test_event(), 0)
+
+        # Check for stream synchronization , when a large tensor multiplication is
+        # computed on the stream. The stream.query should be true once the synchroniztion is done
+        @torch.jit.script
+        def test_stream_synchronize() -> float:
+            device_index = torch.cuda._current_device()
+            s = torch.jit.cuda.Stream(device_index, 0)
+            e_tik = torch.jit.cuda.Event(True, False, False)
+            e_tok = torch.jit.cuda.Event(True, False, False)
+
+            e_tik.record(s)
+            tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
+            with torch.jit.cuda.stream(s):
+                tensor2 = torch.mm(tensor1, tensor1).to("cuda")
+            s.synchronize()
+            e_tok.record(s)
+            e_tok.synchronize()
+
+            if not s.query():
+                return -1.0
+
+            # not necessary to check e_tik and e_tok, as elapsed_time would throw
+            # exception if otherwise.
+            return e_tik.elapsed_time(e_tok)
+        self.assertGreater(test_stream_synchronize(), 0)
+
+        # Test event synchronization for the event that records a stream doing
+        # a large tensor multiplication. Check if the elapsed time is greater than 0
+        # and the stream.query evaluates to true.
+        @torch.jit.script
+        def test_event_synchronize() -> float:
+            device_index = torch.cuda._current_device()
+            s = torch.jit.cuda.Stream(device_index, 0)
+            e_tik = torch.jit.cuda.Event(True, False, False)
+            e_tok = torch.jit.cuda.Event(True, False, False)
+
+            e_tik.record(s)
+            tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
+            with torch.jit.cuda.stream(s):
+                tensor = torch.mm(tensor1, tensor1).to("cuda")
+            s.record_event(e_tok)
+            e_tok.synchronize()
+            s.synchronize()
+
+            if not s.query():
+                return -1.0
+
+            # not necessary to check e_tik and e_tok, as elapsed_time would throw
+            # exception if otherwise.
+            return e_tik.elapsed_time(e_tok)
+
+        self.assertGreater(test_event_synchronize(), 0)
+
+        # Test for event wait. Check if event waits for the all the operations on
+        # the stream to be done. Check for synchronizations and query on the streams
+        # and events. This test is adapted from eager mode tests for CUDA. Please refer
+        # test/test_cuda.py
+        @torch.jit.script
+        def test_event_wait() -> float:
+            device_index = torch.cuda._current_device()
+            s0 = torch.cuda.current_stream(device_index)
+            s1 = torch.jit.cuda.Stream(device_index, 0)
+            e_tik = torch.jit.cuda.Event(True, True, False)
+            e_tok = torch.jit.cuda.Event(True, True, False)
+
+            e_tik.record(s0)
+            tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
+            with torch.jit.cuda.stream(s0):
+                tensor2 = torch.mm(tensor1, tensor1).cuda()
+            e_sync = torch.jit.cuda.Event(True, False, False)
+            e_sync.record(torch.cuda.current_stream(device_index))
+            e_sync.wait(s1)
+            with torch.jit.cuda.stream(s1):
+                tensor3 = torch.rand(1000000000, 1000000000, device="cuda")
+                tensor4 = torch.mm(tensor3, tensor3).cuda()
+            s1.synchronize()
+            e_tok.record(torch.cuda.current_stream(device_index))
+            e_tok.synchronize()
+            s0.synchronize()
+
+            if not s0.query() or not s1.query() or not e_sync.query():
+                return -1.0
+
+            # not necessary to check e_tik and e_tok, as elapsed_time would throw
+            # exception if otherwise.
+            return e_tik.elapsed_time(e_tok)
+        self.assertGreater(test_event_wait(), 0)
+
+        # Test for stream wait_event. Checks if the stream waits on the event
+        @torch.jit.script
+        def test_wait_event():
+            d1 = torch.device('cuda:1')
+
+            with torch.jit.cuda.device(d1):
+                s0 = torch.cuda.current_stream(1)
+                tensor1 = torch.rand(1000000000, 1000000000, device="cuda")
+                tensor2 = torch.mm(tensor1, tensor1).to("cuda")
+                e0 = torch.jit.cuda.Event(False, False, False)
+                s0.record_event(e0)
+
+            s1 = torch.cuda.current_stream(0)
+            s1.wait_event(e0)
+            s1.synchronize()
+
+            return e0.query() and s0.query() and s1.query()
+        self.assertTrue(test_wait_event())
+
+        # Test if a scripted module with cuda streams can be saved, loaded and executed
+        def test_save_load(self):
+            class Model(torch.nn.Module):
+                def forward(self):
+                    device_index = torch.cuda._current_device()
+                    s = torch.jit.cuda.Stream(device_index, 0)
+                    a = torch.rand(3, 4, device="cuda")
+                    b = torch.rand(3, 4, device="cuda")
+
+                    with torch.jit.cuda.stream(s):
+                        is_stream_s = torch.cuda.current_stream(s.device_index()).id() == s.id()
+                        c = torch.cat((a, b), 0).cuda()
+                    s.synchronize()
+                    return is_stream_s, a, b, c
+
+            model = Model()
+
+            # Script the model and save
+            script_model = torch.jit.script(model)
+            is_stream_s, a, b, c = script_model()
+            # Verify if the output is correct
+            self.assertTrue(is_stream_s)
+            self.assertEqual(torch.cat((a, b), 0), c)
+
+            # Save and load scripted model
+            load_model = self.getExportImportCopy(script_model)
+            is_stream_s, a_load, b_load, c_load = load_model()
+            self.assertTrue(is_stream_s)
+            self.assertEqual(torch.cat((a_load, b_load), 0), c_load)
diff --git a/test/test_jit.py b/test/test_jit.py
index ff89429..a683a8e 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -35,6 +35,7 @@
 from jit.test_slice import TestSlice  # noqa: F401
 from jit.test_warn import TestWarn  # noqa: F401
 from jit.test_isinstance import TestIsinstance  # noqa: F401
+from jit.test_cuda import TestCUDA  # noqa: F401
 from jit.test_hash import TestHash  # noqa: F401
 
 # Torch
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index a214684..ec53f1d 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -408,6 +408,7 @@
     "torch/csrc/jit/codegen/cuda/transform_rfactor.cpp",
     "torch/csrc/jit/codegen/cuda/type.cpp",
     "torch/csrc/jit/tensorexpr/cuda_codegen.cpp",
+    "torch/csrc/jit/runtime/register_cuda_ops.cpp",
 ]
 
 libtorch_cuda_sources = libtorch_cuda_core_sources + [
diff --git a/torch/csrc/jit/cuda/cuda.h b/torch/csrc/jit/cuda/cuda.h
new file mode 100644
index 0000000..fa92ce2
--- /dev/null
+++ b/torch/csrc/jit/cuda/cuda.h
@@ -0,0 +1,179 @@
+#include <aten/src/ATen/cuda/CUDAEvent.h>
+#include <c10/core/Device.h>
+#include <c10/cuda/CUDAStream.h>
+#include <torch/custom_class.h>
+
+namespace torch {
+namespace jit {
+
+class CUDAEvent;
+// This class is a wrapper around c10::cuda::CUDAStream.
+// It is needed because TorchBind does not support all of the argument types
+// for c10::cuda::CUDAStream. For more details, please refer to
+// c10/cuda/CUDAStream.h.
+class CUDAStream final : public CustomClassHolder {
+ public:
+  CUDAStream(int64_t device = -1, int64_t priority = 0) {
+    constexpr int64_t PRIORITY_INDEX = 0;
+    stream_ = std::make_unique<c10::cuda::CUDAStream>(
+        c10::cuda::getStreamFromPool(priority < PRIORITY_INDEX, device));
+  }
+
+  CUDAStream(c10::cuda::CUDAStream s) {
+    stream_ = std::make_unique<c10::cuda::CUDAStream>(s);
+  }
+
+  bool query() {
+    return stream_->query();
+  }
+
+  c10::intrusive_ptr<CUDAEvent> recordEvent(
+      c10::intrusive_ptr<CUDAEvent> event);
+
+  void synchronize() {
+    stream_->synchronize();
+  }
+
+  void waitEvent(c10::intrusive_ptr<CUDAEvent> event);
+
+  void waitStream(c10::intrusive_ptr<CUDAStream> stream);
+
+  /// Get the CUDA device index that this stream is associated with.
+  int64_t device_index() const {
+    return stream_->device_index();
+  }
+
+  /// Get the full Device that this stream is associated with.  The Device
+  /// is guaranteed to be a CUDA device.
+  c10::Device device() const {
+    return stream_->device();
+  }
+
+  /// Return the stream ID corresponding to this particular stream.
+  int64_t id() const {
+    return stream_->id();
+  }
+
+  /// Pack a CUDAStream to uint64_t representation.
+  /// The CUDAStream can be unpacked using unpack().  The format of
+  /// the uint64_t is unspecified and may be changed.
+  int64_t pack() const {
+    return stream_->pack();
+  }
+
+ private:
+  std::unique_ptr<c10::cuda::CUDAStream> stream_;
+  friend class CUDAEvent;
+};
+
+// This class is a wrapper around at::cuda::CUDAStream.
+// It is needed because TorchBind does not support all of the argument types
+// for at::cuda::CUDAEvent. For more details, please refer to
+// aten/src/ATen/cuda/CUDAEvent.h.
+class CUDAEvent final : public CustomClassHolder {
+ public:
+  CUDAEvent(
+      bool enable_timing = false,
+      bool blocking = false,
+      bool interprocess = false) {
+    int flags = cudaEventDisableTiming;
+    if (enable_timing) {
+      flags = cudaEventDefault;
+    }
+    if (blocking) {
+      flags |= cudaEventBlockingSync;
+    }
+    if (interprocess) {
+      TORCH_CHECK(!enable_timing);
+      flags |= cudaEventInterprocess;
+    }
+
+    event_ = std::make_unique<at::cuda::CUDAEvent>(flags);
+  }
+
+  double elapsedTime(c10::intrusive_ptr<CUDAEvent> end) {
+    return event_->elapsed_time(*end->event_);
+  }
+
+  std::string ipcHandle() {
+    cudaIpcEventHandle_t handle;
+    event_->ipc_handle(&handle);
+    std::string str_handle((const char*)&handle, sizeof(handle));
+    return str_handle;
+  }
+
+  bool query() {
+    return event_->query();
+  }
+
+  void record(c10::intrusive_ptr<CUDAStream> stream);
+
+  void synchronize() {
+    event_->synchronize();
+  }
+  void wait(c10::intrusive_ptr<CUDAStream> stream);
+
+ private:
+  void recordInternal(CUDAStream* stream);
+  std::unique_ptr<at::cuda::CUDAEvent> event_;
+
+  friend class CUDAStream;
+};
+
+c10::intrusive_ptr<CUDAEvent> CUDAStream::recordEvent(
+    c10::intrusive_ptr<CUDAEvent> event) {
+  if (!event) {
+    event = c10::make_intrusive<CUDAEvent>();
+  }
+
+  event->recordInternal(this);
+  return event;
+}
+
+void CUDAStream::waitEvent(c10::intrusive_ptr<CUDAEvent> event) {
+  event->event_->block(*stream_);
+}
+
+void CUDAStream::waitStream(c10::intrusive_ptr<CUDAStream> stream) {
+  auto ev = c10::make_intrusive<CUDAEvent>();
+  stream->recordEvent(ev);
+  waitEvent(ev);
+}
+
+void CUDAEvent::record(c10::intrusive_ptr<CUDAStream> stream) {
+  event_->record(*stream->stream_);
+}
+
+void CUDAEvent::recordInternal(CUDAStream* stream) {
+  event_->record(*stream->stream_);
+}
+
+void CUDAEvent::wait(c10::intrusive_ptr<CUDAStream> stream) {
+  event_->block(*stream->stream_);
+}
+
+TORCH_LIBRARY(cuda, m) {
+  auto stream_class = m.class_<torch::jit::CUDAStream>("Stream").def(
+      torch::init<int64_t, int64_t>());
+  auto event_class = m.class_<torch::jit::CUDAEvent>("Event").def(
+      torch::init<bool, bool, bool>());
+
+  stream_class.def("query", &CUDAStream::query)
+      .def("record_event", &CUDAStream::recordEvent)
+      .def("synchronize", &CUDAStream::synchronize)
+      .def("wait_event", &CUDAStream::waitEvent)
+      .def("wait_stream", &CUDAStream::waitStream)
+      .def("device_index", &CUDAStream::device_index)
+      .def("device", &CUDAStream::device)
+      .def("pack", &CUDAStream::pack)
+      .def("id", &CUDAStream::id);
+
+  event_class.def("elapsed_time", &CUDAEvent::elapsedTime)
+      .def("query", &CUDAEvent::query)
+      .def("record", &CUDAEvent::record)
+      .def("synchronize", &CUDAEvent::synchronize)
+      .def("wait", &CUDAEvent::wait);
+};
+
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/frontend/script_type_parser.cpp b/torch/csrc/jit/frontend/script_type_parser.cpp
index 8b1aa58..f4c1fa2 100644
--- a/torch/csrc/jit/frontend/script_type_parser.cpp
+++ b/torch/csrc/jit/frontend/script_type_parser.cpp
@@ -211,6 +211,13 @@
       }
     }
 
+    // Check if the type is a custom class. This is done by checking
+    // if type_name starts with "torch.classes."
+    if (type_name.find("torch.classes.") == 0) {
+      auto custom_class_type = getCustomClass("__torch__." + type_name);
+      return custom_class_type;
+    }
+
     throw ErrorReport(expr) << "Unknown type name '" << type_name << "'";
   } else if (auto name = parseBaseTypeName(expr)) {
     auto itr = string_to_type_lut().find(*name);
diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp
index 0b3e4a4..1ca0f48 100644
--- a/torch/csrc/jit/ir/alias_analysis.cpp
+++ b/torch/csrc/jit/ir/alias_analysis.cpp
@@ -572,7 +572,8 @@
           !aliasAnalysisHasSpecialCaseFor(node->kind()),
       "Special cases should be handled already if we're here.");
 
-  if (node->kind().is_aten() || node->kind().is_prim()) {
+  if (node->kind().is_aten() || node->kind().is_prim() ||
+      node->kind().is_cuda()) {
     // TODO There is nothing in the system that relies on aten:: and prim::
     // ops using AliasAnalysisKind::FROM_SCHEMA or
     // AliasAnalysisKind::INTERNAL_SPECIAL_CASE, but this is the intended
diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp
index 65b410d..eb75928 100644
--- a/torch/csrc/jit/ir/ir.cpp
+++ b/torch/csrc/jit/ir/ir.cpp
@@ -1079,6 +1079,11 @@
     case prim::rpc_sync: // It represents RPC message sent.
     case prim::rpc_remote: // It represents RPC message sent.
     case aten::wait: // It can represent RPC message received.
+#ifndef __HIP_PLATFORM_HCC__
+    case cuda::set_stream:
+    case cuda::_set_device:
+    case cuda::_current_device:
+#endif
     case prim::Enter:
     case prim::Exit:
       return true;
@@ -1094,7 +1099,7 @@
     return false;
   }
 
-  if (kind_.is_prim() || kind_.is_aten()) {
+  if (kind_.is_prim() || kind_.is_aten() || kind_.is_cuda()) {
     // TODO There is nothing in the system that relies on aten:: and prim::
     // ops using AliasAnalysisKind::FROM_SCHEMA,
     // AliasAnalysisKind::INTERNAL_SPECIAL_CASE, or
diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h
index 21f172f..02867b8 100644
--- a/torch/csrc/jit/ir/ir.h
+++ b/torch/csrc/jit/ir/ir.h
@@ -72,6 +72,11 @@
 namespace aten {
 using namespace ::c10::aten;
 }
+namespace cuda {
+#ifndef __HIP_PLATFORM_HCC__
+using namespace ::c10::cuda;
+#endif
+} // namespace cuda
 
 struct Function;
 struct MatchedSchema;
diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp
index 933d3bb..056e23d 100644
--- a/torch/csrc/jit/python/python_sugared_value.cpp
+++ b/torch/csrc/jit/python/python_sugared_value.cpp
@@ -217,6 +217,32 @@
   return toSugaredValue(member, m, loc, /*is_constant=*/true);
 }
 
+#ifndef __HIP_PLATFORM_HCC__
+std::shared_ptr<SugaredValue> CUDAPythonModuleValue::attr(
+    const SourceRange& loc,
+    Function& m,
+    const std::string& field) {
+  // List of all the cuda operators which are supported in JIT
+  const std::unordered_set<std::string> cuda_ops = {"current_stream",
+                                                    "default_stream",
+                                                    "_current_device",
+                                                    "_set_device",
+                                                    "device_index",
+                                                    "device_count",
+                                                    "set_stream"};
+
+  if (cuda_ops.find(field) != cuda_ops.end()) {
+    return std::make_shared<BuiltinFunction>(Symbol::cuda(field), c10::nullopt);
+  }
+
+  py::object member = getattr(loc, field);
+  // note: is_constant = true because we consider that global properties
+  // on modules like math.pi or torch.float to be constants
+  // even though it is possible, though rare, for someone to mutate them
+  return toSugaredValue(member, m, loc, /*is_constant=*/true);
+}
+#endif
+
 Value* ModuleValue::asValue(const SourceRange& loc, Function& m) {
   return self_;
 }
@@ -938,6 +964,12 @@
   if (auto callee = as_function(obj)) {
     return std::make_shared<FunctionValue>(callee->function_);
   } else if (py::isinstance<py::module>(obj)) {
+#ifndef USE_ROCM
+    std::string obj_name = py::cast<py::str>(py::getattr(obj, "__name__"));
+    if (obj_name.compare("torch.cuda") == 0) {
+      return std::make_shared<CUDAPythonModuleValue>(obj);
+    }
+#endif
     return std::make_shared<PythonModuleValue>(obj);
   } else if (
       obj.ptr() == py::module::import("torch.jit").attr("_fork").ptr() ||
diff --git a/torch/csrc/jit/python/python_sugared_value.h b/torch/csrc/jit/python/python_sugared_value.h
index b5d8f44..1edbc6c 100644
--- a/torch/csrc/jit/python/python_sugared_value.h
+++ b/torch/csrc/jit/python/python_sugared_value.h
@@ -91,6 +91,20 @@
       const std::string& field) override;
 };
 
+// Used for desugaring uses of the torch.cuda module. All the CUDA APIs with
+// torch.cuda.* are resolved using CUDAPythonModuleValue.
+#ifndef __HIP_PLATFORM_HCC__
+struct VISIBILITY_HIDDEN CUDAPythonModuleValue : public PythonValue {
+  explicit CUDAPythonModuleValue(py::object mod)
+      : PythonValue(std::move(mod)) {}
+
+  std::shared_ptr<SugaredValue> attr(
+      const SourceRange& loc,
+      Function& m,
+      const std::string& field) override;
+};
+#endif
+
 // Represents all the parameters of a module as a List[Tensor]
 struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue {
   ConstantParameterList(Value* the_list) : the_list_(the_list) {}
diff --git a/torch/csrc/jit/runtime/register_cuda_ops.cpp b/torch/csrc/jit/runtime/register_cuda_ops.cpp
new file mode 100644
index 0000000..5cf31d6
--- /dev/null
+++ b/torch/csrc/jit/runtime/register_cuda_ops.cpp
@@ -0,0 +1,87 @@
+// This file registers special JIT operators used to implement the PyTorch CUDA
+// API in TorchScript.
+#ifndef __HIP_PLATFORM_HCC__
+#include <torch/csrc/api/include/torch/utils.h>
+#include <torch/csrc/jit/cuda/cuda.h>
+#include <torch/csrc/jit/ir/ir.h>
+#include <torch/csrc/jit/runtime/custom_operator.h>
+#include <torch/csrc/jit/runtime/operator.h>
+
+namespace torch {
+namespace jit {
+
+namespace {
+
+c10::AliasAnalysisKind aliasAnalysisFromSchema() {
+  return c10::AliasAnalysisKind::FROM_SCHEMA;
+}
+
+RegisterOperators const reg({
+    Operator(
+        "cuda::current_stream(int64_t val) -> __torch__.torch.classes.cuda.Stream",
+        [](Stack* stack) {
+          auto idx = uint16_t(pop(stack).toInt());
+          auto s = c10::cuda::getCurrentCUDAStream(idx);
+          auto st = make_custom_class<torch::jit::CUDAStream>(s);
+          push(stack, IValue(st));
+        },
+        aliasAnalysisFromSchema()),
+    Operator(
+        "cuda::default_stream(int64_t val) -> __torch__.torch.classes.cuda.Stream",
+        [](Stack* stack) {
+          auto idx = uint16_t(pop(stack).toInt());
+          auto s = c10::cuda::getDefaultCUDAStream(idx);
+          auto st = make_custom_class<torch::jit::CUDAStream>(s);
+          push(stack, IValue(st));
+        },
+        aliasAnalysisFromSchema()),
+    Operator(
+        "cuda::_current_device() -> int",
+        [](Stack* stack) {
+          auto v = c10::cuda::current_device();
+          push(stack, static_cast<int>(v));
+        },
+        aliasAnalysisFromSchema()),
+    Operator(
+        "cuda::_set_device(int64_t val) -> ()",
+        [](Stack* stack) {
+          int64_t idx = -1;
+          pop(stack, idx);
+          c10::cuda::set_device(static_cast<c10::DeviceIndex>(idx));
+        },
+        aliasAnalysisFromSchema()),
+    Operator(
+        "cuda::device_index(Device device) -> int",
+        [](Stack* stack) {
+          auto device = pop(stack);
+          auto idx = device.toDevice().index();
+          push(stack, idx);
+        },
+        aliasAnalysisFromSchema()),
+    Operator(
+        "cuda::device_count() -> int",
+        [](Stack* stack) { push(stack, at::cuda::device_count()); },
+        aliasAnalysisFromSchema()),
+    Operator(
+        "cuda::set_stream(__torch__.torch.classes.cuda.Stream stream) -> ()",
+        [](Stack* stack) {
+          auto v = pop(stack);
+          auto s = v.toCustomClass<torch::jit::CUDAStream>();
+          // To set the current CUDA stream using
+          // c10::cuda::setCurrentCUDAStream, the jit::CUDAStream object needs
+          // to be converted to c10::cuda::CUDAStream. Since the latter cannot
+          // be returned from a class registered via TorchBind, this can only be
+          // achieved by packing the c10::cuda::CUDAStream instance contained
+          // inside the jit::CUDAStream object to a uint64_t representation, and
+          // unpacking it inside this operator. The unpacked stream is then used
+          // to set the current CUDA stream.
+          auto packed = s->pack();
+          auto unpacked = c10::cuda::CUDAStream::unpack(packed);
+          c10::cuda::setCurrentCUDAStream(unpacked);
+        },
+        aliasAnalysisFromSchema()),
+});
+} // namespace
+} // namespace jit
+} // namespace torch
+#endif
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index f2b0c5c..cfd3271 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -44,6 +44,7 @@
 from torch.jit._serialization import save, load
 from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph
 
+from torch.jit.cuda import stream
 from torch.jit._freeze import freeze
 
 # For backwards compatibility
diff --git a/torch/jit/cuda.py b/torch/jit/cuda.py
new file mode 100644
index 0000000..1680530
--- /dev/null
+++ b/torch/jit/cuda.py
@@ -0,0 +1,182 @@
+# mypy: ignore-errors
+
+r"""
+This package adds support for JIT compilation for CUDA Streams and events,
+This is similar to API's available in the eager mode
+:ref:`cuda-semantics` has more details about working with CUDA.
+"""
+
+import torch
+from typing import Optional, Any
+from torch import device as _device
+
+def get_current_device_index() -> int:
+    r"""Checks if there are CUDA devices available and
+    returns the device index of the current default CUDA device.
+    Returns -1 in case there are no CUDA devices available.
+
+    Arguments: ``None``
+    """
+    if torch.cuda.device_count() > 0:
+        return torch.cuda._current_device()
+    return -1
+
+def get_device_index(device: Optional[_device] = None, optional: bool = False, allow_cpu: bool = False) -> int:
+    r"""Gets the device index from :attr:`device`, which can be a torch.device
+    object, a Python integer, or ``None``.
+
+    If :attr:`device` is a torch.device object, returns the device index if it
+    is a CUDA device. Note that for a CUDA device without a specified index,
+    , this will return the current default CUDA device if :attr:`optional` is ``True``.
+    If :attr:`allow_cpu` is ``True``,CPU devices will be accepted and ``-1`` will be
+    returned in this case.
+
+    If :attr:`device` is a Python integer, it is returned as is.
+
+    If :attr:`device` is ``None``, this will return the current default CUDA
+    device if :attr:`optional` is ``True``.
+    """
+    if device is None:
+        if optional:
+            return get_current_device_index()
+        else:
+            raise ValueError('Expected a torch.device with a specified index '
+                             f'or an integer, but got: {device}')
+    device_index = -1
+    if isinstance(device, str):
+        device = torch.device(device)
+
+    if isinstance(device, torch.device):
+        if not allow_cpu and device.type == 'cpu':
+            raise ValueError(f'Expected a non cpu device, but got: {device}')
+        device_index = -1 if device.type == 'cpu' else torch.cuda.device_index(device)
+
+    if isinstance(device, int):
+        device_index = device
+
+    return device_index
+
+class device(object):
+    r"""Context-manager that changes the selected device.
+    This is similar to device (torch.device or int), but has been
+    introduced for JIT compatibility.
+    Arguments:
+        device (torch.device or int): device index to select. It's a no-op if
+            this argument is a negative integer or ``None``.
+    """
+    def __init__(self, device: Optional[_device]):
+        self.idx = -1
+        self.prev_idx = -1
+        self.device = device
+
+    def __enter__(self):
+        self.idx = get_device_index(self.device, optional=True)
+
+        if self.idx == -1:
+            return
+        self.prev_idx = torch.cuda._current_device()
+
+        if self.prev_idx != self.idx:
+            torch.cuda._set_device(self.idx)
+
+    def __exit__(self, type: Any, value: Any, traceback: Any):
+        if self.prev_idx != self.idx:
+            torch.cuda._set_device(self.prev_idx)
+
+class StreamContext(object):
+    r"""Context-manager that selects a given stream.
+    All CUDA kernels queued within its context will be enqueued on a selected
+    stream.
+    Arguments:
+        StreamContext (Stream): selected stream. This manager is a no-op if it's
+            ``None``.
+    .. note:: Streams are per-device. If the selected stream is not on the
+        current device, this function will also change the current device to
+        match the stream.
+    """
+    cur_stream : Optional['torch.classes.cuda.Stream']
+
+    def __init__(self, stream: Optional['torch.classes.cuda.Stream']):
+        self.idx = -1
+        self.stream = stream
+        # Initialize the below streams to default stream on the current device
+        self.device_index = get_current_device_index()
+        self.src_prev_stream = torch.cuda.default_stream(self.device_index)
+        self.dst_prev_stream = torch.cuda.default_stream(self.device_index)
+
+    def __enter__(self):
+        self.idx = get_device_index(device=None, optional=True)
+        # If there is no CUDA device available, return
+        if self.idx == -1:
+            return
+
+        # Local cur_stream variable for type refinement
+        cur_stream = self.stream
+        # Return if stream is None
+        if cur_stream is None:
+            return
+        self.src_prev_stream = torch.cuda.current_stream(self.idx)
+        # If the stream is not on the current device, then change the device
+        # and set the current stream on the device
+        if self.src_prev_stream.device_index() != cur_stream.device_index():
+            with device(cur_stream.device()):
+                self.dst_prev_stream = torch.cuda.current_stream(cur_stream.device_index())
+            torch.cuda._set_device(cur_stream.device_index())
+        torch.cuda.set_stream(cur_stream)
+
+    def __exit__(self, type: Any, value: Any, traceback: Any):
+        # Local cur_stream variable for type refinement
+        cur_stream = self.stream
+        # If stream is None or no CUDA device available, return
+        if cur_stream is None or self.idx == -1:
+            return
+        # If the stream was not on the current device, restore the previous stream on
+        # the destination device and also reset the current device to the previous device.
+        # Set the current stream on the device to the src_prev_stream
+        if self.src_prev_stream.device_index() != cur_stream.device_index():
+            torch.cuda.set_stream(self.dst_prev_stream)
+            torch.cuda._set_device(self.idx)
+        torch.cuda.set_stream(self.src_prev_stream)
+
+def stream(stream: Optional['torch.classes.cuda.Stream']) -> StreamContext:
+    r"""Wrapper around the Context-manager that selects a given stream.
+    All CUDA kernels queued within its context will be enqueued on a selected
+    stream.
+    Arguments:
+        stream (Stream): selected stream. This manager is a no-op if it's
+            ``None``.
+    """
+    return StreamContext(stream)
+
+def Stream(device: int = -1, priority: int = 0) -> 'torch.classes.cuda.Stream':
+    r"""Wrapper around a CUDA stream.
+    A CUDA stream is a linear sequence of execution that belongs to a specific
+    device, independent from other streams.  See :ref:`cuda-semantics` for
+    details.
+    Arguments:
+        device(int, optional): a device on which to allocate
+            the stream. If :attr:`device` is ``None`` (default) or a negative
+            integer, this will use the current device.
+        priority(int, optional): priority of the stream. Can be either
+            -1 (high priority) or 0 (low priority). By default, streams have
+            priority 0.
+    .. note:: Although CUDA versions >= 11 support more than two levels of
+        priorities, in PyTorch, we only support two levels of priorities.
+    """
+    return torch.classes.cuda.Stream(device, priority)
+
+def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False) -> 'torch.classes.cuda.Event':
+    r"""Wrapper around a CUDA event.
+    CUDA events are synchronization markers that can be used to monitor the
+    device's progress, to accurately measure timing, and to synchronize CUDA
+    streams.
+    Arguments:
+        enable_timing (bool, optional): indicates if the event should measure time
+            (default: ``False``)
+        blocking (bool, optional): if ``True``, :meth:`wait` will be blocking (default: ``False``)
+        interprocess (bool): if ``True``, the event can be shared between processes
+            (default: ``False``)
+    .. _CUDA Event Documentation:
+       https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
+    """
+    return torch.classes.cuda.Event(enable_timing, blocking, interprocess)