Add support for serializing real tensor data in after aot minifier (#99834)

The new minifier script looks like this:

```
import torch._dynamo.repro.after_aot
reader = torch._dynamo.repro.after_aot.InputReader(save_dir='/tmp/tmpcsngx39e')
buf0 = reader.storage('e2b39c716c0d4efb9fa57375a3902b9dab666893', 16)
t0 = reader.tensor(buf0, (4,))
args = [t0]
mod = make_fx(Repro(), tracing_mode='real')(*args)
```

The real tensor data is stored in the storages folder of the checkpoint dump directory. If you delete this folder / it is otherwise missing, we will transparently fall back to generating random data like before. The tensors are serialized using content store from #99809, which means each storage is content-addressed and we will automatically deduplicate equivalent data (which is useful if you keep dumping out, e.g., your parameters.) We don't use the tensor serialization capability from content store, instead all of the tensor metadata is stored inline inside the repro script (so that everything is in one file if you lose the checkpointed tensors).

We also add a stable_hash option to content store, where we use a slow SHA-1 sum on the data in CPU side to compute a hash that is stable across systems with the same endianness.

Out of rage, I also added support for Dtype.itemsize property access.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99834
Approved by: https://github.com/voznesenskym
diff --git a/test/dynamo/test_after_aot.py b/test/dynamo/test_after_aot.py
index 0b0cc66..b229f2d 100644
--- a/test/dynamo/test_after_aot.py
+++ b/test/dynamo/test_after_aot.py
@@ -1,10 +1,15 @@
 # Owner(s): ["module: dynamo"]
 
 import io
+import os
+import shutil
+import sys
+import tempfile
+import unittest
 
 import torch._dynamo.test_case
 
-from torch._dynamo.repro.after_aot import save_graph_repro
+from torch._dynamo.repro.after_aot import InputWriter, save_graph_repro
 
 from torch.fx.experimental.proxy_tensor import make_fx
 from torch.utils._traceback import report_compile_source_on_error
@@ -16,7 +21,8 @@
 
 class TestAfterAot(torch._dynamo.test_case.TestCase):
     def test_save_graph_repro(self):
-        return
+        # TODO: This triggers CUDA context initialization, even though
+        # it is CPU only
         buf = io.StringIO()
         args = [torch.randn(4)]
 
@@ -24,53 +30,51 @@
             return (x * x,)
 
         gm = make_fx(f)(*args)
-        save_graph_repro(buf, gm, args, "inductor_accuracy", stable_output=True)
-        r = strip_trailing_whitespace(buf.getvalue())
-        self.assertExpectedInline(
-            r,
+        with tempfile.TemporaryDirectory() as d:
+            save_graph_repro(buf, gm, args, "inductor_accuracy", save_dir=d)
+            r = buf.getvalue()
+            with report_compile_source_on_error():
+                exec(r, {"__compile_source__": r})
+
+            shutil.rmtree(os.path.join(d, "storages"))
+
+            # Should still work even without the save dir
+            with report_compile_source_on_error():
+                exec(r, {"__compile_source__": r})
+
+    @unittest.skipIf(sys.byteorder != "little", "checksum depends on endianness")
+    def test_dump_tensor(self):
+        def test(tensor, expected):
+            with tempfile.TemporaryDirectory() as d:
+                writer = InputWriter(d, stable_hash=True)
+                prefix = len(writer.lines)
+                x = writer.tensor(tensor)
+                self.assertExpectedInline(
+                    "\n".join(writer.lines[prefix:]), expected, skip=1
+                )
+                env = {}
+                # TODO: assert no logs
+                exec("\n".join(writer.lines), env)
+                self.assertEqual(env[x], tensor)
+
+        test(
+            torch.zeros(3, 4),
             """\
-import torch._inductor.overrides
-
-import torch
-from torch import tensor, device
-import torch.fx as fx
-from torch._dynamo.testing import rand_strided
-from math import inf
-from torch.fx.experimental.proxy_tensor import make_fx
-
-# config omitted due to stable_output=True
-
-# REPLACEABLE COMMENT FOR TESTING PURPOSES
-
-
-
-from torch.nn import *
-class Repro(torch.nn.Module):
-    def __init__(self):
-        super().__init__()
-
-
-
-    def forward(self, x_1):
-        mul = torch.ops.aten.mul.Tensor(x_1, x_1);  x_1 = None
-        return (mul,)
-
-args = []
-args.append(rand_strided((4,), (1,), torch.float32, 'cpu'))  # shape (4,), stride (1,)
-mod = make_fx(Repro(), tracing_mode='real')(*args)
-
-from torch._inductor.compile_fx import compile_fx_inner
-from torch._dynamo.debug_utils import same_two_models
-
-compiled = compile_fx_inner(mod, args)
-class AccuracyError(Exception):
-    pass
-if not same_two_models(mod, compiled, args, only_fwd=True):
-    raise AccuracyError("Bad accuracy detected")
-""",
+buf0 = reader.storage('c17fd92682ca5b304ac71074b558dda9e8eb4d66', 48)
+t0 = reader.tensor(buf0, (3, 4))""",
         )
-        with report_compile_source_on_error():
-            exec(r, {"__compile_source__": r})
+        test(
+            torch.ones(3, 4, dtype=torch.int32),
+            """\
+buf0 = reader.storage('7c221e2da0c58c700cc2996644dd13d042bd552e', 48, dtype_hint=torch.int32)
+t0 = reader.tensor(buf0, (3, 4), dtype=torch.int32)""",
+        )
+        test(
+            torch.empty((3, 4, 5, 6), memory_format=torch.channels_last).fill_(2),
+            """\
+buf0 = reader.storage('49ebab3961d6221e64c4c72b0aefd976bdd2afc4', 1440)
+t0 = reader.tensor(buf0, (3, 4, 5, 6), (120, 1, 24, 4))""",
+        )
 
 
 if __name__ == "__main__":
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 8171fc4..914729f 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -99,6 +99,7 @@
     is_floating_point: _bool
     is_complex: _bool
     is_signed: _bool
+    itemsize: _int
 
 # Defined in torch/csrc/TypeInfo.cpp
 class iinfo:
diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py
index 54c07f8..3b936e6 100644
--- a/torch/_dynamo/repro/after_aot.py
+++ b/torch/_dynamo/repro/after_aot.py
@@ -1,5 +1,6 @@
 import copy
 import functools
+import itertools
 import logging
 import os
 import shutil
@@ -8,10 +9,11 @@
 import uuid
 from importlib import import_module
 from tempfile import TemporaryFile
+from typing import Optional, Sequence
 
 import torch
+import torch._prims_common as utils
 import torch.fx as fx
-
 from torch._dynamo.debug_utils import (
     _cuda_system_info_comment,
     AccuracyError,
@@ -24,6 +26,10 @@
     NNModuleToString,
     TEST_REPLACEABLE_COMMENT,
 )
+from torch._dynamo.testing import rand_strided
+from torch.multiprocessing.reductions import StorageWeakRef
+
+from torch.utils._content_store import ContentStoreReader, ContentStoreWriter
 
 from .. import config
 
@@ -66,6 +72,8 @@
             # with fake inputs
             inner_compiled_fn = compiler_fn(gm, example_inputs)
         except Exception as e:
+            # TODO: Failures here are troublesome because no real inputs,
+            # need a different serialization strategy
             if config.repro_after == "aot":
                 if config.repro_level == 1:
                     dump_compiler_graph_state(
@@ -82,7 +90,18 @@
                 log.error("CompilerError")
             raise
 
+        # We may run regular PyTorch compute that may trigger Dynamo, do NOT
+        # recursively attempt to accuracy minify in that case!
         def deferred_for_real_inputs(real_inputs):
+            # This is a bit obscure: if we recursively try to accuracy minify
+            # the SAME function, this would trigger.  But most of the time
+            # we should never hit this branch
+            if config.repro_after != "aot":
+                return inner_compiled_fn(real_inputs)
+            with config.patch(repro_after=None):
+                return inner_debug_fn(real_inputs)
+
+        def inner_debug_fn(real_inputs):
             """
             Aot Autograd fw_compiler and bw_compiler can have fake tensors. So,
             example_inputs can be fake tensors. We can call compiler_fn (which is
@@ -115,12 +134,12 @@
                     )
                     dump_compiler_graph_state(
                         fx.GraphModule(gm, orig_graph),
-                        copy_tensor_attrs,
+                        real_inputs,
                         f"{compiler_name}_accuracy",
                     )
                     dump_to_minify(
                         fx.GraphModule(gm, orig_graph),
-                        copy_tensor_attrs,
+                        real_inputs,
                         f"{compiler_name}_accuracy",
                     )
                     raise AccuracyError("Bad accuracy detected")
@@ -163,6 +182,184 @@
 
 
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#                       REPRO SUPPORT CODE
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+# Helper functions for computing what the default values of tensor
+# values should be.  These all coincide with factory functions, e.g., torch.empty
+
+
+def _stride_or_default(
+    stride: Optional[Sequence[int]], *, shape: Sequence[int]
+) -> Sequence[int]:
+    return stride if stride is not None else utils.make_contiguous_strides_for(shape)
+
+
+def _dtype_or_default(dtype: Optional[torch.dtype]) -> torch.dtype:
+    return dtype if dtype is not None else torch.float32
+
+
+def _device_or_default(device: Optional[torch.device]) -> torch.device:
+    return device if device is not None else torch.device("cpu")
+
+
+def _storage_offset_or_default(storage_offset: Optional[int]) -> int:
+    return storage_offset if storage_offset is not None else 0
+
+
+# TODO: Support bundling the entire repro into a zip file for ease of
+# transferring around
+class InputReader:
+    def __init__(self, save_dir=None):
+        # If None, we will generate random data instead.  It's important
+        # to natively support this use case as it will allow people to
+        # share repros without including the real data, if the problem
+        # reproduces even on random data.
+        if save_dir is None:
+            log.warning("no save_dir specified, will generate random data")
+        self.store = ContentStoreReader(save_dir) if save_dir is not None else None
+
+    def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None):
+        device = _device_or_default(device)
+        dtype_hint = _dtype_or_default(dtype_hint)
+        if self.store is not None and storage_hash is not None:
+            try:
+                storage = self.store.read_storage(storage_hash)
+            except FileNotFoundError:
+                pass
+            else:
+                if device != storage.device:
+                    log.warning("device mismatch: %s != %s", device, storage.device)
+                    # TODO: transfer it to the right device?  But failing this
+                    # way would be very mysterious!  Would have been better
+                    # not to store device in the serialized format...
+                return storage
+        log.warning("could not load %s, generating random data instead", storage_hash)
+        shape = (nbytes // dtype_hint.itemsize,)
+        stride = _stride_or_default(None, shape=shape)
+        return rand_strided(shape, stride, dtype_hint, device).untyped_storage()
+
+    def tensor(
+        self,
+        storage,
+        shape,
+        stride=None,
+        *,
+        storage_offset=None,
+        dtype=None,
+        **metadata,
+    ):
+        stride = _stride_or_default(stride, shape=shape)
+        storage_offset = _storage_offset_or_default(storage_offset)
+        dtype = _dtype_or_default(dtype)
+        t = torch.tensor([], dtype=dtype, device=storage.device)
+        t.set_(storage, storage_offset, shape, stride)
+        torch._utils.set_tensor_metadata(t, metadata)
+        return t
+
+    def symint(self, val):
+        return val
+
+
+# Here is our writer strategy:
+#  1. We will stream all of the inputs to disk
+#  2. You can now deterministically randomize the inputs, or reload
+#     the inputs from disk
+#  3. You can YOLO run the script without the inputs, in which case
+#     we'll fill the inputs with random data and pray.  This is the
+#     legacy behavior, but it's also useful if you want to find out
+#     if we're so broken even random inputs trigger it
+#  4. We could offer an in process "check if the randomized thing
+#     works too" but this is delicate so we don't do it
+
+
+class InputWriter:
+    def __init__(self, save_dir, *, stable_hash=False):
+        self.lines = [
+            "import torch._dynamo.repro.after_aot",
+            f"reader = torch._dynamo.repro.after_aot.InputReader(save_dir={save_dir!r})",
+        ]
+        # TODO: consider ensuring tensor and storage counters line up?
+        self.tensor_counter = itertools.count()
+        self.symint_counter = itertools.count()
+        self.storage_counter = itertools.count()
+        self.store = (
+            ContentStoreWriter(save_dir, stable_hash=stable_hash)
+            if save_dir is not None
+            else None
+        )
+        self.seen_storages = {}
+
+    # Storages are untyped, but we need to initialize them with data if
+    # we don't have the real data, so we give a hint saying what kind
+    # of initialization may be appropriate
+    #
+    # If we had a FakeTensor, device_hint tells us what device should be
+    def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str:
+        ws = StorageWeakRef(untyped_storage)
+        v = self.seen_storages.get(ws)
+        if v is not None:
+            return v
+        v = f"buf{next(self.storage_counter)}"
+        maybe_dtype_hint = ""
+        if _dtype_or_default(None) != _dtype_or_default(dtype_hint):
+            maybe_dtype_hint = f", dtype_hint={dtype_hint!r}"
+        # TODO: being optional on device is kind of pointless as the default
+        # is CPU but most repros we care about are CUDA
+        maybe_device = ""
+        device = untyped_storage.device
+        if device.type == "meta":
+            assert device_hint is not None
+            device = device_hint
+        if _device_or_default(None) != device:
+            maybe_device = f", device={device!r}"
+        nbytes = untyped_storage.nbytes()
+        storage_hash = None
+        if self.store is not None and untyped_storage.device.type != "meta":
+            storage_hash = self.store.write_storage(untyped_storage)
+        self.lines.append(
+            f"{v} = reader.storage({storage_hash!r}, {nbytes!r}{maybe_device}{maybe_dtype_hint})"
+        )
+        self.seen_storages[ws] = v
+        return v
+
+    def tensor(self, t) -> str:
+        storage = self.storage(
+            t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device
+        )
+        maybe_stride = ""
+        if _stride_or_default(None, shape=t.shape) != t.stride():
+            maybe_stride = f", {tuple(t.stride())}"
+        maybe_dtype = ""
+        if _dtype_or_default(None) != t.dtype:
+            maybe_dtype = f", dtype={t.dtype!r}"
+        maybe_storage_offset = ""
+        if _storage_offset_or_default(None) != t.storage_offset():
+            maybe_storage_offset = f", storage_offset={t.storage_offset()!r}"
+        maybe_tensor_metadata = ""
+        tensor_metadata = torch._utils.get_tensor_metadata(t)
+        if tensor_metadata:
+            maybe_tensor_metadata = ", " + ", ".join(
+                f"{k}={v!r}" for k, v in tensor_metadata.items()
+            )
+        v = f"t{next(self.tensor_counter)}"
+        self.lines.append(
+            f"{v} = reader.tensor({storage}, {tuple(t.shape)}"
+            f"{maybe_stride}{maybe_storage_offset}{maybe_dtype}{maybe_tensor_metadata})"
+        )
+        return v
+
+    # TODO: this doesn't actually symint atm
+    def symint(self, val) -> str:
+        if isinstance(val, torch.SymInt):
+            val = val.node.hint
+        v = f"s{next(self.symint_counter)}"
+        self.lines.append(f"{v} = reader.symint({val!r})")
+        return v
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 #                           DUMP REPROS
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
 
@@ -183,7 +380,7 @@
 }
 
 
-def generate_compiler_repro_string(gm, args, *, stable_output=False):
+def generate_compiler_repro_string(gm, args, *, stable_output=False, save_dir=None):
     model_str = textwrap.dedent(
         f"""
 import torch
@@ -210,26 +407,24 @@
 
     model_str += NNModuleToString.convert(gm)
 
-    model_str += "args = []\n"
-
     # get hint shape/stride when dynamic shape enabled
     def hint_if_symint(x):
         return tuple(i.node.hint if isinstance(i, torch.SymInt) else i for i in x)
 
-    for arg in args:
-        if isinstance(arg, int):
-            model_str += f"args.append({arg})\n"
-        elif isinstance(arg, torch.SymInt):
-            model_str += f"args.append({arg.node.hint})  # {arg}\n"
+    writer = InputWriter(save_dir)
+    wargs = []
+    for i, arg in enumerate(args):
+        if isinstance(arg, (int, torch.SymInt)):
+            wargs.append(writer.symint(arg))
         elif isinstance(arg, torch.Tensor):
-            model_str += (
-                "args.append(rand_strided"
-                + f"{hint_if_symint(arg.shape), hint_if_symint(arg.stride()), arg.dtype, arg.device.type})"
-                + f"  # shape {tuple(arg.shape)}, stride {arg.stride()}\n"
-            )
+            # TODO: improve these names with FQN
+            wargs.append(writer.tensor(arg))
         else:
             raise TypeError(f"arg is neither SymInt/int nor torch.Tensor, {arg}")
 
+    model_str += "\n".join(writer.lines) + "\n"
+    model_str += f"args = [{', '.join(wargs)}]\n"
+
     # TODO: fake may be better for performance here
     tracing_mode = "real"
     if config.dynamic_shapes:
@@ -238,7 +433,9 @@
     return model_str
 
 
-def save_graph_repro(fd, gm, args, compiler_name, *, stable_output=False):
+def save_graph_repro(
+    fd, gm, args, compiler_name, *, stable_output=False, save_dir=None
+):
     sync_line = ""
     for arg in args:
         if isinstance(arg, torch.Tensor) and arg.is_cuda:
@@ -247,7 +444,11 @@
 
     if "inductor" in compiler_name:
         fd.write("import torch._inductor.overrides\n")
-    fd.write(generate_compiler_repro_string(gm, args, stable_output=stable_output))
+    fd.write(
+        generate_compiler_repro_string(
+            gm, args, stable_output=stable_output, save_dir=save_dir
+        )
+    )
     fd.write(COMPILER_REPRO_OPTIONS[compiler_name][0])
     if "_accuracy" in compiler_name:
         fd.write(
@@ -282,7 +483,7 @@
         "Writing checkpoint with %s nodes to %s", len(gm.graph.nodes), file_name
     )
     with open(file_name, "w") as fd:
-        save_graph_repro(fd, gm, args, compiler_name)
+        save_graph_repro(fd, gm, args, compiler_name, save_dir=subdir)
     curdir = os.getcwd()
     repro_path = os.path.join(curdir, "repro.py")
     try:
@@ -303,11 +504,16 @@
 def dump_to_minify(gm, args, compiler_name: str):
     favored_device = 1 if torch.cuda.device_count() >= 2 else 0
 
+    # TODO: factor this out
+    subdir = os.path.join(minifier_dir(), "checkpoints")
+    if not os.path.exists(subdir):
+        os.makedirs(subdir, exist_ok=True)
+
     contents = textwrap.dedent(
         f"""
 isolate_fails_code_str = None
 
-{generate_compiler_repro_string(gm, args)}
+{generate_compiler_repro_string(gm, args, save_dir=subdir)}
 
 from functools import partial
 from torch._dynamo.repro.after_aot import (
@@ -321,7 +527,10 @@
 minifier(
     mod,
     args,
-    module_fails=partial(isolate_fails, env=env_variables, compiler_name="{compiler_name}", patch_code=isolate_fails_code_str),
+    module_fails=partial(
+        isolate_fails, env=env_variables, compiler_name="{compiler_name}",
+        patch_code=isolate_fails_code_str, save_dir={subdir!r}
+    ),
     dump_state=partial(dump_compiler_graph_state, compiler_name="{compiler_name}"),
 )
         """
@@ -329,7 +538,9 @@
     return helper_for_dump_minify(contents)
 
 
-def isolate_fails(fx_g, args, compiler_name: str, env=None, patch_code=None):
+def isolate_fails(
+    fx_g, args, compiler_name: str, env=None, patch_code=None, save_dir=None
+):
     if env is None:
         env = {}
     subdir = os.path.join(os.getcwd(), "isolate")
@@ -337,7 +548,7 @@
         os.makedirs(subdir, exist_ok=True)
     file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py")
     with open(file_name, "w") as fd:
-        repro_code = generate_compiler_repro_string(fx_g, args)
+        repro_code = generate_compiler_repro_string(fx_g, args, save_dir=save_dir)
         if patch_code is not None:
             repro_code = repro_code.replace(TEST_REPLACEABLE_COMMENT, patch_code)
         fd.write(repro_code)
diff --git a/torch/_dynamo/test_minifier_common.py b/torch/_dynamo/test_minifier_common.py
index 9c7d37b..28fd442 100644
--- a/torch/_dynamo/test_minifier_common.py
+++ b/torch/_dynamo/test_minifier_common.py
@@ -79,6 +79,8 @@
             capture_output=True,
             cwd=repro_dir,
         )
+        print("minifier stdout:", launch_proc.stdout.decode("utf-8"))
+        print("minifier stderr:", launch_proc.stderr.decode("utf-8"))
 
         return launch_proc, launch_code
 
diff --git a/torch/csrc/Dtype.cpp b/torch/csrc/Dtype.cpp
index 24b6737..3da67fc 100644
--- a/torch/csrc/Dtype.cpp
+++ b/torch/csrc/Dtype.cpp
@@ -3,6 +3,7 @@
 #include <structmember.h>
 #include <torch/csrc/Exceptions.h>
 #include <torch/csrc/utils/object_ptr.h>
+#include <torch/csrc/utils/python_numbers.h>
 #include <torch/csrc/utils/python_strings.h>
 #include <torch/csrc/utils/tensor_dtypes.h>
 #include <torch/csrc/utils/tensor_types.h>
@@ -11,6 +12,7 @@
 #include <torch/csrc/Exceptions.h>
 
 PyObject* THPDtype_New(at::ScalarType scalar_type, const std::string& name) {
+  HANDLE_TH_ERRORS
   AT_ASSERT(name.length() < DTYPE_NAME_LEN);
   auto type = (PyTypeObject*)&THPDtypeType;
   auto self = THPObjectPtr{type->tp_alloc(type, 0)};
@@ -20,22 +22,33 @@
   self_->scalar_type = scalar_type;
   std::strncpy(self_->name, name.c_str(), DTYPE_NAME_LEN);
   return self.release();
+  END_HANDLE_TH_ERRORS
 }
 
 PyObject* THPDtype_is_floating_point(THPDtype* self, PyObject* noargs) {
+  HANDLE_TH_ERRORS
   if (at::isFloatingType(self->scalar_type)) {
     Py_RETURN_TRUE;
   } else {
     Py_RETURN_FALSE;
   }
+  END_HANDLE_TH_ERRORS
+}
+
+PyObject* THPDtype_itemsize(THPDtype* self, PyObject* noargs) {
+  HANDLE_TH_ERRORS
+  return THPUtils_packInt64(scalarTypeToTypeMeta(self->scalar_type).itemsize());
+  END_HANDLE_TH_ERRORS
 }
 
 PyObject* THPDtype_is_complex(THPDtype* self, PyObject* noargs) {
+  HANDLE_TH_ERRORS
   if (at::isComplexType(self->scalar_type)) {
     Py_RETURN_TRUE;
   } else {
     Py_RETURN_FALSE;
   }
+  END_HANDLE_TH_ERRORS
 }
 
 PyObject* THPDtype_is_signed(THPDtype* self, PyObject* noargs) {
@@ -49,12 +62,14 @@
 }
 
 PyObject* THPDtype_reduce(PyObject* _self, PyObject* noargs) {
+  HANDLE_TH_ERRORS
   /*
    * For singletons, a string is returned. The string should be interpreted
    * as the name of a global variable.
    */
   auto self = (THPDtype*)_self;
   return THPUtils_packString(self->name);
+  END_HANDLE_TH_ERRORS
 }
 
 typedef PyObject* (*getter)(PyObject*, void*);
@@ -68,6 +83,7 @@
      nullptr},
     {"is_complex", (getter)THPDtype_is_complex, nullptr, nullptr, nullptr},
     {"is_signed", (getter)THPDtype_is_signed, nullptr, nullptr, nullptr},
+    {"itemsize", (getter)THPDtype_itemsize, nullptr, nullptr, nullptr},
     {nullptr}};
 
 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
diff --git a/torch/utils/_content_store.py b/torch/utils/_content_store.py
index 12463cc..fbc6285 100644
--- a/torch/utils/_content_store.py
+++ b/torch/utils/_content_store.py
@@ -137,12 +137,13 @@
     #       0000..00
     #   tensors/
     #     name
-    def __init__(self, loc: str) -> None:
+    def __init__(self, loc: str, stable_hash: bool = False) -> None:
         self.loc: str = loc
         self.seen_storage_hashes: Set[str] = set()
+        self.stable_hash = stable_hash
 
     def write_storage(self, storage: torch.UntypedStorage) -> str:
-        h = hash_storage(storage)
+        h = hash_storage(storage, stable_hash=self.stable_hash)
         if h in self.seen_storage_hashes:
             return h
         # TODO: consider not using torch.save for this; we don't actually