[TensorPipe] Use Descriptor::Tensor::sourceDevice in tensorpipe_agent. (#55821)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55821

Test Plan: CI

Reviewed By: lw

Differential Revision: D27661608

fbshipit-source-id: fd241f073d8928528a749758c7d0f570dfeb677b
diff --git a/test/cpp/rpc/test_tensorpipe_serialization.cpp b/test/cpp/rpc/test_tensorpipe_serialization.cpp
index 5ad4e91..e43f65d 100644
--- a/test/cpp/rpc/test_tensorpipe_serialization.cpp
+++ b/test/cpp/rpc/test_tensorpipe_serialization.cpp
@@ -45,6 +45,7 @@
   for (auto& tpTensor : sendingTpMessage.tensors) {
     tensorpipe::Descriptor::Tensor t;
     t.length = tpTensor.length;
+    t.sourceDevice = tpTensor.buffer.device();
     t.metadata = tpTensor.metadata;
     recvingTpDescriptor.tensors.push_back(std::move(t));
   }
diff --git a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp
index b83c480..bb937ab 100644
--- a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp
+++ b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp
@@ -199,14 +199,15 @@
   for (size_t tensorIdx = 0; tensorIdx < numTensors; ++tensorIdx) {
     const tensorpipe::Descriptor::Tensor& tensor =
         tpDescriptor.tensors[tensorIdx];
-    if (tensor.buffer.deviceType() == tensorpipe::DeviceType::kCpu) {
+    if (tensor.sourceDevice.type == tensorpipe::kCpuDeviceType) {
       buffers.tensors.emplace_back(
           at::getCPUAllocator()->allocate(tensor.length));
       tensorpipe::CpuBuffer buffer;
       buffer.ptr = buffers.tensors.back().get();
       tpAllocation.tensors[tensorIdx].buffer = buffer;
 #ifdef USE_CUDA_NOT_ROCM
-    } else if (tensor.buffer.deviceType() == tensorpipe::DeviceType::kCuda) {
+    } else if (tensor.sourceDevice.type == tensorpipe::kCudaDeviceType) {
+      // TODO: This could be simply `tensor.targetDevice.value().index`.
       auto deviceIndex = std::stoi(tensor.metadata);
       auto stream = at::cuda::CUDAStream(ctx->getStream(deviceIndex));
       // CUDACachingAllocator will call recordStream accordingly on the current