Fix DeviceGuard usage in THD (#8622)

diff --git a/torch/lib/THD/base/data_channels/DataChannelMPI.cpp b/torch/lib/THD/base/data_channels/DataChannelMPI.cpp
index 2d9b917..8f9cbaa 100644
--- a/torch/lib/THD/base/data_channels/DataChannelMPI.cpp
+++ b/torch/lib/THD/base/data_channels/DataChannelMPI.cpp
@@ -1,6 +1,8 @@
 #include "DataChannelMPI.hpp"
 #include "DataChannelUtils.hpp"
 
+#include <ATen/ATen.h>
+
 #include <algorithm>
 #include <cstdint>
 #include <cstring>
@@ -125,7 +127,6 @@
   return true;
 }
 
-
 rank_type DataChannelMPI::getRank() {
   return _rank;
 }
@@ -135,31 +136,12 @@
   return _num_processes;
 }
 
-struct DeviceGuard {
-  DeviceGuard(int new_device) {
-    if (new_device == -1) return;
-#ifdef USE_CUDA
-    cudaGetDevice(&device_);
-    cudaSetDevice(new_device);
-#endif
-  }
-
-  ~DeviceGuard() {
-    if (device_ == -1) return;
-#ifdef USE_CUDA
-    cudaSetDevice(device_);
-#endif
-  }
-
-  int device_ = -1;
-};
-
 at::Tensor DataChannelMPI::_newLikeFlat(std::vector<at::Tensor>& tensors) const {
   // TODO: check if all outputs are contiguous in memory and skip this step is yes
   if (tensors.size() == 0)
     throw std::runtime_error("received an empty list");
   auto & t = tensors[0];
-  DeviceGuard gpu_guard { t.is_cuda() ? static_cast<int>(t.get_device()) : -1 };
+  at::DeviceGuard gpu_guard(t.is_cuda() ? t.get_device() : -1);
   std::vector<int64_t> sizes { static_cast<int64_t>(tensors.size()) };  // sizes = [output.size()] + input.sizes()
   sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end());
   return t.type().tensor(sizes);
diff --git a/torch/lib/THD/base/data_channels/DataChannelNccl.cpp b/torch/lib/THD/base/data_channels/DataChannelNccl.cpp
index 47c0840..691f87d 100644
--- a/torch/lib/THD/base/data_channels/DataChannelNccl.cpp
+++ b/torch/lib/THD/base/data_channels/DataChannelNccl.cpp
@@ -2,6 +2,8 @@
 #include "DataChannelNccl.hpp"
 #include "DataChannelUtils.hpp"
 
+#include <ATen/ATen.h>
+
 #include <cuda.h>
 #include <THC/THC.h>
 
@@ -189,7 +191,7 @@
       // Destroy the CUDA events
       size_t idx = 0;
       for (auto& event : *(_groupNcclResources[groupId][i].ncclCudaEvents())) {
-        gpuGuard.setDevice(devices[idx++]);
+        gpuGuard.set_index(devices[idx++]);
         THCudaCheck(cudaEventSynchronize(event));
         THCudaCheck(cudaEventDestroy(event));
       }
@@ -304,14 +306,14 @@
 
   // Now creating the CUDA events
   for (size_t i = 0; i < input.size(); ++i) {
-    gpuGuard.setDevice(input[i].get_device());
+    gpuGuard.set_index(input[i].get_device());
     THCudaCheck(cudaEventCreate(&((*events)[i])));
   }
   // Create the communicator on each device of the input
   NCCL_CHECK(ncclGroupStart());
   for (size_t i = 0; i < input.size(); ++i) {
     int nRanks = int(_numProcesses) * input.size();
-    gpuGuard.setDevice(input[i].get_device());
+    gpuGuard.set_index(input[i].get_device());
     NCCL_CHECK(ncclCommInitRank(&((*comms)[i]),
                                 nRanks,
                                 ncclId,
@@ -429,7 +431,7 @@
   NCCL_CHECK(ncclGroupStart());
   for (size_t i = 0; i < data.size(); ++i) {
 
-    gpuGuard.setDevice(data[i].get_device());
+    gpuGuard.set_index(data[i].get_device());
     auto stream = THCState_getCurrentStream(THDGetCudaState());
 
     NCCL_CHECK(ncclAllReduce(data[i].data_ptr(),
@@ -480,7 +482,7 @@
   NCCL_CHECK(ncclGroupStart());
   for (size_t i = 0; i < input.size(); ++i) {
 
-    gpuGuard.setDevice(input[i].get_device());
+    gpuGuard.set_index(input[i].get_device());
     auto stream = THCState_getCurrentStream(THDGetCudaState());
 
     NCCL_CHECK(ncclAllGather(input[i].data_ptr(),
@@ -532,7 +534,7 @@
   NCCL_CHECK(ncclGroupStart());
   for (size_t i = 0; i < data.size(); ++i) {
 
-    gpuGuard.setDevice(data[i].get_device());
+    gpuGuard.set_index(data[i].get_device());
     auto stream = THCState_getCurrentStream(THDGetCudaState());
 
     NCCL_CHECK(ncclReduce(data[i].data_ptr(),
@@ -584,7 +586,7 @@
   NCCL_CHECK(ncclGroupStart());
   for (size_t i = 0; i < data.size(); ++i) {
 
-    gpuGuard.setDevice(data[i].get_device());
+    gpuGuard.set_index(data[i].get_device());
     auto stream = THCState_getCurrentStream(THDGetCudaState());
 
     NCCL_CHECK(ncclBcast(data[i].data_ptr(),