Allow passing cpu to CUDA RPC device maps (#57019)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57019

Based on https://github.com/pytorch/pytorch/pull/56043

Test Plan: Imported from OSS

Reviewed By: anjali411

Differential Revision: D28169796

Pulled By: beauby

fbshipit-source-id: 7fcf623de07c74c4f1ab415b7e20b518876a567a
diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp
index 4a0ba70..f940119 100644
--- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp
+++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp
@@ -55,10 +55,16 @@
       remoteName);
   std::vector<c10::Device> devices;
   devices.reserve(tensors.size());
-  bool hasCudaTensor = false;
+  bool hasMappedDevice = false;
   for (const auto& t : tensors) {
     if (t.device().is_cpu()) {
-      devices.emplace_back(c10::kCPU);
+      const auto deviceIter = deviceMap.find(c10::kCPU);
+      if (deviceIter == deviceMap.end()) {
+        devices.emplace_back(c10::kCPU);
+      } else {
+        devices.emplace_back(deviceIter->second);
+        hasMappedDevice = true;
+      }
     } else {
       const auto deviceIter = deviceMap.find(t.device());
       TORCH_CHECK(
@@ -68,10 +74,10 @@
           t.device(),
           " but received a tensor on that device.");
       devices.push_back(deviceIter->second);
-      hasCudaTensor = true;
+      hasMappedDevice = true;
     }
   }
-  if (!hasCudaTensor) {
+  if (!hasMappedDevice) {
     devices.clear();
   }
   return devices;
@@ -95,7 +101,9 @@
   std::unordered_set<c10::Device> deviceSet;
   for (const auto& entry : deviceMap) {
     for (const auto& device : entry.second) {
-      deviceSet.insert(device.first);
+      if (!device.first.is_cpu()) {
+        deviceSet.insert(device.first);
+      }
     }
   }
   return deviceSet;
diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py
index ce421a4..08a69f9 100644
--- a/torch/distributed/rpc/api.py
+++ b/torch/distributed/rpc/api.py
@@ -490,12 +490,6 @@
         to retrieve the result value locally.
 
     .. warning ::
-        Using GPU tensors as arguments or return values of ``func`` is not
-        supported since we don't support sending GPU tensors over the wire. You
-        need to explicitly copy GPU tensors to CPU before using them as
-        arguments or return values of ``func``.
-
-    .. warning ::
         The ``remote`` API does not copy storages of argument tensors until
         sending them over the wire, which could be done by a different thread
         depending on the RPC backend type. The caller should make sure that the
@@ -695,12 +689,6 @@
     Returns:
         Returns the result of running ``func`` with ``args`` and ``kwargs``.
 
-    .. warning ::
-        Using GPU tensors as arguments or return values of ``func`` is not
-        supported since we don't support sending GPU tensors over the wire. You
-        need to explicitly copy GPU tensors to CPU before using them as
-        arguments or return values of ``func``.
-
     Example::
         Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
         on both workers. Refer to :meth:`~torch.distributed.init_process_group`
diff --git a/torch/distributed/rpc/options.py b/torch/distributed/rpc/options.py
index 0fb13e3..0c32a57 100644
--- a/torch/distributed/rpc/options.py
+++ b/torch/distributed/rpc/options.py
@@ -25,12 +25,6 @@
     for k in device_map:
         v = device_map[k]
         k, v = torch.device(k), torch.device(v)
-        if k.type != 'cuda' or v.type != 'cuda':
-            raise ValueError(
-                "`set_device_map` only supports CUDA devices, "
-                f"but got device pair {k}: {v}"
-
-            )
         if v in reverse_map:
             raise ValueError(
                 "`device_map` only supports 1-to-1 mapping, "
diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py
index a32650c..5309796 100644
--- a/torch/testing/_internal/distributed/rpc/rpc_test.py
+++ b/torch/testing/_internal/distributed/rpc/rpc_test.py
@@ -4934,18 +4934,16 @@
         rpc.shutdown()
 
     @staticmethod
-    def _gpu_add_given_gpus(x, y, x_to, y_to, z_to):
-        if all([
-            x.is_cuda,
-            x.device.index == x_to,
-            y.is_cuda,
-            y.device.index == y_to
-        ]):
+    def _gpu_add_given_devices(x, y, x_to, y_to, z_to):
+        x_device = "cpu" if x.device.type == "cpu" else x.device.index
+        y_device = "cpu" if y.device.type == "cpu" else y.device.index
+        if x_device == x_to and y_device == y_to:
             return x.to(z_to) + y.to(z_to)
         else:
             raise ValueError("Wrong device affinity")
 
-    def _test_device_maps_gpu(self, x_from, y_from, z_to, device_map, dst=None):
+    def _test_device_maps_gpu(self, x_from, y_from, z_to, device_map, dst=None, fn=None):
+        fn = TensorPipeAgentCudaRpcTest._gpu_add_given_devices if fn is None else fn
         x_to = device_map[x_from]
         y_to = device_map[y_from]
 
@@ -4964,20 +4962,66 @@
         x = torch.zeros(2).to(x_from)
         y = torch.ones(2).to(y_from)
 
-        ret = rpc.rpc_sync(
-            dst,
-            TensorPipeAgentCudaRpcTest._gpu_add_given_gpus,
-            args=(x, y, x_to, y_to, z_to)
-        )
+        ret = rpc.rpc_sync(dst, fn, args=(x, y, x_to, y_to, z_to))
 
         reverse_device_map = {device_map[k] : k for k in device_map}
         z_from = reverse_device_map[z_to]
 
-        self.assertEqual(ret.device.index, z_from)
+        ret_device = "cpu" if ret.device.type == "cpu" else ret.device.index
+        self.assertEqual(ret_device, z_from)
         self.assertEqual(ret, torch.ones(2).to(z_from))
 
         rpc.shutdown()
 
+    def test_device_map_cpu(self):
+        self._test_device_maps_gpu(
+            x_from="cpu",
+            y_from="cpu",
+            z_to="cpu",
+            device_map={"cpu" : "cpu"},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_map_cpu_to_gpu_default(self):
+        self._test_device_maps_gpu(
+            x_from="cpu",
+            y_from="cpu",
+            z_to=0,
+            device_map={"cpu" : 0},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_cpu_to_gpu_non_default(self):
+        self._test_device_maps_gpu(
+            x_from="cpu",
+            y_from="cpu",
+            z_to=1,
+            device_map={"cpu" : 1},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
+    @skip_if_lt_x_gpu(1)
+    def test_device_map_gpu_to_cpu_default(self):
+        self._test_device_maps_gpu(
+            x_from=0,
+            y_from=0,
+            z_to="cpu",
+            device_map={0 : "cpu"},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
+    @skip_if_lt_x_gpu(2)
+    def test_device_map_gpu_to_cpu_non_default(self):
+        self._test_device_maps_gpu(
+            x_from=1,
+            y_from=1,
+            z_to="cpu",
+            device_map={1 : "cpu"},
+            fn=TensorPipeAgentCudaRpcTest._gpu_add_given_devices,
+        )
+
     @skip_if_lt_x_gpu(2)
     def test_device_map_gpu_default(self):
         self._test_device_maps_gpu(