Enable sharing meta tensors between processes (#129520)

Fixes #129436

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129520
Approved by: https://github.com/ezyang
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 336850b..4308295 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -1132,6 +1132,7 @@
     "fd_id",
     "init_reductions",
     "rebuild_cuda_tensor",
+    "rebuild_meta_tensor",
     "rebuild_event",
     "rebuild_nested_tensor",
     "rebuild_sparse_coo_tensor",
diff --git a/test/test_multiprocessing.py b/test/test_multiprocessing.py
index bcd6657..b12b54d 100644
--- a/test/test_multiprocessing.py
+++ b/test/test_multiprocessing.py
@@ -309,8 +309,9 @@
                 is_set = e.is_set()
 
             self.assertTrue(is_set)
-            self.assertTrue(data[0].eq(4).all())
-            self.assertTrue(data[1].eq(4).all())
+            if device != "meta":
+                self.assertTrue(data[0].eq(4).all())
+                self.assertTrue(data[1].eq(4).all())
 
             p.join(100)
             self.assertFalse(p.is_alive())
@@ -326,12 +327,18 @@
 
             t1 = q.get()
             t2 = q.get()
-            self.assertTrue(t1.eq(1).all())
+            if device == "meta":
+                self.assertEqual(t1.size(), t2.size())
+            else:
+                self.assertTrue(t1.eq(1).all())
             s1 = t1.storage()
             s2 = t2.storage()
             self.assertEqual(type(s1), type(s2))
             self.assertEqual(s1.data_ptr(), s1.data_ptr())
-            self.assertEqual(s1, s2)
+            if device == "meta":
+                self.assertEqual(s1.size(), s2.size())
+            else:
+                self.assertEqual(s1, s2)
 
             # We need to delete this tensors to allow producer (child process)
             # collect them properly
@@ -857,6 +864,22 @@
         self._test_empty_tensor_sharing(torch.float32, torch.device("cuda"))
         self._test_empty_tensor_sharing(torch.int64, torch.device("cuda"))
 
+    def test_empty_tensor_sharing_meta(self):
+        self._test_empty_tensor_sharing(torch.float32, torch.device("meta"))
+        self._test_empty_tensor_sharing(torch.int64, torch.device("meta"))
+
+    def test_tensor_sharing_meta(self):
+        dtype = torch.float32
+        device = torch.device("meta")
+        q = mp.Queue()
+        empty = torch.tensor([1], dtype=dtype, device=device)
+        q.put(empty)
+        out = q.get(timeout=1)
+        self.assertEqual(out, empty)
+
+    def test_meta_simple(self):
+        self._test_sharing(mp.get_context("spawn"), "meta", torch.float)
+
     def _test_autograd_sharing(self, var, ctx=mp, is_parameter=False):
         device = "cuda" if var.is_cuda else "cpu"
 
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index cdb4e53..2b37b34 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -1925,8 +1925,10 @@
             example_value = wrap_to_fake_tensor_and_record(
                 example_value, tx=tx, **kwargs
             )
-        if isinstance(example_value, torch.Tensor) and (
-            maybe_get_fake_mode(example_value) is not tx.fake_mode
+        if (
+            isinstance(example_value, torch.Tensor)
+            and example_value.device.type != "meta"
+            and (maybe_get_fake_mode(example_value) is not tx.fake_mode)
         ):
             raise InternalTorchDynamoError(
                 "`example_value` needs to be a `FakeTensor`"
diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py
index 727b4f3..0a6d3c8 100644
--- a/torch/multiprocessing/reductions.py
+++ b/torch/multiprocessing/reductions.py
@@ -120,6 +120,38 @@
     return t
 
 
+def rebuild_meta_tensor(
+    tensor_cls,
+    tensor_size,
+    tensor_stride,
+    tensor_offset,
+    dtype,
+    storage_size_bytes,
+    requires_grad,
+):
+    untyped_storage = torch.UntypedStorage(storage_size_bytes, device="meta")
+
+    typed_storage = torch.TypedStorage(
+        wrap_storage=untyped_storage, dtype=dtype, _internal=True
+    )
+
+    t = torch._utils._rebuild_tensor(
+        typed_storage,
+        tensor_offset,
+        tensor_size,
+        tensor_stride,
+    )
+
+    if tensor_cls == torch.nn.parameter.Parameter:
+        # It is crucial for integer tensors to receive
+        # the requires_grad=False as an argument in the constructor
+        t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
+    else:
+        t.requires_grad = requires_grad
+
+    return t
+
+
 def rebuild_cuda_tensor(
     tensor_cls,
     tensor_size,
@@ -344,6 +376,19 @@
                 event_sync_required,
             ),
         )
+    elif storage._untyped_storage.device.type == "meta":
+        return (
+            rebuild_meta_tensor,
+            (
+                type(tensor),
+                tensor.size(),
+                tensor.stride(),
+                tensor.storage_offset(),
+                tensor.dtype,
+                tensor.untyped_storage().size(),
+                tensor.requires_grad,
+            ),
+        )
 
     # _backward_hooks purposely omitted here, see Note [Don't serialize hooks]
     metadata = (
@@ -554,6 +599,10 @@
         raise RuntimeError(
             "Cannot pickle CUDA storage; try pickling a CUDA tensor instead"
         )
+    elif storage.device.type == "meta":
+        raise RuntimeError(
+            "Cannot pickle meta storage; try pickling a meta tensor instead"
+        )
     elif get_sharing_strategy() == "file_system":
         metadata = storage._share_filename_cpu_()
         cache_key = metadata[1]