Virtualize `<type>Storage` classes (#66970)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/66228

cc ezyang bhosmer smessmer ljk53 bdhirsh

Pull Request resolved: https://github.com/pytorch/pytorch/pull/66970

Reviewed By: bdhirsh

Differential Revision: D33245612

Pulled By: ezyang

fbshipit-source-id: 4c61c2cb029e2b94b0e68927c377d3e1c358dd7c
(cherry picked from commit d29fcdfb4bc2cc17b1795d4349e4b56fa0d1cf12)
diff --git a/docs/source/storage.rst b/docs/source/storage.rst
index 3aeec08..747acf1 100644
--- a/docs/source/storage.rst
+++ b/docs/source/storage.rst
@@ -1,87 +1,96 @@
 torch.Storage
 ===================================
 
-A :class:`torch.Storage` is a contiguous, one-dimensional array of a single
-data type.
+A :class:`torch._TypedStorage` is a contiguous, one-dimensional array of
+elements of a particular :class:`torch.dtype`. It can be given any
+:class:`torch.dtype`, and the internal data will be interpretted appropriately.
 
-Every :class:`torch.Tensor` has a corresponding storage of the same data type.
+Every strided :class:`torch.Tensor` contains a :class:`torch._TypedStorage`,
+which stores all of the data that the :class:`torch.Tensor` views.
+
+For backward compatibility, there are also :class:`torch.<type>Storage` classes
+(like :class:`torch.FloatStorage`, :class:`torch.IntStorage`, etc). These
+classes are not actually instantiated, and calling their constructors creates
+a :class:`torch._TypedStorage` with the appropriate :class:`torch.dtype`.
+:class:`torch.<type>Storage` classes have all of the same class methods that
+:class:`torch._TypedStorage` has.
+
+Also for backward compatibility, :class:`torch.Storage` is an alias for the
+storage class that corresponds with the default data type
+(:func:`torch.get_default_dtype()`). For instance, if the default data type is
+:attr:`torch.float`, :class:`torch.Storage` resolves to
+:class:`torch.FloatStorage`.
+
+
+.. autoclass:: torch._TypedStorage
+   :members:
+   :undoc-members:
+   :inherited-members:
 
 .. autoclass:: torch.DoubleStorage
    :members:
    :undoc-members:
-   :inherited-members:
 
 .. autoclass:: torch.FloatStorage
    :members:
    :undoc-members:
-   :inherited-members:
 
 .. autoclass:: torch.HalfStorage
    :members:
    :undoc-members:
-   :inherited-members:
 
 .. autoclass:: torch.LongStorage
    :members:
    :undoc-members:
-   :inherited-members:
 
 .. autoclass:: torch.IntStorage
    :members:
    :undoc-members:
-   :inherited-members:
 
 .. autoclass:: torch.ShortStorage
    :members:
    :undoc-members:
-   :inherited-members:
 
 .. autoclass:: torch.CharStorage
    :members:
    :undoc-members:
-   :inherited-members:
 
 .. autoclass:: torch.ByteStorage
    :members:
    :undoc-members:
-   :inherited-members:
 
 .. autoclass:: torch.BoolStorage
    :members:
    :undoc-members:
-   :inherited-members:
 
 .. autoclass:: torch.BFloat16Storage
    :members:
    :undoc-members:
-   :inherited-members:
 
 .. autoclass:: torch.ComplexDoubleStorage
    :members:
    :undoc-members:
-   :inherited-members:
 
 .. autoclass:: torch.ComplexFloatStorage
    :members:
    :undoc-members:
-   :inherited-members:
 
 .. autoclass:: torch.QUInt8Storage
    :members:
    :undoc-members:
-   :inherited-members:
 
 .. autoclass:: torch.QInt8Storage
    :members:
    :undoc-members:
-   :inherited-members:
 
 .. autoclass:: torch.QInt32Storage
    :members:
    :undoc-members:
-   :inherited-members:
 
 .. autoclass:: torch.QUInt4x2Storage
    :members:
    :undoc-members:
-   :inherited-members:
+
+.. autoclass:: torch.QUInt2x4Storage
+   :members:
+   :undoc-members:
diff --git a/test/test_torch.py b/test/test_torch.py
index 1917c9a..2ab2b69 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -221,6 +221,16 @@
 
     @onlyNativeDeviceTypes
     @dtypes(*get_all_dtypes())
+    def test_tensor_storage_type(self, device, dtype):
+        a = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9)
+
+        module = torch.cuda if (torch.device(device).type == 'cuda') else torch
+        expected_storage_type = getattr(module, torch.storage._dtype_to_storage_type_map()[dtype])
+
+        self.assertEqual(a.storage_type(), expected_storage_type)
+
+    @onlyNativeDeviceTypes
+    @dtypes(*get_all_dtypes())
     def test_tensor_from_storage(self, device, dtype):
         a = make_tensor((4, 5, 3), dtype=dtype, device=device, low=-9, high=9)
         a_s = a.storage()
@@ -6171,6 +6181,7 @@
         self.assertEqual(bools.size(), 8)
         self.assertEqual(bools.tolist(), [False, True, True, True, True, True, True, True])
         self.assertEqual(bools.type(), 'torch.BoolStorage')
+        self.assertTrue(isinstance(bools, torch.BoolStorage))
 
         f = bytearray(b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9')
         bools = torch.BoolStorage.from_buffer(f, 'big')
@@ -6183,6 +6194,122 @@
         bytes = torch.ByteStorage.from_buffer(a)
         self.assertEqual(bytes.nbytes(), 4)
         self.assertEqual(bytes.tolist(), [1, 2, 3, 4])
+        self.assertTrue(isinstance(bytes, torch.ByteStorage))
+
+    def test_storage_error(self):
+        quantized_storages = [
+            torch.QInt32Storage,
+            torch.QInt8Storage,
+            torch.QUInt2x4Storage,
+            torch.QUInt4x2Storage,
+            torch.QUInt8Storage,
+        ]
+
+        with self.assertRaisesRegex(RuntimeError, r"Only child classes of _LegacyStorage can be instantiated"):
+            torch.storage._LegacyStorage()
+
+        for storage_class in torch._storage_classes:
+            if storage_class in [torch._UntypedStorage, torch.cuda._UntypedStorage, torch._TypedStorage]:
+                continue
+
+            device = 'cuda' if storage_class.__module__ == 'torch.cuda' else 'cpu'
+            dtype = storage_class.dtype
+
+            if device == 'cuda' and not torch.cuda.is_available():
+                continue
+
+            # Legacy <type>Storage constructor errors
+            with self.assertRaisesRegex(RuntimeError, r"'device' cannot be specified"):
+                storage_class(device='cpu')
+
+            with self.assertRaisesRegex(RuntimeError, r"'dtype' cannot be specified"):
+                storage_class(dtype=torch.float)
+
+            with self.assertRaisesRegex(TypeError, r"got an unexpected keyword"):
+                storage_class(sdlkjf=torch.float)
+
+            with self.assertRaisesRegex(RuntimeError, r"Too many positional arguments"):
+                storage_class(0, 0)
+
+            with self.assertRaisesRegex(TypeError, r"invalid data type"):
+                storage_class('string')
+
+            with self.assertRaisesRegex(TypeError, r"Argument type not recognized"):
+                storage_class(torch.tensor([]))
+
+            s = storage_class()
+
+            with self.assertRaisesRegex(RuntimeError, r"No positional arguments"):
+                storage_class(0, wrap_storage=s._untyped())
+
+            with self.assertRaisesRegex(TypeError, r"must be _UntypedStorage"):
+                storage_class(wrap_storage=s)
+
+            if torch.cuda.is_available():
+                if storage_class in quantized_storages:
+                    with self.assertRaisesRegex(RuntimeError, r"Cannot create CUDA storage with quantized dtype"):
+                        s.cuda()
+
+                else:
+
+                    if s.is_cuda:
+                        s_other_device = s.cpu()
+                    else:
+                        s_other_device = s.cuda()
+
+                    with self.assertRaisesRegex(RuntimeError, r"Device of 'wrap_storage' must be"):
+                        storage_class(wrap_storage=s_other_device._untyped())
+
+            # _TypedStorage constructor errors
+            with self.assertRaisesRegex(RuntimeError, r"No positional arguments"):
+                torch._TypedStorage(0, wrap_storage=s._untyped(), dtype=dtype)
+
+            with self.assertRaisesRegex(RuntimeError, r"Argument 'dtype' must be specified"):
+                torch._TypedStorage(wrap_storage=s._untyped())
+
+            with self.assertRaisesRegex(TypeError, r"Argument 'dtype' must be torch.dtype"):
+                torch._TypedStorage(wrap_storage=s._untyped(), dtype=0)
+
+            with self.assertRaisesRegex(RuntimeError, r"Argument 'device' should not be specified"):
+                torch._TypedStorage(wrap_storage=s._untyped(), dtype=dtype, device=device)
+
+            with self.assertRaisesRegex(TypeError, r"Argument 'wrap_storage' must be _UntypedStorage"):
+                torch._TypedStorage(wrap_storage=s, dtype=dtype)
+
+            with self.assertRaisesRegex(RuntimeError, r"Storage device not recognized"):
+                torch._TypedStorage(dtype=dtype, device='xla')
+
+            if torch.cuda.is_available():
+                if storage_class in quantized_storages:
+                    with self.assertRaisesRegex(RuntimeError, r"Cannot create CUDA storage with quantized dtype"):
+                        torch._TypedStorage(dtype=dtype, device='cuda')
+
+            with self.assertRaisesRegex(TypeError, r"Argument type not recognized"):
+                torch._TypedStorage(torch.tensor([]), dtype=dtype, device=device)
+
+            with self.assertRaisesRegex(RuntimeError, r"Too many positional arguments"):
+                torch._TypedStorage(0, 0, dtype=dtype, device=device)
+
+    def test_storage_error_no_attribute(self):
+        storage_classes = [
+            torch.cuda.ByteStorage,
+            torch.cuda.FloatStorage,
+            torch.cuda._UntypedStorage,
+        ]
+        for storage_class in storage_classes:
+            with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
+                storage_class.from_buffer()
+
+            if storage_class == torch.cuda._UntypedStorage:
+                with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
+                    storage_class._new_with_weak_ptr()
+
+            else:
+                with self.assertRaisesRegex(AttributeError, r'has no attribute'):
+                    storage_class._new_with_weak_ptr()
+
+            with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
+                storage_class._new_shared_filename(0, 0, 0)
 
     def test_storage_casts(self):
         storage = torch.IntStorage([-1, 0, 1, 2, 3, 4])
diff --git a/torch/__init__.py b/torch/__init__.py
index ac47103..cf7f592 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -39,6 +39,7 @@
     'no_grad', 'enable_grad', 'rand', 'randn', 'inference_mode',
     'DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
     'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage',
+    '_TypedStorage',
     'DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
     'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor', 'Tensor',
     'lobpcg', 'use_deterministic_algorithms',
@@ -594,7 +595,7 @@
 ################################################################################
 
 from ._tensor import Tensor
-from .storage import _StorageBase, _TypedStorage
+from .storage import _StorageBase, _TypedStorage, _LegacyStorage
 
 # NOTE: New <type>Storage classes should never be added. When adding a new
 # dtype, use torch.storage._TypedStorage directly.
@@ -602,87 +603,87 @@
 class _UntypedStorage(_C.ByteStorageBase, _StorageBase):
     pass
 
-class ByteStorage(_TypedStorage):
+class ByteStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.uint8
 
-class DoubleStorage(_TypedStorage):
+class DoubleStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.double
 
-class FloatStorage(_TypedStorage):
+class FloatStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.float
 
-class HalfStorage(_TypedStorage):
+class HalfStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.half
 
-class LongStorage(_TypedStorage):
+class LongStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.long
 
-class IntStorage(_TypedStorage):
+class IntStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.int
 
-class ShortStorage(_TypedStorage):
+class ShortStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.short
 
-class CharStorage(_TypedStorage):
+class CharStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.int8
 
-class BoolStorage(_TypedStorage):
+class BoolStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.bool
 
-class BFloat16Storage(_TypedStorage):
+class BFloat16Storage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.bfloat16
 
-class ComplexDoubleStorage(_TypedStorage):
+class ComplexDoubleStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.cdouble
 
-class ComplexFloatStorage(_TypedStorage):
+class ComplexFloatStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.cfloat
 
-class QUInt8Storage(_TypedStorage):
+class QUInt8Storage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.quint8
 
-class QInt8Storage(_TypedStorage):
+class QInt8Storage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.qint8
 
-class QInt32Storage(_TypedStorage):
+class QInt32Storage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.qint32
 
-class QUInt4x2Storage(_TypedStorage):
+class QUInt4x2Storage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.quint4x2
 
-class QUInt2x4Storage(_TypedStorage):
+class QUInt2x4Storage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.quint2x4
@@ -692,6 +693,7 @@
     ShortStorage, CharStorage, ByteStorage, HalfStorage, BoolStorage,
     QUInt8Storage, QInt8Storage, QInt32Storage, BFloat16Storage,
     ComplexFloatStorage, ComplexDoubleStorage, QUInt4x2Storage, QUInt2x4Storage,
+    _TypedStorage
 }
 
 # The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
diff --git a/torch/_tensor.py b/torch/_tensor.py
index 4d690c0..99a604a 100644
--- a/torch/_tensor.py
+++ b/torch/_tensor.py
@@ -202,11 +202,7 @@
         if self.dtype not in torch.storage._dtype_to_storage_type_map():
             raise RuntimeError(f'unsupported Storage type: {self.dtype}')
 
-        storage = self._storage()
-        storage_name = torch.storage._dtype_to_storage_type_map()[self.dtype]
-        storage_class = eval(type(storage).__module__ + '.' + storage_name)
-        storage = storage_class(wrap_storage=storage)
-        return storage
+        return torch._TypedStorage(wrap_storage=self._storage(), dtype=self.dtype)
 
     def _reduce_ex_internal(self, proto):
         check_serializing_named_tensor(self)
@@ -866,10 +862,7 @@
         Returns the type of the underlying storage.
 
         """
-        # NB: this returns old fashioned _TypedStorage, e.g., FloatStorage, as it
-        # would be pretty pointless otherwise (it would always return
-        # _UntypedStorage)
-        return type(self.storage())
+        return self.storage()._get_legacy_storage_class()
 
     def refine_names(self, *names):
         r"""Refines the dimension names of :attr:`self` according to :attr:`names`.
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index ac7026e..ab7338f 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -674,67 +674,77 @@
 
     __new__ = _lazy_new
 
-from torch.storage import _TypedStorage
+from torch.storage import _TypedStorage, _LegacyStorage
 
 class _UntypedStorage(_CudaBase, torch._C.CudaByteStorageBase, _StorageBase):
-    pass
+    @classmethod
+    def from_buffer(cls, *args, **kwargs):
+        raise RuntimeError('from_buffer: Not available for CUDA storage')
 
-class ByteStorage(_TypedStorage):
+    @classmethod
+    def _new_with_weak_ptr(cls, *args, **kwargs):
+        raise RuntimeError('_new_with_weak_ptr: Not available for CUDA storage')
+
+    @classmethod
+    def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None):
+        raise RuntimeError('_new_shared_filename: Not available for CUDA storage')
+
+class ByteStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.uint8
 
-class DoubleStorage(_TypedStorage):
+class DoubleStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.double
 
-class FloatStorage(_TypedStorage):
+class FloatStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.float
 
-class HalfStorage(_TypedStorage):
+class HalfStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.half
 
-class LongStorage(_TypedStorage):
+class LongStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.long
 
-class IntStorage(_TypedStorage):
+class IntStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.int
 
-class ShortStorage(_TypedStorage):
+class ShortStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.short
 
-class CharStorage(_TypedStorage):
+class CharStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.int8
 
-class BoolStorage(_TypedStorage):
+class BoolStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.bool
 
-class BFloat16Storage(_TypedStorage):
+class BFloat16Storage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.bfloat16
 
-class ComplexDoubleStorage(_TypedStorage):
+class ComplexDoubleStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.cdouble
 
-class ComplexFloatStorage(_TypedStorage):
+class ComplexFloatStorage(_LegacyStorage):
     @classproperty
     def dtype(self):
         return torch.cfloat
diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py
index 2da1ab8..4a5d725 100644
--- a/torch/multiprocessing/reductions.py
+++ b/torch/multiprocessing/reductions.py
@@ -6,6 +6,8 @@
 import multiprocessing
 from multiprocessing.util import register_after_fork
 from multiprocessing.reduction import ForkingPickler
+from typing import Union
+
 try:
     # Early load resource_sharer to prevent a partially initialized instance
     # from being inherited in a forked child process. The reduce_storage method
@@ -103,7 +105,7 @@
                         requires_grad, ref_counter_handle, ref_counter_offset, event_handle, event_sync_required):
     # If storage_handle is None, storage points to nullptr.
     if storage_handle is None or storage_size_bytes == 0:
-        storage = storage_cls(0)
+        storage = storage_cls(0, dtype=dtype, device=storage_device)
     else:
         storage = storage_from_cache(storage_cls, (storage_handle, storage_offset_bytes))
         if storage is None:
@@ -120,7 +122,7 @@
             shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(storage)
         else:
             # We already ref counting this Storage, but producer needs new ref-counters to be released.
-            storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset)
+            storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset, device=storage_device)
 
     t = torch._utils._rebuild_tensor(
         torch.storage._TypedStorage(wrap_storage=storage._untyped(), dtype=dtype),
@@ -288,7 +290,7 @@
     storage_ref = shared_cache.get(key)
     if storage_ref is None:
         return None
-    return cls._new_with_weak_ptr(storage_ref.cdata)
+    return torch._UntypedStorage._new_with_weak_ptr(storage_ref.cdata)
 
 
 def rebuild_storage_fd(cls, df, size):
@@ -304,11 +306,18 @@
         os.close(fd)
 
 
-def rebuild_storage_filename(cls, manager, handle, size):
-    storage = storage_from_cache(cls, handle)
+def rebuild_storage_filename(cls, manager, handle, size, dtype=None):
+    storage: Union[torch._TypedStorage, torch._UntypedStorage] = storage_from_cache(cls, handle)
     if storage is not None:
         return storage._shared_decref()
-    storage = cls._new_shared_filename(manager, handle, size)
+    if dtype is None:
+        storage = torch._UntypedStorage._new_shared_filename(manager, handle, size)
+    else:
+        byte_size = size * torch._utils._element_size(dtype)
+        untyped_storage: torch._UntypedStorage = torch._UntypedStorage._new_shared_filename(manager, handle, byte_size)
+        storage = torch._TypedStorage(
+            wrap_storage=untyped_storage,
+            dtype=dtype)
     shared_cache[handle] = StorageWeakRef(storage)
     return storage._shared_decref()
 
@@ -338,6 +347,8 @@
         metadata = storage._share_filename_()
         cache_key = metadata[1]
         rebuild = rebuild_storage_filename
+        if isinstance(storage, torch._TypedStorage):
+            metadata += (storage.dtype,)
         storage._shared_incref()
     elif storage.size() == 0:
         # This is special cased because Empty tensors
diff --git a/torch/storage.py b/torch/storage.py
index 54e8df5..2537d7a 100644
--- a/torch/storage.py
+++ b/torch/storage.py
@@ -7,6 +7,11 @@
 import copy
 import collections
 from functools import lru_cache
+try:
+    import numpy as np
+    HAS_NUMPY = True
+except ModuleNotFoundError:
+    np = None  # type: ignore[assignment]
 
 T = TypeVar('T', bound='Union[_StorageBase, _TypedStorage]')
 class _StorageBase(object):
@@ -38,6 +43,15 @@
     def _new_using_filename(cls: Type[T], size: int) -> T: ...  # noqa: E704
     @classmethod
     def _new_using_fd(cls: Type[T], size: int) -> T: ...  # noqa: E704
+    @classmethod
+    def from_buffer(cls, *args, **kwargs) -> T: ...  # noqa: E704
+    @classmethod
+    def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None) -> T: ...  # noqa: E704
+    @classmethod
+    def _release_ipc_counter(cls, *args, **kwargs) -> T: ...  # noqa: E704
+    @classmethod
+    def _new_with_weak_ptr(cls, *args, **kwargs) -> T: ...  # noqa: E704
+    def _shared_decref(self) -> T: ...  # noqa: E704
 
     def __str__(self):
         content = ' ' + '\n '.join(str(self[i]) for i in range(len(self)))
@@ -83,6 +97,8 @@
         return _type(self, getattr(torch, self.__class__.__name__))
 
     def _to(self, dtype):
+        if not isinstance(dtype, torch.dtype):
+            raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
         storage = torch.tensor([], dtype=torch.uint8, device=self.device).set_(cast(Storage, self)).to(dtype).storage()
         if storage.data_ptr() == self.data_ptr():
             storage = storage.clone()
@@ -187,6 +203,11 @@
 
 @lru_cache(maxsize=None)
 def _dtype_to_storage_type_map():
+    # NOTE: We should no longer add dtypes to this map. This map
+    # is only used for BC/FC with older PyTorch versions. Going forward,
+    # new dtypes of _TypedStorage should not translate to a legacy
+    # <type>Storage class. Instead, new dtypes of _TypedStorage should
+    # be serialized as an _UntypedStorage paired with a torch.dtype
     return {
         torch.double: 'DoubleStorage',
         torch.float: 'FloatStorage',
@@ -213,104 +234,183 @@
         val: key for key, val in _dtype_to_storage_type_map().items()}
     return dtype_map
 
+def _get_storage_from_sequence(sequence, dtype, device):
+    if dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
+        interpret_dtypes = {
+            torch.quint8: torch.uint8,
+            torch.quint4x2: torch.uint8,
+            torch.quint2x4: torch.uint8,
+            torch.qint32: torch.int32,
+            torch.qint8: torch.int8
+        }
+        tmp_tensor = torch.tensor(
+            sequence,
+            dtype=interpret_dtypes[dtype],
+            device=device)
+
+    else:
+        tmp_tensor = torch.tensor(
+            sequence,
+            dtype=dtype,
+            device=device)
+
+    return tmp_tensor.storage()._untyped()
+
+def _isint(x):
+    if HAS_NUMPY:
+        return isinstance(x, (int, np.integer))
+    else:
+        return isinstance(x, int)
+
 class _TypedStorage:
     is_sparse = False
 
+    dtype: torch.dtype
+
     def fill_(self, value):
         self[0:len(self)] = value
         return self
 
-    def __init__(self, *args, **kwargs):
-        arg_error_msg = (
-            f'{type(self)} constructor received an invalid combination '
-            f'of arguments - got args={tuple(type(arg) for arg in args)}, '
-            f'kwargs={ {key: type(val) for key, val in kwargs.items()} }, but '
-            'expected one of:\n'
-            ' * no arguments\n'
-            ' * (int size)\n'
-            ' * (Sequence data)\n')
-        if type(self) == _TypedStorage:
-            arg_error_msg += ' * (wrap_storage=<_UntypedStorage>, dtype=<torch.dtype>)'
+    def __new__(cls, *args, wrap_storage=None, dtype=None, device=None):
+        if cls == torch.storage._LegacyStorage:
+            raise RuntimeError("Only child classes of _LegacyStorage can be instantiated")
+
+        if cls == _TypedStorage:
+            return super().__new__(cls)
+
         else:
-            arg_error_msg += ' * (wrap_storage=<_UntypedStorage>)'
+            arg_error_msg = (
+                f'{cls}.__new__ received an invalid combination '
+                f'of arguments. Expected one of:\n'
+                ' * no arguments\n'
+                ' * (int size)\n'
+                ' * (Sequence data)\n'
+                ' * (*, _UntypedStorage wrap_storage)')
 
-        if 'wrap_storage' in kwargs:
-            assert len(args) == 0, (
-                "No positional arguments should be given when using "
-                "'wrap_storage'")
+            if device is not None:
+                raise RuntimeError(
+                    arg_error_msg +
+                    "\nKeyword argument 'device' cannot be specified")
 
-            if type(self) == _TypedStorage:
-                assert 'dtype' in kwargs, (
-                    "When using 'wrap_storage', 'dtype' also must be specified")
-                assert len(kwargs) == 2, (
-                    "Only 'wrap_storage' and 'dtype' should be given, but got: "
-                    f"{kwargs}")
-                dtype = kwargs['dtype']
-                assert isinstance(dtype, torch.dtype)
-                self.dtype = dtype
+            if dtype is not None:
+                raise RuntimeError(
+                    arg_error_msg +
+                    "\nKeyword argument 'dtype' cannot be specified")
+
+            if wrap_storage is None:
+                if len(args) > 1:
+                    raise RuntimeError(
+                        arg_error_msg +
+                        "\nToo many positional arguments")
+
+                if len(args) == 1 and not _isint(args[0]) and not isinstance(args[0], collections.abc.Sequence):
+                    raise TypeError(
+                        arg_error_msg +
+                        f"\nArgument type not recognized: {type(args[0])}")
+
+                return _TypedStorage(
+                    *args,
+                    dtype=cls.dtype,
+                    device='cuda' if eval(cls.__module__) is torch.cuda else 'cpu')
 
             else:
-                assert hasattr(self, 'dtype')
-                assert len(kwargs) == 1, (
-                    f"Only 'wrap_storage' should be given, but got: {kwargs.keys()}")
-                dtype = self.dtype
+                if len(args) != 0:
+                    raise RuntimeError(
+                        arg_error_msg +
+                        "\nNo positional arguments should be given when using "
+                        "'wrap_storage'")
 
-            storage = kwargs['wrap_storage']
+                if not isinstance(wrap_storage, (torch._UntypedStorage, torch.cuda._UntypedStorage)):
+                    raise TypeError(
+                        arg_error_msg +
+                        f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}")
 
-            if not isinstance(storage, (torch._UntypedStorage, torch.cuda._UntypedStorage)):
-                raise TypeError(arg_error_msg)
-            if type(self) != _TypedStorage and storage.__module__ != self.__module__:
-                raise TypeError((
+                cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
+
+                if wrap_storage.device.type != cls_device:
+                    raise RuntimeError(
+                        arg_error_msg +
+                        f"\nDevice of 'wrap_storage' must be {cls_device}"
+                        f", but got {wrap_storage.device.type}")
+
+                return _TypedStorage(
+                    *args,
+                    wrap_storage=wrap_storage,
+                    dtype=cls.dtype)
+
+    def __init__(self, *args, device=None, dtype=None, wrap_storage=None):
+        arg_error_msg = (
+            '_TypedStorage.__init__ received an invalid combination '
+            'of arguments. Expected one of:\n'
+            ' * (*, torch.device device, torch.dtype dtype)\n'
+            ' * (int size, *, torch.device device, torch.dtype dtype)\n'
+            ' * (Sequence data, *, torch.device device, torch.dtype dtype)\n'
+            ' * (*, _UntypedStorage wrap_storage, torch.dtype dtype)')
+
+        if wrap_storage is not None:
+            if len(args) != 0:
+                raise RuntimeError(
                     arg_error_msg +
-                    f'\n`storage` `module {storage.__module__}` does not match '
-                    f'module of {type(self)}'))
-            self._storage = storage
+                    "\nNo positional arguments should be given when using "
+                    "'wrap_storage'")
+
+            if dtype is None:
+                raise RuntimeError(
+                    arg_error_msg +
+                    "\nArgument 'dtype' must be specified")
+
+            if not isinstance(dtype, torch.dtype):
+                raise TypeError(
+                    arg_error_msg +
+                    f"\nArgument 'dtype' must be torch.dtype, not {type(dtype)}")
+
+            if device is not None:
+                raise RuntimeError(
+                    arg_error_msg +
+                    "\nArgument 'device' should not be specified when 'wrap_storage' is given")
+
+            self.dtype = dtype
+
+            if not isinstance(wrap_storage, (torch._UntypedStorage, torch.cuda._UntypedStorage)):
+                raise TypeError(
+                    arg_error_msg +
+                    f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}")
+
+            self._storage = wrap_storage
 
         else:
-            assert type(self) != _TypedStorage, (
-                "Calling __init__ this way is only supported in _TypedStorage's "
-                "child classes. _TypedStorage can only be directly instantiated "
-                "when kwargs 'wrap_storage' and 'dtype' are given.")
+            self.dtype = torch.get_default_dtype() if dtype is None else dtype
+            device = torch.device('cpu' if device is None else device)
 
-            assert len(kwargs) == 0, "invalid keyword arguments"
+            if device.type == 'cpu':
+                untyped_storage_class = torch._UntypedStorage
+            elif device.type == 'cuda':
+                untyped_storage_class = torch.cuda._UntypedStorage
+            else:
+                raise RuntimeError(f"Storage device not recognized: {device}")
 
-            def isint(x):
-                try:
-                    int(x)
-                except TypeError:
-                    return False
-                return True
+            if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
+                if device.type == 'cuda':
+                    raise RuntimeError("Cannot create CUDA storage with quantized dtype")
 
             if len(args) == 0:
-                self._storage = eval(self.__module__)._UntypedStorage()
+                self._storage = untyped_storage_class()
 
-            elif len(args) == 1 and isint(args[0]):
-                self._storage = eval(self.__module__)._UntypedStorage(int(args[0]) * self.element_size())
-
-            elif len(args) == 1 and isinstance(args[0], collections.abc.Sequence):
-                if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
-                    interpret_dtypes = {
-                        torch.quint8: torch.uint8,
-                        torch.quint4x2: torch.uint8,
-                        torch.quint2x4: torch.uint8,
-                        torch.qint32: torch.int32,
-                        torch.qint8: torch.int8
-                    }
-                    tmp_tensor = torch.tensor(
-                        args[0],
-                        dtype=interpret_dtypes[self.dtype],
-                        device='cuda' if eval(self.__module__) is torch.cuda else 'cpu')
-
+            elif len(args) == 1:
+                if _isint(args[0]):
+                    self._storage = untyped_storage_class(int(args[0]) * self.element_size())
+                elif isinstance(args[0], collections.abc.Sequence):
+                    self._storage = _get_storage_from_sequence(args[0], self.dtype, device)
                 else:
-                    tmp_tensor = torch.tensor(
-                        args[0],
-                        dtype=self.dtype,
-                        device='cuda' if eval(self.__module__) is torch.cuda else 'cpu')
-
-                self._storage = tmp_tensor.storage()._untyped()
+                    raise TypeError(
+                        arg_error_msg +
+                        f"\nArgument type not recognized: {type(args[0])}")
 
             else:
-                raise TypeError(arg_error_msg)
+                raise RuntimeError(
+                    arg_error_msg +
+                    "\nToo many positional arguments")
+
 
     @property
     def is_cuda(self):
@@ -414,11 +514,19 @@
 
     def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]:
         if dtype is None:
+            legacy_class = self._get_legacy_storage_class()
+
+            if legacy_class is not None:
+                return legacy_class.__module__ + '.' + legacy_class.__name__
+
             return '.'.join([self.__module__, type(self).__name__])
+
         else:
             return self._storage.type(dtype, non_blocking)
 
     def cuda(self, device=None, non_blocking=False, **kwargs) -> T:
+        if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
+            raise RuntimeError("Cannot create CUDA storage with quantized dtype")
         cuda_storage = self._storage.cuda(device, non_blocking, **kwargs)
         return self._new_wrapped_storage(cuda_storage)
 
@@ -430,12 +538,9 @@
 
     def __str__(self):
         data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size()))
-        if type(self) == _TypedStorage:
-            return data_str + (
-                f'\n[{torch.typename(self)} with dtype {self.dtype} '
-                f'of size {len(self)}]')
-        else:
-            return data_str + f'\n[{torch.typename(self)} of size {len(self)}]'
+        return data_str + (
+            f'\n[{torch.typename(self)}(dtype={self.dtype}, '
+            f'device={self.device}) of size {len(self)}]')
 
     def __repr__(self):
         return str(self)
@@ -480,12 +585,16 @@
         self._storage.share_memory_()
         return self
 
-    @classmethod
-    def _new_shared(cls, size):
+    def _new_shared(self, size):
         """Creates a new storage in shared memory with the same data type"""
-        module = eval(cls.__module__)
-        untyped_storage = module._UntypedStorage._new_shared(size * cls().element_size())
-        return cls(wrap_storage=untyped_storage)
+        if self.is_cuda:
+            untyped_cls = torch.cuda._UntypedStorage
+        else:
+            untyped_cls = torch._UntypedStorage
+        untyped_storage = untyped_cls._new_shared(size * self.element_size())
+        return _TypedStorage(
+            wrap_storage=untyped_storage,
+            dtype=self.dtype)
 
     @property
     def _cdata(self):
@@ -523,22 +632,39 @@
         return self._storage._weak_ref(*args, **kwargs)
 
     @classmethod
-    def from_buffer(cls, *args, **kwargs):
+    def from_buffer(cls, *args, dtype=None, device=None, **kwargs):
         if cls == _TypedStorage:
-            raise RuntimeError(
-                'from_buffer: only supported for subclasses of _TypedStorage')
+            dtype = torch.get_default_dtype() if dtype is None else dtype
+            device = torch.device('cpu' if device is None else device)
 
-        if 'dtype' in kwargs or len(args) == 5:
-            raise RuntimeError((
-                "from_buffer: 'dtype' can only be specified in "
-                "_UntypedStorage.from_buffer"))
+            if device.type == 'cpu':
+                untyped_cls = torch._UntypedStorage
+            elif device.type == 'cuda':
+                untyped_cls = torch.cuda._UntypedStorage
+            else:
+                raise RuntimeError(
+                    f"_TypedStorage.from_buffer: device '{device}' not recognized")
+            untyped_storage: Union[torch._UntypedStorage, torch.cuda._UntypedStorage]
+            untyped_storage = untyped_cls.from_buffer(*args, dtype=dtype, **kwargs)
 
-        kwargs['dtype'] = cls().dtype
+        else:
+            if dtype is not None or len(args) == 5:
+                raise RuntimeError((
+                    "from_buffer: 'dtype' can only be specified in "
+                    "_UntypedStorage.from_buffer and _TypedStorage.from_buffer"))
+            if device is not None:
+                raise RuntimeError((
+                    "from_buffer: 'device' can only be specified in "
+                    "_UntypedStorage.from_buffer and _TypedStorage.from_buffer"))
 
-        untyped_storage = eval(cls.__module__)._UntypedStorage.from_buffer(*args, **kwargs)
-        return cls(wrap_storage=untyped_storage)
+            dtype = cls.dtype
+            untyped_storage = eval(cls.__module__)._UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
+
+        return _TypedStorage(wrap_storage=untyped_storage, dtype=dtype)
 
     def _to(self, dtype):
+        if not isinstance(dtype, torch.dtype):
+            raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
         storage = torch.tensor([], dtype=self.dtype, device=self.device).set_(self).to(dtype).storage()
         if storage.data_ptr() == self.data_ptr():
             storage = storage.clone()
@@ -594,6 +720,23 @@
 
     @classmethod
     def from_file(cls, filename, shared, size):
+        """
+        from_file(filename, shared=False, size=0) -> Storage
+
+        If `shared` is `True`, then memory is shared between all processes.
+        All changes are written to the file. If `shared` is `False`, then the changes on
+        the storage do not affect the file.
+
+        `size` is the number of elements in the storage. If `shared` is `False`,
+        then the file must contain at least `size * sizeof(Type)` bytes
+        (`Type` is the type of storage). If `shared` is `True` the file will be
+        created if needed.
+
+        Args:
+            filename (str): file name to map
+            shared (bool): whether to share memory
+            size (int): number of elements in the storage
+        """
         if cls == _TypedStorage:
             raise RuntimeError('from_file can only be called on derived classes')
         untyped_storage = eval(cls.__module__)._UntypedStorage.from_file(
@@ -627,28 +770,28 @@
 
     @classmethod
     def _new_shared_cuda(cls, *args, **kwargs):
-        return eval(cls.__module__)._UntypedStorage._new_shared_cuda(*args, **kwargs)
-
-    @classmethod
-    def _new_with_weak_ptr(cls, *args, **kwargs):
-        return eval(cls.__module__)._UntypedStorage._new_with_weak_ptr(*args, **kwargs)
+        return torch.cuda._UntypedStorage._new_shared_cuda(*args, **kwargs)
 
     def _share_filename_(self, *args, **kwargs):
         manager_handle, storage_handle, size = self._storage._share_filename_(*args, **kwargs)
         return manager_handle, storage_handle, size // self.element_size()
 
-    @classmethod
-    def _new_shared_filename(cls, manager, obj, size):
-        bytes_size = size * torch._utils._element_size(cls.dtype)
-        return cls(wrap_storage=eval(cls.__module__)._UntypedStorage._new_shared_filename(manager, obj, bytes_size))
-
     def _shared_decref(self):
         self._storage._shared_decref()
         return self
 
     @classmethod
-    def _release_ipc_counter(cls, *args, **kwargs):
-        return eval(cls.__module__)._UntypedStorage._release_ipc_counter(*args, **kwargs)
+    def _release_ipc_counter(cls, *args, device=None, **kwargs):
+        device = torch.device('cpu' if device is None else device)
+
+        if device.type == 'cpu':
+            untyped_cls = torch._UntypedStorage
+        elif device.type == 'cuda':
+            untyped_cls = torch.cuda._UntypedStorage
+        else:
+            raise RuntimeError(f"device {device} not recognized")
+
+        return untyped_cls._release_ipc_counter(*args, **kwargs)
 
     def _shared_incref(self, *args, **kwargs):
         return self._storage._shared_incref(*args, **kwargs)
@@ -657,6 +800,51 @@
         fd, size = self._storage._share_fd_(*args, **kwargs)
         return fd, size // self.element_size()
 
+    def _get_legacy_storage_class(self):
+        if self.dtype not in _dtype_to_storage_type_map():
+            return None
+
+        storage_name = _dtype_to_storage_type_map()[self.dtype]
+
+        if self.device.type not in ['cpu', 'cuda']:
+            return None
+
+        module = 'torch.' if self.device.type == 'cpu' else 'torch.cuda.'
+
+        try:
+            return eval(module + storage_name)
+        except AttributeError:
+            return None
+
+_TypedStorage.type.__doc__ = _type.__doc__
+_TypedStorage.cuda.__doc__ = _cuda.__doc__
+
+class _LegacyStorageMeta(type):
+    dtype: torch.dtype
+
+    def __instancecheck__(cls, instance):
+        if type(instance) == _TypedStorage:
+            cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
+            return (cls_device == instance.device.type) and (cls.dtype == instance.dtype)
+        return False
+
+class _LegacyStorage(_TypedStorage, metaclass=_LegacyStorageMeta):
+    @classmethod
+    def _new_shared(cls, size):
+        """Creates a new storage in shared memory with the same data type"""
+        module = eval(cls.__module__)
+        untyped_storage = module._UntypedStorage._new_shared(size * cls().element_size())
+        return cls(wrap_storage=untyped_storage)
+
+    @classmethod
+    def _release_ipc_counter(cls, *args, **kwargs):
+        return eval(cls.__module__)._UntypedStorage._release_ipc_counter(*args, **kwargs)
+
+    @classmethod
+    def _new_shared_filename(cls, manager, obj, size):
+        bytes_size = size * torch._utils._element_size(cls.dtype)
+        return cls(wrap_storage=eval(cls.__module__)._UntypedStorage._new_shared_filename(manager, obj, bytes_size))
+
 def _get_dtype_from_pickle_storage_type(pickle_storage_type: str):
     try:
         return _storage_type_to_dtype_map()[pickle_storage_type]