Enable sharing meta tensors between processes (#129520)
Fixes #129436
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129520
Approved by: https://github.com/ezyang
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 336850b..4308295 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -1132,6 +1132,7 @@
"fd_id",
"init_reductions",
"rebuild_cuda_tensor",
+ "rebuild_meta_tensor",
"rebuild_event",
"rebuild_nested_tensor",
"rebuild_sparse_coo_tensor",
diff --git a/test/test_multiprocessing.py b/test/test_multiprocessing.py
index bcd6657..b12b54d 100644
--- a/test/test_multiprocessing.py
+++ b/test/test_multiprocessing.py
@@ -309,8 +309,9 @@
is_set = e.is_set()
self.assertTrue(is_set)
- self.assertTrue(data[0].eq(4).all())
- self.assertTrue(data[1].eq(4).all())
+ if device != "meta":
+ self.assertTrue(data[0].eq(4).all())
+ self.assertTrue(data[1].eq(4).all())
p.join(100)
self.assertFalse(p.is_alive())
@@ -326,12 +327,18 @@
t1 = q.get()
t2 = q.get()
- self.assertTrue(t1.eq(1).all())
+ if device == "meta":
+ self.assertEqual(t1.size(), t2.size())
+ else:
+ self.assertTrue(t1.eq(1).all())
s1 = t1.storage()
s2 = t2.storage()
self.assertEqual(type(s1), type(s2))
self.assertEqual(s1.data_ptr(), s1.data_ptr())
- self.assertEqual(s1, s2)
+ if device == "meta":
+ self.assertEqual(s1.size(), s2.size())
+ else:
+ self.assertEqual(s1, s2)
# We need to delete this tensors to allow producer (child process)
# collect them properly
@@ -857,6 +864,22 @@
self._test_empty_tensor_sharing(torch.float32, torch.device("cuda"))
self._test_empty_tensor_sharing(torch.int64, torch.device("cuda"))
+ def test_empty_tensor_sharing_meta(self):
+ self._test_empty_tensor_sharing(torch.float32, torch.device("meta"))
+ self._test_empty_tensor_sharing(torch.int64, torch.device("meta"))
+
+ def test_tensor_sharing_meta(self):
+ dtype = torch.float32
+ device = torch.device("meta")
+ q = mp.Queue()
+ empty = torch.tensor([1], dtype=dtype, device=device)
+ q.put(empty)
+ out = q.get(timeout=1)
+ self.assertEqual(out, empty)
+
+ def test_meta_simple(self):
+ self._test_sharing(mp.get_context("spawn"), "meta", torch.float)
+
def _test_autograd_sharing(self, var, ctx=mp, is_parameter=False):
device = "cuda" if var.is_cuda else "cpu"
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index cdb4e53..2b37b34 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -1925,8 +1925,10 @@
example_value = wrap_to_fake_tensor_and_record(
example_value, tx=tx, **kwargs
)
- if isinstance(example_value, torch.Tensor) and (
- maybe_get_fake_mode(example_value) is not tx.fake_mode
+ if (
+ isinstance(example_value, torch.Tensor)
+ and example_value.device.type != "meta"
+ and (maybe_get_fake_mode(example_value) is not tx.fake_mode)
):
raise InternalTorchDynamoError(
"`example_value` needs to be a `FakeTensor`"
diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py
index 727b4f3..0a6d3c8 100644
--- a/torch/multiprocessing/reductions.py
+++ b/torch/multiprocessing/reductions.py
@@ -120,6 +120,38 @@
return t
+def rebuild_meta_tensor(
+ tensor_cls,
+ tensor_size,
+ tensor_stride,
+ tensor_offset,
+ dtype,
+ storage_size_bytes,
+ requires_grad,
+):
+ untyped_storage = torch.UntypedStorage(storage_size_bytes, device="meta")
+
+ typed_storage = torch.TypedStorage(
+ wrap_storage=untyped_storage, dtype=dtype, _internal=True
+ )
+
+ t = torch._utils._rebuild_tensor(
+ typed_storage,
+ tensor_offset,
+ tensor_size,
+ tensor_stride,
+ )
+
+ if tensor_cls == torch.nn.parameter.Parameter:
+ # It is crucial for integer tensors to receive
+ # the requires_grad=False as an argument in the constructor
+ t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
+ else:
+ t.requires_grad = requires_grad
+
+ return t
+
+
def rebuild_cuda_tensor(
tensor_cls,
tensor_size,
@@ -344,6 +376,19 @@
event_sync_required,
),
)
+ elif storage._untyped_storage.device.type == "meta":
+ return (
+ rebuild_meta_tensor,
+ (
+ type(tensor),
+ tensor.size(),
+ tensor.stride(),
+ tensor.storage_offset(),
+ tensor.dtype,
+ tensor.untyped_storage().size(),
+ tensor.requires_grad,
+ ),
+ )
# _backward_hooks purposely omitted here, see Note [Don't serialize hooks]
metadata = (
@@ -554,6 +599,10 @@
raise RuntimeError(
"Cannot pickle CUDA storage; try pickling a CUDA tensor instead"
)
+ elif storage.device.type == "meta":
+ raise RuntimeError(
+ "Cannot pickle meta storage; try pickling a meta tensor instead"
+ )
elif get_sharing_strategy() == "file_system":
metadata = storage._share_filename_cpu_()
cache_key = metadata[1]