add additional stream priority for cuda streams (#101956)

Changes the StreamID encoding to use the last bit to distinguish between external and internal streams, 4 bits for IdType (DEFAULT, EXT or user-created streams possibly with high priority), and 5 bits for index. This allows us to have more stream priorities exposed to user (I'm currently setting 4, but that's easy to change now). Note, we are pre-creating all 32 streams in the pool per each allowed priority, I don't know if it's a problem in practice. Currently cuda 11.8/A100 GPUs allow 6 different stream priorities, the number may be different for the different cards/different cuda versions.

Previous callsites explicitly requesting high prioity stream (`isHighPriority=true`) are now getting the highest priority stream.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101956
Approved by: https://github.com/ezyang
diff --git a/c10/cuda/CUDAStream.cpp b/c10/cuda/CUDAStream.cpp
index 00729d7..d004aec 100644
--- a/c10/cuda/CUDAStream.cpp
+++ b/c10/cuda/CUDAStream.cpp
@@ -11,7 +11,6 @@
 #include <mutex>
 #include <vector>
 
-#include <iostream>
 namespace c10 {
 namespace cuda {
 
@@ -23,11 +22,9 @@
 static constexpr int kStreamsPerPoolBits = 5;
 static constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits;
 static constexpr unsigned int kDefaultFlags = cudaStreamNonBlocking;
-static constexpr int kStreamTypeBits = 3;
+static constexpr int kStreamTypeBits = 4;
 
-// Note: lower numbers are higher priorities, zero is default priority
-static constexpr int kHighPriority = -1;
-static constexpr int kLowPriority = 0;
+static int max_stream_priorities;
 
 // Non-default streams
 // Note: the number of CUDA devices is determined at run time,
@@ -43,30 +40,29 @@
 // crash. It's likely an issue in CUDA, but to be safe - let's just "forget"
 // the destruction.
 static c10::once_flag device_flags[C10_COMPILE_TIME_MAX_GPUS];
-static std::atomic<uint32_t> low_priority_counters[C10_COMPILE_TIME_MAX_GPUS];
-static std::atomic<uint32_t> high_priority_counters[C10_COMPILE_TIME_MAX_GPUS];
-static cudaStream_t low_priority_streams[C10_COMPILE_TIME_MAX_GPUS]
-                                        [kStreamsPerPool];
-static cudaStream_t high_priority_streams[C10_COMPILE_TIME_MAX_GPUS]
-                                         [kStreamsPerPool];
+static std::atomic<uint32_t>
+    priority_counters[c10::cuda::max_compile_time_stream_priorities]
+                     [C10_COMPILE_TIME_MAX_GPUS];
+
+static cudaStream_t streams[c10::cuda::max_compile_time_stream_priorities]
+                           [C10_COMPILE_TIME_MAX_GPUS][kStreamsPerPool];
 
 // Note [StreamId assignment]
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~
 // How do we assign stream IDs?
 //
-// -- 57 bits --  -- 5 bits -----  -- 3 bits --
-// zeros          stream id index  StreamIdType
+// -- 54 bits --  -- 5 bits -----  -- 4 bits --     --1 bit --
+// zeros          stream id index  StreamIdType     Ext/native stream
+//                ignored for ext   ignored for ext
+// for external stream, StreamID is a cudaStream_t pointer
+// this means that last bit will always be 0
+// so when constructing StreamId for a native stream we set last bit to 1
+// to distinguish between native and external streams
 //
-// Where StreamIdType:
-//  000 = default stream or externally allocated if id[63:3] != 0
-//  001 = low priority stream
-//  010 = high priority stream
-//
-// This is not really for efficiency; it's just easier to write the code
-// to extract the index if we do this with bitmasks :)
 //
 // We are obligated to treat the stream ID 0 as the default stream, per the
-// invariant specified in c10::Stream.  However, all other numbers are entirely
+// invariant specified in c10::Stream, so this is one exception to
+// "last bit = 1 for native streams". However, all other numbers are entirely
 // an internal implementation detail, we reserve the right to renumber streams
 // however we like.
 //
@@ -79,32 +75,41 @@
 //
 // Also, external managed stream pointers (cudaStream_t) can be directly stored
 // in the Id field so in this case, we need to check the stream alignment.
-// The IdType uses an additional bit to match with the 64-bit address alignment
-// making easy to identify an external stream when its value (X & 7) > 0
-enum class StreamIdType : uint8_t {
-  DEFAULT = 0x0,
-  LOW = 0x1,
-  HIGH = 0x2,
-  EXT = 0x3,
+
+class StreamIdType {
+  // StreamIdType encodes whether this stream is DEFAULT, EXTernal or
+  // for all other native streams, the stream priority (higher value is higher
+  // priority)
+ private:
+  uint8_t stream_type;
+
+ public:
+  static const uint8_t DEFAULT = 0x0;
+  static const uint8_t EXT = 0xF;
+
+ public:
+  StreamIdType(const uint8_t _stream_type) : stream_type(_stream_type) {}
+
+  bool isExt() const {
+    return EXT == stream_type;
+  }
+
+  bool isDefault() const {
+    return DEFAULT == stream_type;
+  }
+
+  uint8_t getStreamType() const {
+    return stream_type;
+  }
 };
 
 std::ostream& operator<<(std::ostream& stream, StreamIdType s) {
-  switch (s) {
-    case StreamIdType::DEFAULT:
-      stream << "DEFAULT";
-      break;
-    case StreamIdType::LOW:
-      stream << "LOW";
-      break;
-    case StreamIdType::HIGH:
-      stream << "HIGH";
-      break;
-    case StreamIdType::EXT:
-      stream << "EXT";
-      break;
-    default:
-      stream << static_cast<uint8_t>(s);
-      break;
+  if (s.isDefault()) {
+    stream << "DEFAULT";
+  } else if (s.isExt()) {
+    stream << "EXT";
+  } else {
+    stream << "PRIORITY " << int(s.getStreamType());
   }
   return stream;
 }
@@ -114,24 +119,30 @@
 // see Note [Hazard when concatenating signed integers]
 
 static inline StreamIdType streamIdType(StreamId s) {
-  int mask_for_type = (1 << kStreamTypeBits) - 1;
-  if (s && ((s & mask_for_type) == 0)) {
-    // Externally allocated streams have their id being the cudaStream_ptr
-    // so the bits corresponding to the type will be 0 and will collide with
-    // the default stream.
-    return StreamIdType::EXT;
+  // Externally allocated streams have their id being the cudaStream_ptr
+  // so the last bit will be 0
+  if ((!(s & 1)) && s) {
+    return StreamIdType(StreamIdType::EXT);
   }
-  return static_cast<StreamIdType>(s & mask_for_type);
+  // last bit is external/internal stream, the mask should start from second
+  // rightmost bit
+  int mask_for_type = (1 << kStreamTypeBits) - 1;
+  auto val = (s >> 1) & mask_for_type;
+  TORCH_INTERNAL_ASSERT(val || !(s & 1), "invalid StreamId", s);
+  return StreamIdType(val);
 }
 
 static inline size_t streamIdIndex(StreamId s) {
   return static_cast<size_t>(
-      (s >> kStreamTypeBits) & ((1 << kStreamsPerPoolBits) - 1));
+      (s >> (kStreamTypeBits + 1)) & ((1 << kStreamsPerPoolBits) - 1));
 }
 
 StreamId makeStreamId(StreamIdType st, size_t si) {
-  return (static_cast<StreamId>(si) << kStreamTypeBits) |
-      static_cast<StreamId>(st);
+  if (st.isDefault()) {
+    return static_cast<StreamId>(0);
+  }
+  return (static_cast<StreamId>(si) << (kStreamTypeBits + 1)) |
+      static_cast<StreamId>(st.getStreamType() << 1) | 1;
 }
 
 // Thread-local current streams
@@ -149,6 +160,14 @@
       "max number of gpus expected (",
       C10_COMPILE_TIME_MAX_GPUS,
       "). Increase that and recompile.");
+  int leastPriority = -1, greatestPriority = -1;
+  C10_CUDA_CHECK(
+      cudaDeviceGetStreamPriorityRange(&leastPriority, &greatestPriority));
+  // greatestPriority is negative
+  auto range = leastPriority - greatestPriority + 1;
+  max_stream_priorities = range >= c10::cuda::max_compile_time_stream_priorities
+      ? c10::cuda::max_compile_time_stream_priorities
+      : range;
 }
 
 // Creates the low and high priority stream pools for the specified device
@@ -157,27 +176,20 @@
   // Switches to the requested device so streams are properly associated
   // with it.
   CUDAGuard device_guard{device_index};
-
   for (const auto i : c10::irange(kStreamsPerPool)) {
-    auto& lowpri_stream = low_priority_streams[device_index][i];
-    auto& hipri_stream = high_priority_streams[device_index][i];
+    for (const auto p : c10::irange(max_stream_priorities)) {
+      auto& stream = streams[p][device_index][i];
+      auto pri = -p; // lower number is higher priority
 
-    C10_CUDA_CHECK(cudaStreamCreateWithPriority(
-        &lowpri_stream, kDefaultFlags, kLowPriority));
-    C10_CUDA_CHECK(cudaStreamCreateWithPriority(
-        &hipri_stream, kDefaultFlags, kHighPriority));
-
-    const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
-    if (C10_UNLIKELY(interp)) {
-      (*interp)->trace_gpu_stream_creation(
-          reinterpret_cast<uintptr_t>(lowpri_stream));
-      (*interp)->trace_gpu_stream_creation(
-          reinterpret_cast<uintptr_t>(hipri_stream));
+      C10_CUDA_CHECK(cudaStreamCreateWithPriority(&stream, kDefaultFlags, pri));
+      const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
+      if (C10_UNLIKELY(interp)) {
+        (*interp)->trace_gpu_stream_creation(
+            reinterpret_cast<uintptr_t>(stream));
+        priority_counters[p][device_index] = 0;
+      }
     }
   }
-
-  low_priority_counters[device_index] = 0;
-  high_priority_counters[device_index] = 0;
 }
 
 // Init front-end to ensure initialization only occurs once
@@ -225,59 +237,60 @@
   StreamId stream_id = stream_.id();
   StreamIdType st = streamIdType(stream_id);
   size_t si = streamIdIndex(stream_id);
-  switch (st) {
-    case StreamIdType::DEFAULT:
-      TORCH_INTERNAL_ASSERT(
-          si == 0,
-          "Unrecognized stream ",
-          stream_,
-          " (I think this should be the default stream, but I got a non-zero index ",
-          si,
-          ").",
-          " Did you manufacture the StreamId yourself?  Don't do that; use the",
-          " official API like c10::cuda::getStreamFromPool() to get a new stream.");
-      return nullptr;
-    case StreamIdType::LOW:
-      return low_priority_streams[device_index][si];
-    case StreamIdType::HIGH:
-      return high_priority_streams[device_index][si];
-    case StreamIdType::EXT:
-      return reinterpret_cast<cudaStream_t>(stream_id);
-    default:
-      TORCH_INTERNAL_ASSERT(
-          0,
-          "Unrecognized stream ",
-          stream_,
-          " (I didn't recognize the stream type, ",
-          st,
-          ")");
+  if (st.isDefault()) {
+    TORCH_INTERNAL_ASSERT(
+        si == 0,
+        "Unrecognized stream ",
+        stream_,
+        " (I think this should be the default stream, but I got a non-zero index ",
+        si,
+        ").",
+        " Did you manufacture the StreamId yourself?  Don't do that; use the",
+        " official API like c10::cuda::getStreamFromPool() to get a new stream.");
+    return nullptr;
+  } else if (st.isExt()) {
+    return reinterpret_cast<cudaStream_t>(stream_id);
+  } else {
+    auto streamType = st.getStreamType();
+    TORCH_INTERNAL_ASSERT(
+        streamType >= 1 && streamType <= max_stream_priorities,
+        "Unrecognized stream ",
+        stream_,
+        " (I didn't recognize the stream type, ",
+        streamType,
+        ")");
+    return streams[st.getStreamType() - 1][device_index][si];
   }
 }
 
 // Returns a stream from the requested pool
 // Note: when called the first time on a device, this will create the
 // stream pools for that device.
-CUDAStream getStreamFromPool(
-    const bool isHighPriority,
-    DeviceIndex device_index) {
+CUDAStream getStreamFromPool(const int priority, DeviceIndex device_index) {
   initCUDAStreamsOnce();
   if (device_index == -1) {
     device_index = current_device();
     c10::cuda::SetTargetDevice();
   }
+  TORCH_CHECK(
+      priority <= 0,
+      "Expected cuda stream priority to be less than or equal to 0, got ",
+      priority);
   check_gpu(device_index);
-
   // Initializes the stream pools (once)
   c10::call_once(
       device_flags[device_index], initDeviceStreamState, device_index);
+  auto pri_idx = -priority;
+  pri_idx =
+      std::min(pri_idx, max_stream_priorities - 1); // pri_idx is zero-based
+  const auto idx = get_idx(priority_counters[pri_idx][device_index]);
+  StreamIdType id_type = StreamIdType(pri_idx + 1);
+  return CUDAStreamForId(device_index, makeStreamId(id_type, idx));
+}
 
-  if (isHighPriority) {
-    const auto idx = get_idx(high_priority_counters[device_index]);
-    return CUDAStreamForId(device_index, makeStreamId(StreamIdType::HIGH, idx));
-  }
-
-  const auto idx = get_idx(low_priority_counters[device_index]);
-  return CUDAStreamForId(device_index, makeStreamId(StreamIdType::LOW, idx));
+CUDAStream getStreamFromPool(const bool isHighPriority, DeviceIndex device) {
+  int priority = isHighPriority ? -max_stream_priorities + 1 : 0;
+  return getStreamFromPool(priority, device);
 }
 
 CUDAStream getStreamFromExternal(
diff --git a/c10/cuda/CUDAStream.h b/c10/cuda/CUDAStream.h
index 094372a..7cc2a43 100644
--- a/c10/cuda/CUDAStream.h
+++ b/c10/cuda/CUDAStream.h
@@ -55,6 +55,8 @@
 namespace c10 {
 namespace cuda {
 
+static constexpr int max_compile_time_stream_priorities = 4;
+
 // Value object representing a CUDA stream.  This is just a wrapper
 // around c10::Stream, but it comes with a little extra CUDA-specific
 // functionality (conversion to cudaStream_t), and a guarantee that
@@ -174,16 +176,17 @@
   static std::tuple<int, int> priority_range() {
     // Note: this returns the range of priority **supported by PyTorch**, not
     // the range of priority **supported by CUDA**. The former is a subset of
-    // the latter. Currently PyTorch only supports 0 and -1, which are "low" and
-    // "high" priority.
+    // the latter.
     int least_priority, greatest_priority;
     C10_CUDA_CHECK(
         cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority));
     TORCH_INTERNAL_ASSERT(
-        least_priority >= 0, "Unexpected CUDA stream priority range");
+        least_priority == 0, "Unexpected CUDA stream priority range");
     TORCH_INTERNAL_ASSERT(
         greatest_priority <= -1, "Unexpected CUDA stream priority range");
-    return std::make_tuple(0, -1);
+    greatest_priority = std::max(
+        -c10::cuda::max_compile_time_stream_priorities + 1, greatest_priority);
+    return std::make_tuple(least_priority, greatest_priority);
   }
 
   // Deleted for now; use CUDAEvent::block instead
@@ -205,6 +208,9 @@
  */
 C10_API CUDAStream
 getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1);
+// no default priority to disambiguate overloads
+C10_API CUDAStream
+getStreamFromPool(const int priority, DeviceIndex device = -1);
 
 /**
  * Get a CUDAStream from a externally allocated one.
diff --git a/test/test_dlpack.py b/test/test_dlpack.py
index b50a082..7066035 100644
--- a/test/test_dlpack.py
+++ b/test/test_dlpack.py
@@ -148,6 +148,9 @@
     @onlyCUDA
     @skipCUDAIfRocm
     def test_dlpack_convert_default_stream(self, device):
+        # tests run on non-default stream, so _sleep call
+        # below will run on a non-default stream, causing
+        # default stream to wait due to inserted syncs
         torch.cuda.default_stream().synchronize()
         x = torch.zeros(1, device=device)
         torch.cuda._sleep(2**20)
diff --git a/torch/csrc/cuda/Stream.cpp b/torch/csrc/cuda/Stream.cpp
index 18c30df..9d16361 100644
--- a/torch/csrc/cuda/Stream.cpp
+++ b/torch/csrc/cuda/Stream.cpp
@@ -57,15 +57,13 @@
     TORCH_CHECK(
         priority == 0, "Priority was explicitly set for a external stream")
   }
-
   at::cuda::CUDAStream stream = (stream_id || device_index || device_type)
       ? at::cuda::CUDAStream::unpack3(
             stream_id, device_index, static_cast<c10::DeviceType>(device_type))
       : stream_ptr
       ? at::cuda::getStreamFromExternal(
             reinterpret_cast<cudaStream_t>(stream_ptr), current_device)
-      : at::cuda::getStreamFromPool(
-            /* isHighPriority */ priority < 0 ? true : false);
+      : at::cuda::getStreamFromPool(priority);
 
   THCPStream* self = (THCPStream*)ptr.get();
   self->stream_id = static_cast<int64_t>(stream.id());
diff --git a/torch/csrc/jit/cuda/cuda.h b/torch/csrc/jit/cuda/cuda.h
index a6afc0a..e8a0d04 100644
--- a/torch/csrc/jit/cuda/cuda.h
+++ b/torch/csrc/jit/cuda/cuda.h
@@ -17,11 +17,10 @@
   CUDAStream(
       c10::optional<c10::Device> device = c10::nullopt,
       int64_t priority = 0) {
-    constexpr int64_t PRIORITY_INDEX = 0;
     c10::DeviceIndex device_index =
         device.has_value() ? device->index() : c10::cuda::current_device();
     stream_ = std::make_unique<c10::cuda::CUDAStream>(
-        c10::cuda::getStreamFromPool(priority < PRIORITY_INDEX, device_index));
+        c10::cuda::getStreamFromPool(static_cast<int>(priority), device_index));
   }
 
   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
diff --git a/torch/cuda/streams.py b/torch/cuda/streams.py
index 0c125da..52eab59 100644
--- a/torch/cuda/streams.py
+++ b/torch/cuda/streams.py
@@ -20,12 +20,10 @@
         device(torch.device or int, optional): a device on which to allocate
             the stream. If :attr:`device` is ``None`` (default) or a negative
             integer, this will use the current device.
-        priority(int, optional): priority of the stream. Can be either
-            -1 (high priority) or 0 (low priority). By default, streams have
-            priority 0.
+        priority(int, optional): priority of the stream, should be 0 or
+            negative, where negative numbers indicate higher priority. By default,
+            streams have priority 0.
 
-    .. note:: Although CUDA versions >= 11 support more than two levels of
-        priorities, in PyTorch, we only support two levels of priorities.
     """
 
     def __new__(cls, device=None, priority=0, **kwargs):