Fix the tensor deserialization problem of jit script module on CUDA (#16279)
Summary:
Now we create a temporary tensor for the whole record.
Fix https://github.com/pytorch/pytorch/issues/15271
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16279
Reviewed By: BIT-silence
Differential Revision: D13791442
Pulled By: houseroad
fbshipit-source-id: 6f52ca09627fb684f74121357cc42e4adadec36a
diff --git a/test/test_jit.py b/test/test_jit.py
index 6e0ad79..cd95236 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -648,6 +648,21 @@
self.assertEqual(origin_result, m3(input.cpu()))
self.assertEqual(origin_result, m4(input.cuda(0)))
+ @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
+ def test_restore_shared_storage_on_cuda(self):
+ whole_tensor = torch.randn(4, 5, dtype=torch.float, device='cpu')
+ m = torch.jit.ScriptModule()
+ m.p0 = nn.Parameter(whole_tensor.narrow(0, 0, 1))
+ m.register_buffer('b0', whole_tensor.narrow(0, 3, 1))
+ m2 = self.getExportImportCopy(m, map_location=torch.device('cuda:0'))
+ self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
+ self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
+ self.assertTrue(m2.p0.is_cuda)
+ self.assertTrue(m2.b0.is_cuda)
+ self.assertTrue(m2.p0.is_shared())
+ self.assertTrue(m2.b0.is_shared())
+ self.assertEqual(m2.b0.storage().data_ptr(), m2.p0.storage().data_ptr())
+
def test_typeas_trace_check(self):
a = torch.tensor([0.4], requires_grad=True)
b = torch.tensor([0.7], requires_grad=True)
diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp
index 5d129f2..729babc 100644
--- a/torch/csrc/jit/import.cpp
+++ b/torch/csrc/jit/import.cpp
@@ -157,7 +157,7 @@
} else if (device.type() == at::DeviceType::CUDA) {
at::Tensor cpu_tensor =
at::empty({0}, at::CPU(type).options())
- .set_(cpu_storage, tensor_proto.offset(), dims, strides);
+ .set_(cpu_storage);
at::Storage cuda_storage =
cpu_tensor.to(device, cpu_tensor.scalar_type()).storage();
storage_it =