Delete default constructor from CUDAStream. (#13021)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13021
Let's make nullptr CUDAStream an illegal state.
Reviewed By: gchanan
Differential Revision: D10520421
fbshipit-source-id: 723c1f5130b2c92ec97411a958707fac4a90173f
diff --git a/aten/src/ATen/cuda/CUDAStream.h b/aten/src/ATen/cuda/CUDAStream.h
index 8f43602..ab15324 100644
--- a/aten/src/ATen/cuda/CUDAStream.h
+++ b/aten/src/ATen/cuda/CUDAStream.h
@@ -6,6 +6,7 @@
#include "cuda_runtime_api.h"
#include <ATen/cuda/ATenCUDAGeneral.h>
+#include <c10/util/Exception.h>
/*
* A CUDAStream interface. See CUDAStream.cpp for implementation details.
@@ -80,9 +81,10 @@
struct AT_CUDA_API CUDAStream {
// Constructors
- CUDAStream() = default;
/* implicit */ CUDAStream(CUDAStreamInternals* internals_in)
- : internals_{internals_in} { }
+ : internals_{internals_in} {
+ AT_ASSERT(internals_in);
+ }
// Returns true if the CUDAStream is not null.
explicit operator bool() const noexcept { return internals_ != nullptr; }
diff --git a/aten/src/ATen/test/stream_test.cpp b/aten/src/ATen/test/stream_test.cpp
index f2f0d79..18d5250 100644
--- a/aten/src/ATen/test/stream_test.cpp
+++ b/aten/src/ATen/test/stream_test.cpp
@@ -31,7 +31,7 @@
cudaStream_t cuda_stream;
// Tests that copying works as expected and preserves the stream
- at::cuda::CUDAStream copyStream;
+ at::cuda::CUDAStream copyStream = at::cuda::getStreamFromPool();
{
auto s = at::cuda::getStreamFromPool();
device = s.device();
@@ -49,7 +49,7 @@
ASSERT_EQ_CUDA(copyStream.stream(), cuda_stream);
// Tests that moving works as expected and preserves the stream
- at::cuda::CUDAStream moveStream;
+ at::cuda::CUDAStream moveStream = at::cuda::getStreamFromPool();
{
auto s = at::cuda::getStreamFromPool();
device = s.device();
@@ -85,16 +85,16 @@
ASSERT_EQ_CUDA(curStream, defaultStream);
}
-void thread_fun(at::cuda::CUDAStream& cur_thread_stream) {
+void thread_fun(at::optional<at::cuda::CUDAStream>& cur_thread_stream) {
auto new_stream = at::cuda::getStreamFromPool();
at::cuda::setCurrentCUDAStream(new_stream);
- cur_thread_stream = at::cuda::getCurrentCUDAStream();
- ASSERT_EQ_CUDA(cur_thread_stream, new_stream);
+ cur_thread_stream = {at::cuda::getCurrentCUDAStream()};
+ ASSERT_EQ_CUDA(*cur_thread_stream, new_stream);
}
// Ensures streams are thread local
TEST(TestStream, MultithreadGetAndSetTest) {
- at::cuda::CUDAStream s0, s1;
+ at::optional<at::cuda::CUDAStream> s0, s1;
std::thread t0{thread_fun, std::ref(s0)};
std::thread t1{thread_fun, std::ref(s1)};
@@ -105,8 +105,8 @@
at::cuda::CUDAStream default_stream = at::cuda::getDefaultCUDAStream();
ASSERT_EQ_CUDA(cur_stream, default_stream);
- ASSERT_NE_CUDA(cur_stream, s0);
- ASSERT_NE_CUDA(cur_stream, s1);
+ ASSERT_NE_CUDA(cur_stream, *s0);
+ ASSERT_NE_CUDA(cur_stream, *s1);
ASSERT_NE_CUDA(s0, s1);
}
diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp
index 513cc0c..546177a 100644
--- a/torch/lib/c10d/ProcessGroupGloo.cpp
+++ b/torch/lib/c10d/ProcessGroupGloo.cpp
@@ -104,9 +104,10 @@
#ifdef USE_CUDA
std::vector<cudaStream_t> getStreamVector(AlgorithmEntry& entry) {
- std::vector<cudaStream_t> streams(entry.streams.size());
- for (size_t i = 0; i < entry.streams.size(); i++) {
- streams[i] = entry.streams[i].stream();
+ std::vector<cudaStream_t> streams;
+ streams.reserve(entry.streams.size());
+ for (auto s : entry.streams) {
+ streams.push_back(s);
}
return streams;
}
@@ -526,12 +527,12 @@
#ifdef USE_CUDA
// If these are CUDA tensors, create streams and events
if (key.type->is_cuda()) {
- entry->streams.resize(key.devices.size());
- entry->events.resize(key.devices.size());
+ entry->streams.reserve(key.devices.size());
+ entry->events.reserve(key.devices.size());
for (size_t i = 0; i < key.devices.size(); i++) {
deviceGuard.set_index(key.devices[i]);
- entry->streams[i] = at::cuda::getStreamFromPool();
- entry->events[i] = CUDAEvent::create();
+ entry->streams.push_back(at::cuda::getStreamFromPool());
+ entry->events.push_back(CUDAEvent::create());
}
}
#endif
diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp
index ec87a70..7497611 100644
--- a/torch/lib/c10d/ProcessGroupNCCL.cpp
+++ b/torch/lib/c10d/ProcessGroupNCCL.cpp
@@ -245,8 +245,8 @@
std::vector<CUDAEvent> eventVal;
std::vector<at::cuda::CUDAStream> streamVal;
- eventVal.resize(devices.size());
- streamVal.resize(devices.size());
+ eventVal.reserve(devices.size());
+ streamVal.reserve(devices.size());
// Create the NCCL communicators for each GPU
C10D_NCCL_CHECK(ncclGroupStart());
@@ -260,12 +260,12 @@
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID);
// Also create the NCCL streams and events
- streamVal[i] = at::cuda::getStreamFromPool();
+ streamVal.push_back(at::cuda::getStreamFromPool());
// Event created using cudaEventDisableTiming flag and not
// cudaEventBlockingSync flag will provide the best performance when used
// with cudaStreamWaitEvent() and cudaEventQuery(). Since we here don't
// measure the performance using cudaEvent, this should be set.
- eventVal[i] = CUDAEvent::create(cudaEventDisableTiming);
+ eventVal.push_back(CUDAEvent::create(cudaEventDisableTiming));
}
C10D_NCCL_CHECK(ncclGroupEnd());
diff --git a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp
index 71e18ec..fcea617 100644
--- a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp
+++ b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp
@@ -88,10 +88,10 @@
// getters to retrieve the current stream).
//
at::DeviceGuard deviceGuard;
- streams_.resize(numDevices_);
+ streams_.reserve(numDevices_);
for (auto i = 0; i < numDevices_; i++) {
deviceGuard.set_index(i);
- streams_[i] = at::cuda::getStreamFromPool();
+ streams_.push_back(at::cuda::getStreamFromPool());
}
}
diff --git a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp
index 6d9c970..dc3dfd7 100644
--- a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp
+++ b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp
@@ -67,10 +67,10 @@
// and pass this along to the collective (since it uses the THC
// getters to retrieve the current stream).
//
- streams_.resize(numDevices_);
+ streams_.reserve(numDevices_);
for (auto i = 0; i < numDevices_; i++) {
deviceGuard.set_index(i);
- streams_[i] = at::cuda::getStreamFromPool();
+ streams_.push_back(at::cuda::getStreamFromPool());
}
}