improve the restore device test, and relax the assertion (#14734)

Summary:
Only compare the device index if device has it.

Test the tensor restore with some computation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14734

Reviewed By: dzhulgakov

Differential Revision: D13317949

Pulled By: houseroad

fbshipit-source-id: 26b2f2912a9bbc3b660a62283fb403ddab437e49
diff --git a/test/test_jit.py b/test/test_jit.py
index 4cc6bcd..b75b133 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -505,12 +505,20 @@
 
     @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
     def test_restore_device_cuda(self):
-        m = torch.jit.ScriptModule()
+        class MyModule(torch.jit.ScriptModule):
+            def __init__(self):
+                super(MyModule, self).__init__(False)
+                self.register_buffer('b0', torch.randn(1, 3))
+                self.p0 = nn.Parameter(torch.randn(2, 3))
+
+            @torch.jit.script_method
+            def forward(self, x):
+                return x + self.b0 + self.p0
+
+        m = MyModule()
+        m.cuda(torch.cuda.device_count() - 1)
         cuda_device_str = 'cuda:' + str(torch.cuda.device_count() - 1)
-        m.p0 = nn.Parameter(torch.tensor([0.3], dtype=torch.float,
-                                         device=cuda_device_str))
-        m.register_buffer('b0', torch.tensor([0.9], dtype=torch.float,
-                                             device=cuda_device_str))
+
         self.assertTrue(m.p0.is_cuda)
         self.assertTrue(m.b0.is_cuda)
 
@@ -533,6 +541,13 @@
         self.assertEqual(str(m4.p0.device), 'cuda:0')
         self.assertEqual(str(m4.b0.device), 'cuda:0')
 
+        # compute and compare the results
+        input = torch.rand(2, 3).cuda(torch.cuda.device_count() - 1)
+        origin_result = m(input)
+        self.assertEqual(origin_result, m2(input))
+        self.assertEqual(origin_result, m3(input.cpu()))
+        self.assertEqual(origin_result, m4(input.cuda(0)))
+
     def test_typeas_trace_check(self):
         a = torch.tensor([0.4], requires_grad=True)
         b = torch.tensor([0.7], requires_grad=True)
diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp
index 20f4312..5cc139f 100644
--- a/torch/csrc/jit/import.cpp
+++ b/torch/csrc/jit/import.cpp
@@ -153,7 +153,16 @@
           at::DeviceTypeName(device.type(), false));
     }
   }
-  AT_ASSERT(storage_it->second.device() == device);
+  if (storage_it->second.device().type() != device.type() ||
+      (device.has_index() &&
+       storage_it->second.device().index() != device.index())) {
+    std::stringstream oss;
+    oss << "storage previously was specified with device "
+      << storage_it->second.device()
+      << "but now is specified with device "
+      << device << std::endl;
+    AT_ERROR(oss.str());
+  }
 
   at::Tensor result;
   if (device.type() == at::DeviceType::CPU) {