improve the restore device test, and relax the assertion (#14734)
Summary:
Only compare the device index if device has it.
Test the tensor restore with some computation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14734
Reviewed By: dzhulgakov
Differential Revision: D13317949
Pulled By: houseroad
fbshipit-source-id: 26b2f2912a9bbc3b660a62283fb403ddab437e49
diff --git a/test/test_jit.py b/test/test_jit.py
index 4cc6bcd..b75b133 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -505,12 +505,20 @@
@unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
def test_restore_device_cuda(self):
- m = torch.jit.ScriptModule()
+ class MyModule(torch.jit.ScriptModule):
+ def __init__(self):
+ super(MyModule, self).__init__(False)
+ self.register_buffer('b0', torch.randn(1, 3))
+ self.p0 = nn.Parameter(torch.randn(2, 3))
+
+ @torch.jit.script_method
+ def forward(self, x):
+ return x + self.b0 + self.p0
+
+ m = MyModule()
+ m.cuda(torch.cuda.device_count() - 1)
cuda_device_str = 'cuda:' + str(torch.cuda.device_count() - 1)
- m.p0 = nn.Parameter(torch.tensor([0.3], dtype=torch.float,
- device=cuda_device_str))
- m.register_buffer('b0', torch.tensor([0.9], dtype=torch.float,
- device=cuda_device_str))
+
self.assertTrue(m.p0.is_cuda)
self.assertTrue(m.b0.is_cuda)
@@ -533,6 +541,13 @@
self.assertEqual(str(m4.p0.device), 'cuda:0')
self.assertEqual(str(m4.b0.device), 'cuda:0')
+ # compute and compare the results
+ input = torch.rand(2, 3).cuda(torch.cuda.device_count() - 1)
+ origin_result = m(input)
+ self.assertEqual(origin_result, m2(input))
+ self.assertEqual(origin_result, m3(input.cpu()))
+ self.assertEqual(origin_result, m4(input.cuda(0)))
+
def test_typeas_trace_check(self):
a = torch.tensor([0.4], requires_grad=True)
b = torch.tensor([0.7], requires_grad=True)
diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp
index 20f4312..5cc139f 100644
--- a/torch/csrc/jit/import.cpp
+++ b/torch/csrc/jit/import.cpp
@@ -153,7 +153,16 @@
at::DeviceTypeName(device.type(), false));
}
}
- AT_ASSERT(storage_it->second.device() == device);
+ if (storage_it->second.device().type() != device.type() ||
+ (device.has_index() &&
+ storage_it->second.device().index() != device.index())) {
+ std::stringstream oss;
+ oss << "storage previously was specified with device "
+ << storage_it->second.device()
+ << "but now is specified with device "
+ << device << std::endl;
+ AT_ERROR(oss.str());
+ }
at::Tensor result;
if (device.type() == at::DeviceType::CPU) {