Fix DeviceGuard usage in THD (#8622)
diff --git a/torch/lib/THD/base/data_channels/DataChannelMPI.cpp b/torch/lib/THD/base/data_channels/DataChannelMPI.cpp
index 2d9b917..8f9cbaa 100644
--- a/torch/lib/THD/base/data_channels/DataChannelMPI.cpp
+++ b/torch/lib/THD/base/data_channels/DataChannelMPI.cpp
@@ -1,6 +1,8 @@
#include "DataChannelMPI.hpp"
#include "DataChannelUtils.hpp"
+#include <ATen/ATen.h>
+
#include <algorithm>
#include <cstdint>
#include <cstring>
@@ -125,7 +127,6 @@
return true;
}
-
rank_type DataChannelMPI::getRank() {
return _rank;
}
@@ -135,31 +136,12 @@
return _num_processes;
}
-struct DeviceGuard {
- DeviceGuard(int new_device) {
- if (new_device == -1) return;
-#ifdef USE_CUDA
- cudaGetDevice(&device_);
- cudaSetDevice(new_device);
-#endif
- }
-
- ~DeviceGuard() {
- if (device_ == -1) return;
-#ifdef USE_CUDA
- cudaSetDevice(device_);
-#endif
- }
-
- int device_ = -1;
-};
-
at::Tensor DataChannelMPI::_newLikeFlat(std::vector<at::Tensor>& tensors) const {
// TODO: check if all outputs are contiguous in memory and skip this step is yes
if (tensors.size() == 0)
throw std::runtime_error("received an empty list");
auto & t = tensors[0];
- DeviceGuard gpu_guard { t.is_cuda() ? static_cast<int>(t.get_device()) : -1 };
+ at::DeviceGuard gpu_guard(t.is_cuda() ? t.get_device() : -1);
std::vector<int64_t> sizes { static_cast<int64_t>(tensors.size()) }; // sizes = [output.size()] + input.sizes()
sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end());
return t.type().tensor(sizes);
diff --git a/torch/lib/THD/base/data_channels/DataChannelNccl.cpp b/torch/lib/THD/base/data_channels/DataChannelNccl.cpp
index 47c0840..691f87d 100644
--- a/torch/lib/THD/base/data_channels/DataChannelNccl.cpp
+++ b/torch/lib/THD/base/data_channels/DataChannelNccl.cpp
@@ -2,6 +2,8 @@
#include "DataChannelNccl.hpp"
#include "DataChannelUtils.hpp"
+#include <ATen/ATen.h>
+
#include <cuda.h>
#include <THC/THC.h>
@@ -189,7 +191,7 @@
// Destroy the CUDA events
size_t idx = 0;
for (auto& event : *(_groupNcclResources[groupId][i].ncclCudaEvents())) {
- gpuGuard.setDevice(devices[idx++]);
+ gpuGuard.set_index(devices[idx++]);
THCudaCheck(cudaEventSynchronize(event));
THCudaCheck(cudaEventDestroy(event));
}
@@ -304,14 +306,14 @@
// Now creating the CUDA events
for (size_t i = 0; i < input.size(); ++i) {
- gpuGuard.setDevice(input[i].get_device());
+ gpuGuard.set_index(input[i].get_device());
THCudaCheck(cudaEventCreate(&((*events)[i])));
}
// Create the communicator on each device of the input
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < input.size(); ++i) {
int nRanks = int(_numProcesses) * input.size();
- gpuGuard.setDevice(input[i].get_device());
+ gpuGuard.set_index(input[i].get_device());
NCCL_CHECK(ncclCommInitRank(&((*comms)[i]),
nRanks,
ncclId,
@@ -429,7 +431,7 @@
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < data.size(); ++i) {
- gpuGuard.setDevice(data[i].get_device());
+ gpuGuard.set_index(data[i].get_device());
auto stream = THCState_getCurrentStream(THDGetCudaState());
NCCL_CHECK(ncclAllReduce(data[i].data_ptr(),
@@ -480,7 +482,7 @@
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < input.size(); ++i) {
- gpuGuard.setDevice(input[i].get_device());
+ gpuGuard.set_index(input[i].get_device());
auto stream = THCState_getCurrentStream(THDGetCudaState());
NCCL_CHECK(ncclAllGather(input[i].data_ptr(),
@@ -532,7 +534,7 @@
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < data.size(); ++i) {
- gpuGuard.setDevice(data[i].get_device());
+ gpuGuard.set_index(data[i].get_device());
auto stream = THCState_getCurrentStream(THDGetCudaState());
NCCL_CHECK(ncclReduce(data[i].data_ptr(),
@@ -584,7 +586,7 @@
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < data.size(); ++i) {
- gpuGuard.setDevice(data[i].get_device());
+ gpuGuard.set_index(data[i].get_device());
auto stream = THCState_getCurrentStream(THDGetCudaState());
NCCL_CHECK(ncclBcast(data[i].data_ptr(),