Add guards for using named tensor with serialization and multiprocessing (#25345)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25345
Test Plan
- New tests [namedtensor ci]
Test Plan: Imported from OSS
Differential Revision: D17101486
Pulled By: zou3519
fbshipit-source-id: 58e803b042056ee6abab8551517f74078f2b81d5
diff --git a/test/test_namedtensor.py b/test/test_namedtensor.py
index a0568d4..e352682 100644
--- a/test/test_namedtensor.py
+++ b/test/test_namedtensor.py
@@ -7,6 +7,9 @@
import torch
from torch import Tensor
import torch.nn.functional as F
+from multiprocessing.reduction import ForkingPickler
+import pickle
+import io
import sys
@@ -163,6 +166,23 @@
none_named_tensor = torch.zeros(2, 3).names_(None, None)
self.assertEqual(repr(none_named_tensor), expected)
+ def test_no_save_support(self):
+ named_tensor = torch.zeros(2, 3, names=('N', 'C'))
+ buf = io.BytesIO()
+ with self.assertRaisesRegex(RuntimeError, "NYI"):
+ torch.save(named_tensor, buf)
+
+ def test_no_pickle_support(self):
+ named_tensor = torch.zeros(2, 3, names=('N', 'C'))
+ with self.assertRaisesRegex(RuntimeError, "NYI"):
+ serialized = pickle.dumps(named_tensor)
+
+ def test_no_multiprocessing_support(self):
+ named_tensor = torch.zeros(2, 3, names=('N', 'C'))
+ buf = io.BytesIO()
+ with self.assertRaisesRegex(RuntimeError, "NYI"):
+ ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(named_tensor)
+
def test_noncontig_contiguous(self):
# This type of contiguous is special-cased and therefore needs its own test
for device in torch.testing.get_all_device_types():
diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py
index e56004f..6273a05 100644
--- a/torch/multiprocessing/reductions.py
+++ b/torch/multiprocessing/reductions.py
@@ -1,5 +1,6 @@
import torch
import torch.utils.hooks
+from torch.namedtensor import _check_serializing_named_tensor
import os
import threading
import errno
@@ -137,6 +138,7 @@
"If you just want to transfer the data, call detach() on the tensor "
"before serializing (e.g., putting it on the queue).")
+ _check_serializing_named_tensor(tensor)
torch.utils.hooks.warn_if_has_hooks(tensor)
# Note [CUDA IPC and the caching allocator]
diff --git a/torch/namedtensor.py b/torch/namedtensor.py
index ff63e34..1ab41ed 100644
--- a/torch/namedtensor.py
+++ b/torch/namedtensor.py
@@ -14,6 +14,13 @@
'of our named tensors project.'.format(api_name))
+def _check_serializing_named_tensor(tensor):
+ if torch._C._BUILD_NAMEDTENSOR and tensor.has_names():
+ raise RuntimeError(
+ "NYI: Named tensors don't support serialization. Please drop "
+ "names before serialization and/or serialize them seperately.")
+
+
def _build_dim_map(tensor):
"""Returns a map of { dim: dim_name } where dim is a name if the dim is named
and the dim index otherwise."""
diff --git a/torch/tensor.py b/torch/tensor.py
index d4f90ff..a7724a4 100644
--- a/torch/tensor.py
+++ b/torch/tensor.py
@@ -1,7 +1,7 @@
import sys
import torch
import torch._C as _C
-from torch.namedtensor import _update_names
+from torch.namedtensor import _update_names, _check_serializing_named_tensor
from collections import OrderedDict
import torch.utils.hooks as hooks
import warnings
@@ -37,6 +37,7 @@
return new_tensor
def __reduce_ex__(self, proto):
+ _check_serializing_named_tensor(self)
# See Note [Don't serialize hooks]
torch.utils.hooks.warn_if_has_hooks(self)
if self.is_quantized: