Virtualize `<type>Storage` classes (#66970)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/66228
cc ezyang bhosmer smessmer ljk53 bdhirsh
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66970
Reviewed By: bdhirsh
Differential Revision: D33245612
Pulled By: ezyang
fbshipit-source-id: 4c61c2cb029e2b94b0e68927c377d3e1c358dd7c
(cherry picked from commit d29fcdfb4bc2cc17b1795d4349e4b56fa0d1cf12)
diff --git a/docs/source/storage.rst b/docs/source/storage.rst
index 3aeec08..747acf1 100644
--- a/docs/source/storage.rst
+++ b/docs/source/storage.rst
@@ -1,87 +1,96 @@
torch.Storage
===================================
-A :class:`torch.Storage` is a contiguous, one-dimensional array of a single
-data type.
+A :class:`torch._TypedStorage` is a contiguous, one-dimensional array of
+elements of a particular :class:`torch.dtype`. It can be given any
+:class:`torch.dtype`, and the internal data will be interpretted appropriately.
-Every :class:`torch.Tensor` has a corresponding storage of the same data type.
+Every strided :class:`torch.Tensor` contains a :class:`torch._TypedStorage`,
+which stores all of the data that the :class:`torch.Tensor` views.
+
+For backward compatibility, there are also :class:`torch.<type>Storage` classes
+(like :class:`torch.FloatStorage`, :class:`torch.IntStorage`, etc). These
+classes are not actually instantiated, and calling their constructors creates
+a :class:`torch._TypedStorage` with the appropriate :class:`torch.dtype`.
+:class:`torch.<type>Storage` classes have all of the same class methods that
+:class:`torch._TypedStorage` has.
+
+Also for backward compatibility, :class:`torch.Storage` is an alias for the
+storage class that corresponds with the default data type
+(:func:`torch.get_default_dtype()`). For instance, if the default data type is
+:attr:`torch.float`, :class:`torch.Storage` resolves to
+:class:`torch.FloatStorage`.
+
+
+.. autoclass:: torch._TypedStorage
+ :members:
+ :undoc-members:
+ :inherited-members:
.. autoclass:: torch.DoubleStorage
:members:
:undoc-members:
- :inherited-members:
.. autoclass:: torch.FloatStorage
:members:
:undoc-members:
- :inherited-members:
.. autoclass:: torch.HalfStorage
:members:
:undoc-members:
- :inherited-members:
.. autoclass:: torch.LongStorage
:members:
:undoc-members:
- :inherited-members:
.. autoclass:: torch.IntStorage
:members:
:undoc-members:
- :inherited-members:
.. autoclass:: torch.ShortStorage
:members:
:undoc-members:
- :inherited-members:
.. autoclass:: torch.CharStorage
:members:
:undoc-members:
- :inherited-members:
.. autoclass:: torch.ByteStorage
:members:
:undoc-members:
- :inherited-members:
.. autoclass:: torch.BoolStorage
:members:
:undoc-members:
- :inherited-members:
.. autoclass:: torch.BFloat16Storage
:members:
:undoc-members:
- :inherited-members:
.. autoclass:: torch.ComplexDoubleStorage
:members:
:undoc-members:
- :inherited-members:
.. autoclass:: torch.ComplexFloatStorage
:members:
:undoc-members:
- :inherited-members:
.. autoclass:: torch.QUInt8Storage
:members:
:undoc-members:
- :inherited-members:
.. autoclass:: torch.QInt8Storage
:members:
:undoc-members:
- :inherited-members:
.. autoclass:: torch.QInt32Storage
:members:
:undoc-members:
- :inherited-members:
.. autoclass:: torch.QUInt4x2Storage
:members:
:undoc-members:
- :inherited-members:
+
+.. autoclass:: torch.QUInt2x4Storage
+ :members:
+ :undoc-members:
diff --git a/test/test_torch.py b/test/test_torch.py
index 1917c9a..2ab2b69 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -221,6 +221,16 @@
@onlyNativeDeviceTypes
@dtypes(*get_all_dtypes())
+ def test_tensor_storage_type(self, device, dtype):
+ a = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9)
+
+ module = torch.cuda if (torch.device(device).type == 'cuda') else torch
+ expected_storage_type = getattr(module, torch.storage._dtype_to_storage_type_map()[dtype])
+
+ self.assertEqual(a.storage_type(), expected_storage_type)
+
+ @onlyNativeDeviceTypes
+ @dtypes(*get_all_dtypes())
def test_tensor_from_storage(self, device, dtype):
a = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9)
a_s = a.storage()
@@ -6171,6 +6181,7 @@
self.assertEqual(bools.size(), 8)
self.assertEqual(bools.tolist(), [False, True, True, True, True, True, True, True])
self.assertEqual(bools.type(), 'torch.BoolStorage')
+ self.assertTrue(isinstance(bools, torch.BoolStorage))
f = bytearray(b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9')
bools = torch.BoolStorage.from_buffer(f, 'big')
@@ -6183,6 +6194,122 @@
bytes = torch.ByteStorage.from_buffer(a)
self.assertEqual(bytes.nbytes(), 4)
self.assertEqual(bytes.tolist(), [1, 2, 3, 4])
+ self.assertTrue(isinstance(bytes, torch.ByteStorage))
+
+ def test_storage_error(self):
+ quantized_storages = [
+ torch.QInt32Storage,
+ torch.QInt8Storage,
+ torch.QUInt2x4Storage,
+ torch.QUInt4x2Storage,
+ torch.QUInt8Storage,
+ ]
+
+ with self.assertRaisesRegex(RuntimeError, r"Only child classes of _LegacyStorage can be instantiated"):
+ torch.storage._LegacyStorage()
+
+ for storage_class in torch._storage_classes:
+ if storage_class in [torch._UntypedStorage, torch.cuda._UntypedStorage, torch._TypedStorage]:
+ continue
+
+ device = 'cuda' if storage_class.__module__ == 'torch.cuda' else 'cpu'
+ dtype = storage_class.dtype
+
+ if device == 'cuda' and not torch.cuda.is_available():
+ continue
+
+ # Legacy <type>Storage constructor errors
+ with self.assertRaisesRegex(RuntimeError, r"'device' cannot be specified"):
+ storage_class(device='cpu')
+
+ with self.assertRaisesRegex(RuntimeError, r"'dtype' cannot be specified"):
+ storage_class(dtype=torch.float)
+
+ with self.assertRaisesRegex(TypeError, r"got an unexpected keyword"):
+ storage_class(sdlkjf=torch.float)
+
+ with self.assertRaisesRegex(RuntimeError, r"Too many positional arguments"):
+ storage_class(0, 0)
+
+ with self.assertRaisesRegex(TypeError, r"invalid data type"):
+ storage_class('string')
+
+ with self.assertRaisesRegex(TypeError, r"Argument type not recognized"):
+ storage_class(torch.tensor([]))
+
+ s = storage_class()
+
+ with self.assertRaisesRegex(RuntimeError, r"No positional arguments"):
+ storage_class(0, wrap_storage=s._untyped())
+
+ with self.assertRaisesRegex(TypeError, r"must be _UntypedStorage"):
+ storage_class(wrap_storage=s)
+
+ if torch.cuda.is_available():
+ if storage_class in quantized_storages:
+ with self.assertRaisesRegex(RuntimeError, r"Cannot create CUDA storage with quantized dtype"):
+ s.cuda()
+
+ else:
+
+ if s.is_cuda:
+ s_other_device = s.cpu()
+ else:
+ s_other_device = s.cuda()
+
+ with self.assertRaisesRegex(RuntimeError, r"Device of 'wrap_storage' must be"):
+ storage_class(wrap_storage=s_other_device._untyped())
+
+ # _TypedStorage constructor errors
+ with self.assertRaisesRegex(RuntimeError, r"No positional arguments"):
+ torch._TypedStorage(0, wrap_storage=s._untyped(), dtype=dtype)
+
+ with self.assertRaisesRegex(RuntimeError, r"Argument 'dtype' must be specified"):
+ torch._TypedStorage(wrap_storage=s._untyped())
+
+ with self.assertRaisesRegex(TypeError, r"Argument 'dtype' must be torch.dtype"):
+ torch._TypedStorage(wrap_storage=s._untyped(), dtype=0)
+
+ with self.assertRaisesRegex(RuntimeError, r"Argument 'device' should not be specified"):
+ torch._TypedStorage(wrap_storage=s._untyped(), dtype=dtype, device=device)
+
+ with self.assertRaisesRegex(TypeError, r"Argument 'wrap_storage' must be _UntypedStorage"):
+ torch._TypedStorage(wrap_storage=s, dtype=dtype)
+
+ with self.assertRaisesRegex(RuntimeError, r"Storage device not recognized"):
+ torch._TypedStorage(dtype=dtype, device='xla')
+
+ if torch.cuda.is_available():
+ if storage_class in quantized_storages:
+ with self.assertRaisesRegex(RuntimeError, r"Cannot create CUDA storage with quantized dtype"):
+ torch._TypedStorage(dtype=dtype, device='cuda')
+
+ with self.assertRaisesRegex(TypeError, r"Argument type not recognized"):
+ torch._TypedStorage(torch.tensor([]), dtype=dtype, device=device)
+
+ with self.assertRaisesRegex(RuntimeError, r"Too many positional arguments"):
+ torch._TypedStorage(0, 0, dtype=dtype, device=device)
+
+ def test_storage_error_no_attribute(self):
+ storage_classes = [
+ torch.cuda.ByteStorage,
+ torch.cuda.FloatStorage,
+ torch.cuda._UntypedStorage,
+ ]
+ for storage_class in storage_classes:
+ with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
+ storage_class.from_buffer()
+
+ if storage_class == torch.cuda._UntypedStorage:
+ with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
+ storage_class._new_with_weak_ptr()
+
+ else:
+ with self.assertRaisesRegex(AttributeError, r'has no attribute'):
+ storage_class._new_with_weak_ptr()
+
+ with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
+ storage_class._new_shared_filename(0, 0, 0)
def test_storage_casts(self):
storage = torch.IntStorage([-1, 0, 1, 2, 3, 4])
diff --git a/torch/__init__.py b/torch/__init__.py
index ac47103..cf7f592 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -39,6 +39,7 @@
'no_grad', 'enable_grad', 'rand', 'randn', 'inference_mode',
'DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage',
+ '_TypedStorage',
'DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor', 'Tensor',
'lobpcg', 'use_deterministic_algorithms',
@@ -594,7 +595,7 @@
################################################################################
from ._tensor import Tensor
-from .storage import _StorageBase, _TypedStorage
+from .storage import _StorageBase, _TypedStorage, _LegacyStorage
# NOTE: New <type>Storage classes should never be added. When adding a new
# dtype, use torch.storage._TypedStorage directly.
@@ -602,87 +603,87 @@
class _UntypedStorage(_C.ByteStorageBase, _StorageBase):
pass
-class ByteStorage(_TypedStorage):
+class ByteStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.uint8
-class DoubleStorage(_TypedStorage):
+class DoubleStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.double
-class FloatStorage(_TypedStorage):
+class FloatStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.float
-class HalfStorage(_TypedStorage):
+class HalfStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.half
-class LongStorage(_TypedStorage):
+class LongStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.long
-class IntStorage(_TypedStorage):
+class IntStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.int
-class ShortStorage(_TypedStorage):
+class ShortStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.short
-class CharStorage(_TypedStorage):
+class CharStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.int8
-class BoolStorage(_TypedStorage):
+class BoolStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.bool
-class BFloat16Storage(_TypedStorage):
+class BFloat16Storage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.bfloat16
-class ComplexDoubleStorage(_TypedStorage):
+class ComplexDoubleStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.cdouble
-class ComplexFloatStorage(_TypedStorage):
+class ComplexFloatStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.cfloat
-class QUInt8Storage(_TypedStorage):
+class QUInt8Storage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.quint8
-class QInt8Storage(_TypedStorage):
+class QInt8Storage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.qint8
-class QInt32Storage(_TypedStorage):
+class QInt32Storage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.qint32
-class QUInt4x2Storage(_TypedStorage):
+class QUInt4x2Storage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.quint4x2
-class QUInt2x4Storage(_TypedStorage):
+class QUInt2x4Storage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.quint2x4
@@ -692,6 +693,7 @@
ShortStorage, CharStorage, ByteStorage, HalfStorage, BoolStorage,
QUInt8Storage, QInt8Storage, QInt32Storage, BFloat16Storage,
ComplexFloatStorage, ComplexDoubleStorage, QUInt4x2Storage, QUInt2x4Storage,
+ _TypedStorage
}
# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
diff --git a/torch/_tensor.py b/torch/_tensor.py
index 4d690c0..99a604a 100644
--- a/torch/_tensor.py
+++ b/torch/_tensor.py
@@ -202,11 +202,7 @@
if self.dtype not in torch.storage._dtype_to_storage_type_map():
raise RuntimeError(f'unsupported Storage type: {self.dtype}')
- storage = self._storage()
- storage_name = torch.storage._dtype_to_storage_type_map()[self.dtype]
- storage_class = eval(type(storage).__module__ + '.' + storage_name)
- storage = storage_class(wrap_storage=storage)
- return storage
+ return torch._TypedStorage(wrap_storage=self._storage(), dtype=self.dtype)
def _reduce_ex_internal(self, proto):
check_serializing_named_tensor(self)
@@ -866,10 +862,7 @@
Returns the type of the underlying storage.
"""
- # NB: this returns old fashioned _TypedStorage, e.g., FloatStorage, as it
- # would be pretty pointless otherwise (it would always return
- # _UntypedStorage)
- return type(self.storage())
+ return self.storage()._get_legacy_storage_class()
def refine_names(self, *names):
r"""Refines the dimension names of :attr:`self` according to :attr:`names`.
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index ac7026e..ab7338f 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -674,67 +674,77 @@
__new__ = _lazy_new
-from torch.storage import _TypedStorage
+from torch.storage import _TypedStorage, _LegacyStorage
class _UntypedStorage(_CudaBase, torch._C.CudaByteStorageBase, _StorageBase):
- pass
+ @classmethod
+ def from_buffer(cls, *args, **kwargs):
+ raise RuntimeError('from_buffer: Not available for CUDA storage')
-class ByteStorage(_TypedStorage):
+ @classmethod
+ def _new_with_weak_ptr(cls, *args, **kwargs):
+ raise RuntimeError('_new_with_weak_ptr: Not available for CUDA storage')
+
+ @classmethod
+ def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None):
+ raise RuntimeError('_new_shared_filename: Not available for CUDA storage')
+
+class ByteStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.uint8
-class DoubleStorage(_TypedStorage):
+class DoubleStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.double
-class FloatStorage(_TypedStorage):
+class FloatStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.float
-class HalfStorage(_TypedStorage):
+class HalfStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.half
-class LongStorage(_TypedStorage):
+class LongStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.long
-class IntStorage(_TypedStorage):
+class IntStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.int
-class ShortStorage(_TypedStorage):
+class ShortStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.short
-class CharStorage(_TypedStorage):
+class CharStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.int8
-class BoolStorage(_TypedStorage):
+class BoolStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.bool
-class BFloat16Storage(_TypedStorage):
+class BFloat16Storage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.bfloat16
-class ComplexDoubleStorage(_TypedStorage):
+class ComplexDoubleStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.cdouble
-class ComplexFloatStorage(_TypedStorage):
+class ComplexFloatStorage(_LegacyStorage):
@classproperty
def dtype(self):
return torch.cfloat
diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py
index 2da1ab8..4a5d725 100644
--- a/torch/multiprocessing/reductions.py
+++ b/torch/multiprocessing/reductions.py
@@ -6,6 +6,8 @@
import multiprocessing
from multiprocessing.util import register_after_fork
from multiprocessing.reduction import ForkingPickler
+from typing import Union
+
try:
# Early load resource_sharer to prevent a partially initialized instance
# from being inherited in a forked child process. The reduce_storage method
@@ -103,7 +105,7 @@
requires_grad, ref_counter_handle, ref_counter_offset, event_handle, event_sync_required):
# If storage_handle is None, storage points to nullptr.
if storage_handle is None or storage_size_bytes == 0:
- storage = storage_cls(0)
+ storage = storage_cls(0, dtype=dtype, device=storage_device)
else:
storage = storage_from_cache(storage_cls, (storage_handle, storage_offset_bytes))
if storage is None:
@@ -120,7 +122,7 @@
shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(storage)
else:
# We already ref counting this Storage, but producer needs new ref-counters to be released.
- storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset)
+ storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset, device=storage_device)
t = torch._utils._rebuild_tensor(
torch.storage._TypedStorage(wrap_storage=storage._untyped(), dtype=dtype),
@@ -288,7 +290,7 @@
storage_ref = shared_cache.get(key)
if storage_ref is None:
return None
- return cls._new_with_weak_ptr(storage_ref.cdata)
+ return torch._UntypedStorage._new_with_weak_ptr(storage_ref.cdata)
def rebuild_storage_fd(cls, df, size):
@@ -304,11 +306,18 @@
os.close(fd)
-def rebuild_storage_filename(cls, manager, handle, size):
- storage = storage_from_cache(cls, handle)
+def rebuild_storage_filename(cls, manager, handle, size, dtype=None):
+ storage: Union[torch._TypedStorage, torch._UntypedStorage] = storage_from_cache(cls, handle)
if storage is not None:
return storage._shared_decref()
- storage = cls._new_shared_filename(manager, handle, size)
+ if dtype is None:
+ storage = torch._UntypedStorage._new_shared_filename(manager, handle, size)
+ else:
+ byte_size = size * torch._utils._element_size(dtype)
+ untyped_storage: torch._UntypedStorage = torch._UntypedStorage._new_shared_filename(manager, handle, byte_size)
+ storage = torch._TypedStorage(
+ wrap_storage=untyped_storage,
+ dtype=dtype)
shared_cache[handle] = StorageWeakRef(storage)
return storage._shared_decref()
@@ -338,6 +347,8 @@
metadata = storage._share_filename_()
cache_key = metadata[1]
rebuild = rebuild_storage_filename
+ if isinstance(storage, torch._TypedStorage):
+ metadata += (storage.dtype,)
storage._shared_incref()
elif storage.size() == 0:
# This is special cased because Empty tensors
diff --git a/torch/storage.py b/torch/storage.py
index 54e8df5..2537d7a 100644
--- a/torch/storage.py
+++ b/torch/storage.py
@@ -7,6 +7,11 @@
import copy
import collections
from functools import lru_cache
+try:
+ import numpy as np
+ HAS_NUMPY = True
+except ModuleNotFoundError:
+ np = None # type: ignore[assignment]
T = TypeVar('T', bound='Union[_StorageBase, _TypedStorage]')
class _StorageBase(object):
@@ -38,6 +43,15 @@
def _new_using_filename(cls: Type[T], size: int) -> T: ... # noqa: E704
@classmethod
def _new_using_fd(cls: Type[T], size: int) -> T: ... # noqa: E704
+ @classmethod
+ def from_buffer(cls, *args, **kwargs) -> T: ... # noqa: E704
+ @classmethod
+ def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None) -> T: ... # noqa: E704
+ @classmethod
+ def _release_ipc_counter(cls, *args, **kwargs) -> T: ... # noqa: E704
+ @classmethod
+ def _new_with_weak_ptr(cls, *args, **kwargs) -> T: ... # noqa: E704
+ def _shared_decref(self) -> T: ... # noqa: E704
def __str__(self):
content = ' ' + '\n '.join(str(self[i]) for i in range(len(self)))
@@ -83,6 +97,8 @@
return _type(self, getattr(torch, self.__class__.__name__))
def _to(self, dtype):
+ if not isinstance(dtype, torch.dtype):
+ raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
storage = torch.tensor([], dtype=torch.uint8, device=self.device).set_(cast(Storage, self)).to(dtype).storage()
if storage.data_ptr() == self.data_ptr():
storage = storage.clone()
@@ -187,6 +203,11 @@
@lru_cache(maxsize=None)
def _dtype_to_storage_type_map():
+ # NOTE: We should no longer add dtypes to this map. This map
+ # is only used for BC/FC with older PyTorch versions. Going forward,
+ # new dtypes of _TypedStorage should not translate to a legacy
+ # <type>Storage class. Instead, new dtypes of _TypedStorage should
+ # be serialized as an _UntypedStorage paired with a torch.dtype
return {
torch.double: 'DoubleStorage',
torch.float: 'FloatStorage',
@@ -213,104 +234,183 @@
val: key for key, val in _dtype_to_storage_type_map().items()}
return dtype_map
+def _get_storage_from_sequence(sequence, dtype, device):
+ if dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
+ interpret_dtypes = {
+ torch.quint8: torch.uint8,
+ torch.quint4x2: torch.uint8,
+ torch.quint2x4: torch.uint8,
+ torch.qint32: torch.int32,
+ torch.qint8: torch.int8
+ }
+ tmp_tensor = torch.tensor(
+ sequence,
+ dtype=interpret_dtypes[dtype],
+ device=device)
+
+ else:
+ tmp_tensor = torch.tensor(
+ sequence,
+ dtype=dtype,
+ device=device)
+
+ return tmp_tensor.storage()._untyped()
+
+def _isint(x):
+ if HAS_NUMPY:
+ return isinstance(x, (int, np.integer))
+ else:
+ return isinstance(x, int)
+
class _TypedStorage:
is_sparse = False
+ dtype: torch.dtype
+
def fill_(self, value):
self[0:len(self)] = value
return self
- def __init__(self, *args, **kwargs):
- arg_error_msg = (
- f'{type(self)} constructor received an invalid combination '
- f'of arguments - got args={tuple(type(arg) for arg in args)}, '
- f'kwargs={ {key: type(val) for key, val in kwargs.items()} }, but '
- 'expected one of:\n'
- ' * no arguments\n'
- ' * (int size)\n'
- ' * (Sequence data)\n')
- if type(self) == _TypedStorage:
- arg_error_msg += ' * (wrap_storage=<_UntypedStorage>, dtype=<torch.dtype>)'
+ def __new__(cls, *args, wrap_storage=None, dtype=None, device=None):
+ if cls == torch.storage._LegacyStorage:
+ raise RuntimeError("Only child classes of _LegacyStorage can be instantiated")
+
+ if cls == _TypedStorage:
+ return super().__new__(cls)
+
else:
- arg_error_msg += ' * (wrap_storage=<_UntypedStorage>)'
+ arg_error_msg = (
+ f'{cls}.__new__ received an invalid combination '
+ f'of arguments. Expected one of:\n'
+ ' * no arguments\n'
+ ' * (int size)\n'
+ ' * (Sequence data)\n'
+ ' * (*, _UntypedStorage wrap_storage)')
- if 'wrap_storage' in kwargs:
- assert len(args) == 0, (
- "No positional arguments should be given when using "
- "'wrap_storage'")
+ if device is not None:
+ raise RuntimeError(
+ arg_error_msg +
+ "\nKeyword argument 'device' cannot be specified")
- if type(self) == _TypedStorage:
- assert 'dtype' in kwargs, (
- "When using 'wrap_storage', 'dtype' also must be specified")
- assert len(kwargs) == 2, (
- "Only 'wrap_storage' and 'dtype' should be given, but got: "
- f"{kwargs}")
- dtype = kwargs['dtype']
- assert isinstance(dtype, torch.dtype)
- self.dtype = dtype
+ if dtype is not None:
+ raise RuntimeError(
+ arg_error_msg +
+ "\nKeyword argument 'dtype' cannot be specified")
+
+ if wrap_storage is None:
+ if len(args) > 1:
+ raise RuntimeError(
+ arg_error_msg +
+ "\nToo many positional arguments")
+
+ if len(args) == 1 and not _isint(args[0]) and not isinstance(args[0], collections.abc.Sequence):
+ raise TypeError(
+ arg_error_msg +
+ f"\nArgument type not recognized: {type(args[0])}")
+
+ return _TypedStorage(
+ *args,
+ dtype=cls.dtype,
+ device='cuda' if eval(cls.__module__) is torch.cuda else 'cpu')
else:
- assert hasattr(self, 'dtype')
- assert len(kwargs) == 1, (
- f"Only 'wrap_storage' should be given, but got: {kwargs.keys()}")
- dtype = self.dtype
+ if len(args) != 0:
+ raise RuntimeError(
+ arg_error_msg +
+ "\nNo positional arguments should be given when using "
+ "'wrap_storage'")
- storage = kwargs['wrap_storage']
+ if not isinstance(wrap_storage, (torch._UntypedStorage, torch.cuda._UntypedStorage)):
+ raise TypeError(
+ arg_error_msg +
+ f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}")
- if not isinstance(storage, (torch._UntypedStorage, torch.cuda._UntypedStorage)):
- raise TypeError(arg_error_msg)
- if type(self) != _TypedStorage and storage.__module__ != self.__module__:
- raise TypeError((
+ cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
+
+ if wrap_storage.device.type != cls_device:
+ raise RuntimeError(
+ arg_error_msg +
+ f"\nDevice of 'wrap_storage' must be {cls_device}"
+ f", but got {wrap_storage.device.type}")
+
+ return _TypedStorage(
+ *args,
+ wrap_storage=wrap_storage,
+ dtype=cls.dtype)
+
+ def __init__(self, *args, device=None, dtype=None, wrap_storage=None):
+ arg_error_msg = (
+ '_TypedStorage.__init__ received an invalid combination '
+ 'of arguments. Expected one of:\n'
+ ' * (*, torch.device device, torch.dtype dtype)\n'
+ ' * (int size, *, torch.device device, torch.dtype dtype)\n'
+ ' * (Sequence data, *, torch.device device, torch.dtype dtype)\n'
+ ' * (*, _UntypedStorage wrap_storage, torch.dtype dtype)')
+
+ if wrap_storage is not None:
+ if len(args) != 0:
+ raise RuntimeError(
arg_error_msg +
- f'\n`storage` `module {storage.__module__}` does not match '
- f'module of {type(self)}'))
- self._storage = storage
+ "\nNo positional arguments should be given when using "
+ "'wrap_storage'")
+
+ if dtype is None:
+ raise RuntimeError(
+ arg_error_msg +
+ "\nArgument 'dtype' must be specified")
+
+ if not isinstance(dtype, torch.dtype):
+ raise TypeError(
+ arg_error_msg +
+ f"\nArgument 'dtype' must be torch.dtype, not {type(dtype)}")
+
+ if device is not None:
+ raise RuntimeError(
+ arg_error_msg +
+ "\nArgument 'device' should not be specified when 'wrap_storage' is given")
+
+ self.dtype = dtype
+
+ if not isinstance(wrap_storage, (torch._UntypedStorage, torch.cuda._UntypedStorage)):
+ raise TypeError(
+ arg_error_msg +
+ f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}")
+
+ self._storage = wrap_storage
else:
- assert type(self) != _TypedStorage, (
- "Calling __init__ this way is only supported in _TypedStorage's "
- "child classes. _TypedStorage can only be directly instantiated "
- "when kwargs 'wrap_storage' and 'dtype' are given.")
+ self.dtype = torch.get_default_dtype() if dtype is None else dtype
+ device = torch.device('cpu' if device is None else device)
- assert len(kwargs) == 0, "invalid keyword arguments"
+ if device.type == 'cpu':
+ untyped_storage_class = torch._UntypedStorage
+ elif device.type == 'cuda':
+ untyped_storage_class = torch.cuda._UntypedStorage
+ else:
+ raise RuntimeError(f"Storage device not recognized: {device}")
- def isint(x):
- try:
- int(x)
- except TypeError:
- return False
- return True
+ if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
+ if device.type == 'cuda':
+ raise RuntimeError("Cannot create CUDA storage with quantized dtype")
if len(args) == 0:
- self._storage = eval(self.__module__)._UntypedStorage()
+ self._storage = untyped_storage_class()
- elif len(args) == 1 and isint(args[0]):
- self._storage = eval(self.__module__)._UntypedStorage(int(args[0]) * self.element_size())
-
- elif len(args) == 1 and isinstance(args[0], collections.abc.Sequence):
- if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
- interpret_dtypes = {
- torch.quint8: torch.uint8,
- torch.quint4x2: torch.uint8,
- torch.quint2x4: torch.uint8,
- torch.qint32: torch.int32,
- torch.qint8: torch.int8
- }
- tmp_tensor = torch.tensor(
- args[0],
- dtype=interpret_dtypes[self.dtype],
- device='cuda' if eval(self.__module__) is torch.cuda else 'cpu')
-
+ elif len(args) == 1:
+ if _isint(args[0]):
+ self._storage = untyped_storage_class(int(args[0]) * self.element_size())
+ elif isinstance(args[0], collections.abc.Sequence):
+ self._storage = _get_storage_from_sequence(args[0], self.dtype, device)
else:
- tmp_tensor = torch.tensor(
- args[0],
- dtype=self.dtype,
- device='cuda' if eval(self.__module__) is torch.cuda else 'cpu')
-
- self._storage = tmp_tensor.storage()._untyped()
+ raise TypeError(
+ arg_error_msg +
+ f"\nArgument type not recognized: {type(args[0])}")
else:
- raise TypeError(arg_error_msg)
+ raise RuntimeError(
+ arg_error_msg +
+ "\nToo many positional arguments")
+
@property
def is_cuda(self):
@@ -414,11 +514,19 @@
def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]:
if dtype is None:
+ legacy_class = self._get_legacy_storage_class()
+
+ if legacy_class is not None:
+ return legacy_class.__module__ + '.' + legacy_class.__name__
+
return '.'.join([self.__module__, type(self).__name__])
+
else:
return self._storage.type(dtype, non_blocking)
def cuda(self, device=None, non_blocking=False, **kwargs) -> T:
+ if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
+ raise RuntimeError("Cannot create CUDA storage with quantized dtype")
cuda_storage = self._storage.cuda(device, non_blocking, **kwargs)
return self._new_wrapped_storage(cuda_storage)
@@ -430,12 +538,9 @@
def __str__(self):
data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size()))
- if type(self) == _TypedStorage:
- return data_str + (
- f'\n[{torch.typename(self)} with dtype {self.dtype} '
- f'of size {len(self)}]')
- else:
- return data_str + f'\n[{torch.typename(self)} of size {len(self)}]'
+ return data_str + (
+ f'\n[{torch.typename(self)}(dtype={self.dtype}, '
+ f'device={self.device}) of size {len(self)}]')
def __repr__(self):
return str(self)
@@ -480,12 +585,16 @@
self._storage.share_memory_()
return self
- @classmethod
- def _new_shared(cls, size):
+ def _new_shared(self, size):
"""Creates a new storage in shared memory with the same data type"""
- module = eval(cls.__module__)
- untyped_storage = module._UntypedStorage._new_shared(size * cls().element_size())
- return cls(wrap_storage=untyped_storage)
+ if self.is_cuda:
+ untyped_cls = torch.cuda._UntypedStorage
+ else:
+ untyped_cls = torch._UntypedStorage
+ untyped_storage = untyped_cls._new_shared(size * self.element_size())
+ return _TypedStorage(
+ wrap_storage=untyped_storage,
+ dtype=self.dtype)
@property
def _cdata(self):
@@ -523,22 +632,39 @@
return self._storage._weak_ref(*args, **kwargs)
@classmethod
- def from_buffer(cls, *args, **kwargs):
+ def from_buffer(cls, *args, dtype=None, device=None, **kwargs):
if cls == _TypedStorage:
- raise RuntimeError(
- 'from_buffer: only supported for subclasses of _TypedStorage')
+ dtype = torch.get_default_dtype() if dtype is None else dtype
+ device = torch.device('cpu' if device is None else device)
- if 'dtype' in kwargs or len(args) == 5:
- raise RuntimeError((
- "from_buffer: 'dtype' can only be specified in "
- "_UntypedStorage.from_buffer"))
+ if device.type == 'cpu':
+ untyped_cls = torch._UntypedStorage
+ elif device.type == 'cuda':
+ untyped_cls = torch.cuda._UntypedStorage
+ else:
+ raise RuntimeError(
+ f"_TypedStorage.from_buffer: device '{device}' not recognized")
+ untyped_storage: Union[torch._UntypedStorage, torch.cuda._UntypedStorage]
+ untyped_storage = untyped_cls.from_buffer(*args, dtype=dtype, **kwargs)
- kwargs['dtype'] = cls().dtype
+ else:
+ if dtype is not None or len(args) == 5:
+ raise RuntimeError((
+ "from_buffer: 'dtype' can only be specified in "
+ "_UntypedStorage.from_buffer and _TypedStorage.from_buffer"))
+ if device is not None:
+ raise RuntimeError((
+ "from_buffer: 'device' can only be specified in "
+ "_UntypedStorage.from_buffer and _TypedStorage.from_buffer"))
- untyped_storage = eval(cls.__module__)._UntypedStorage.from_buffer(*args, **kwargs)
- return cls(wrap_storage=untyped_storage)
+ dtype = cls.dtype
+ untyped_storage = eval(cls.__module__)._UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
+
+ return _TypedStorage(wrap_storage=untyped_storage, dtype=dtype)
def _to(self, dtype):
+ if not isinstance(dtype, torch.dtype):
+ raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
storage = torch.tensor([], dtype=self.dtype, device=self.device).set_(self).to(dtype).storage()
if storage.data_ptr() == self.data_ptr():
storage = storage.clone()
@@ -594,6 +720,23 @@
@classmethod
def from_file(cls, filename, shared, size):
+ """
+ from_file(filename, shared=False, size=0) -> Storage
+
+ If `shared` is `True`, then memory is shared between all processes.
+ All changes are written to the file. If `shared` is `False`, then the changes on
+ the storage do not affect the file.
+
+ `size` is the number of elements in the storage. If `shared` is `False`,
+ then the file must contain at least `size * sizeof(Type)` bytes
+ (`Type` is the type of storage). If `shared` is `True` the file will be
+ created if needed.
+
+ Args:
+ filename (str): file name to map
+ shared (bool): whether to share memory
+ size (int): number of elements in the storage
+ """
if cls == _TypedStorage:
raise RuntimeError('from_file can only be called on derived classes')
untyped_storage = eval(cls.__module__)._UntypedStorage.from_file(
@@ -627,28 +770,28 @@
@classmethod
def _new_shared_cuda(cls, *args, **kwargs):
- return eval(cls.__module__)._UntypedStorage._new_shared_cuda(*args, **kwargs)
-
- @classmethod
- def _new_with_weak_ptr(cls, *args, **kwargs):
- return eval(cls.__module__)._UntypedStorage._new_with_weak_ptr(*args, **kwargs)
+ return torch.cuda._UntypedStorage._new_shared_cuda(*args, **kwargs)
def _share_filename_(self, *args, **kwargs):
manager_handle, storage_handle, size = self._storage._share_filename_(*args, **kwargs)
return manager_handle, storage_handle, size // self.element_size()
- @classmethod
- def _new_shared_filename(cls, manager, obj, size):
- bytes_size = size * torch._utils._element_size(cls.dtype)
- return cls(wrap_storage=eval(cls.__module__)._UntypedStorage._new_shared_filename(manager, obj, bytes_size))
-
def _shared_decref(self):
self._storage._shared_decref()
return self
@classmethod
- def _release_ipc_counter(cls, *args, **kwargs):
- return eval(cls.__module__)._UntypedStorage._release_ipc_counter(*args, **kwargs)
+ def _release_ipc_counter(cls, *args, device=None, **kwargs):
+ device = torch.device('cpu' if device is None else device)
+
+ if device.type == 'cpu':
+ untyped_cls = torch._UntypedStorage
+ elif device.type == 'cuda':
+ untyped_cls = torch.cuda._UntypedStorage
+ else:
+ raise RuntimeError(f"device {device} not recognized")
+
+ return untyped_cls._release_ipc_counter(*args, **kwargs)
def _shared_incref(self, *args, **kwargs):
return self._storage._shared_incref(*args, **kwargs)
@@ -657,6 +800,51 @@
fd, size = self._storage._share_fd_(*args, **kwargs)
return fd, size // self.element_size()
+ def _get_legacy_storage_class(self):
+ if self.dtype not in _dtype_to_storage_type_map():
+ return None
+
+ storage_name = _dtype_to_storage_type_map()[self.dtype]
+
+ if self.device.type not in ['cpu', 'cuda']:
+ return None
+
+ module = 'torch.' if self.device.type == 'cpu' else 'torch.cuda.'
+
+ try:
+ return eval(module + storage_name)
+ except AttributeError:
+ return None
+
+_TypedStorage.type.__doc__ = _type.__doc__
+_TypedStorage.cuda.__doc__ = _cuda.__doc__
+
+class _LegacyStorageMeta(type):
+ dtype: torch.dtype
+
+ def __instancecheck__(cls, instance):
+ if type(instance) == _TypedStorage:
+ cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
+ return (cls_device == instance.device.type) and (cls.dtype == instance.dtype)
+ return False
+
+class _LegacyStorage(_TypedStorage, metaclass=_LegacyStorageMeta):
+ @classmethod
+ def _new_shared(cls, size):
+ """Creates a new storage in shared memory with the same data type"""
+ module = eval(cls.__module__)
+ untyped_storage = module._UntypedStorage._new_shared(size * cls().element_size())
+ return cls(wrap_storage=untyped_storage)
+
+ @classmethod
+ def _release_ipc_counter(cls, *args, **kwargs):
+ return eval(cls.__module__)._UntypedStorage._release_ipc_counter(*args, **kwargs)
+
+ @classmethod
+ def _new_shared_filename(cls, manager, obj, size):
+ bytes_size = size * torch._utils._element_size(cls.dtype)
+ return cls(wrap_storage=eval(cls.__module__)._UntypedStorage._new_shared_filename(manager, obj, bytes_size))
+
def _get_dtype_from_pickle_storage_type(pickle_storage_type: str):
try:
return _storage_type_to_dtype_map()[pickle_storage_type]