|  | import copy | 
|  | import functools | 
|  | import getpass | 
|  | import itertools | 
|  | import logging | 
|  | import os | 
|  | import subprocess | 
|  | import tempfile | 
|  | import textwrap | 
|  | from collections import Counter | 
|  | from importlib import import_module | 
|  | from typing import Callable, Optional, Sequence, TypeVar | 
|  |  | 
|  | import torch | 
|  | import torch._prims_common as utils | 
|  | import torch._subclasses.meta_utils | 
|  |  | 
|  | from torch._dynamo.testing import rand_strided | 
|  | from torch._prims_common import is_float_dtype | 
|  | from torch.multiprocessing.reductions import StorageWeakRef | 
|  | from torch.utils._content_store import ContentStoreReader, ContentStoreWriter | 
|  |  | 
|  | from . import config | 
|  | from .utils import clone_inputs, get_debug_dir | 
|  |  | 
|  | log = logging.getLogger(__name__) | 
|  |  | 
|  | T = TypeVar("T") | 
|  |  | 
|  |  | 
|  | inductor_config = import_module("torch._inductor.config") | 
|  | use_buck = inductor_config.is_fbcode() | 
|  |  | 
|  | if use_buck: | 
|  | import libfb.py.build_info  # type: ignore[import] | 
|  |  | 
|  |  | 
|  | extra_deps = [] | 
|  | extra_imports = "" | 
|  | if use_buck: | 
|  | extra_deps = [ | 
|  | "//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu", | 
|  | "//caffe2/torch/fb/sparsenn:sparsenn_operators", | 
|  | "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu", | 
|  | "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops", | 
|  | ] | 
|  | cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//") | 
|  | extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps]) | 
|  |  | 
|  |  | 
|  | BUCK_CMD_PREFIX = ["buck2", "run", "@mode/dev-nosan"] | 
|  |  | 
|  |  | 
|  | class BuckTargetWriter: | 
|  | def __init__(self, filename): | 
|  | self.subdir, self.py_file = os.path.split(os.path.abspath(filename)) | 
|  | self.target = self.py_file.replace(".py", "") | 
|  |  | 
|  | # Get main_module path from fbcode | 
|  | self.path = f'{self.subdir.replace("/", ".")}.{self.target}' | 
|  | self.path = self.path[self.path.find("fbcode.") :] | 
|  | self.path = self.path[7:] | 
|  |  | 
|  | # Get cmd line path | 
|  | tmp = self.subdir | 
|  | tmp = tmp[tmp.find("fbcode/") :][7:] | 
|  | self.cmd_line_path = f"//{tmp}:{self.target}" | 
|  |  | 
|  | def build(self): | 
|  | extra_cpp_deps = "\n".join([f'        "{x}",' for x in extra_deps]) | 
|  | return textwrap.dedent( | 
|  | f""" | 
|  | load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary") | 
|  |  | 
|  | python_binary( | 
|  | name="{self.target}", | 
|  | srcs = ["{self.py_file}"], | 
|  | compile = False, | 
|  | deps = [ | 
|  | "//caffe2:torch", | 
|  | "//caffe2/functorch:functorch", | 
|  | "//triton:triton", | 
|  | "{cur_target}", | 
|  | ], | 
|  | cpp_deps = [ | 
|  | {extra_cpp_deps} | 
|  | ], | 
|  | main_module = "{self.path}", | 
|  | ) | 
|  | """ | 
|  | ) | 
|  |  | 
|  | def write(self, print_msg=True): | 
|  | target_file = os.path.join(self.subdir, "TARGETS") | 
|  | with open(target_file, "w") as fd: | 
|  | fd.write(self.build()) | 
|  | # log.warning("Wrote isolation TARGETS file at %s", target_file) | 
|  | cmd_split = BUCK_CMD_PREFIX + [self.cmd_line_path] | 
|  | if print_msg: | 
|  | log.warning( | 
|  | "Found an example that reproduces the error. Run this cmd to repro - %s", | 
|  | " ".join(cmd_split), | 
|  | ) | 
|  | return cmd_split | 
|  |  | 
|  |  | 
|  | def minifier_dir(): | 
|  | path = os.path.join(get_debug_dir(), "minifier") | 
|  | if path is None: | 
|  | path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}" | 
|  | if not os.path.exists(path): | 
|  | os.makedirs(path, exist_ok=True) | 
|  | return path | 
|  |  | 
|  |  | 
|  | MAX_CONSTANT_NUMEL_INLINE = 4 | 
|  |  | 
|  |  | 
|  | class NNModuleToString: | 
|  | safe_reprs = [ | 
|  | torch.nn.Linear, | 
|  | torch.nn.Conv1d, | 
|  | torch.nn.Conv2d, | 
|  | torch.nn.Conv3d, | 
|  | torch.nn.BatchNorm1d, | 
|  | torch.nn.BatchNorm2d, | 
|  | torch.nn.BatchNorm3d, | 
|  | torch.nn.LayerNorm, | 
|  | torch.nn.Dropout, | 
|  | torch.nn.Softmax, | 
|  | torch.nn.ReLU, | 
|  | torch.nn.GELU, | 
|  | torch.nn.Identity, | 
|  | torch.nn.MaxPool2d, | 
|  | torch.nn.Embedding, | 
|  | torch.nn.Tanh, | 
|  | torch.nn.ConvTranspose1d, | 
|  | torch.nn.GLU, | 
|  | torch.nn.LSTM, | 
|  | torch.nn.Flatten, | 
|  | torch.nn.AdaptiveAvgPool2d, | 
|  | ] | 
|  |  | 
|  | @staticmethod | 
|  | def can_convert_to_string(gm): | 
|  | cant_convert = set() | 
|  | for _, module in gm.named_children(): | 
|  | if type(module) not in NNModuleToString.safe_reprs: | 
|  | cant_convert.add(module) | 
|  |  | 
|  | if len(cant_convert) > 0: | 
|  | log.warning("We have not tested reprs of some modules - %s", cant_convert) | 
|  | # TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct. | 
|  | return True | 
|  |  | 
|  | @staticmethod | 
|  | def convert(gm): | 
|  | from torch.nn.modules.module import _addindent | 
|  |  | 
|  | tab = " " * 4 | 
|  |  | 
|  | model_str = textwrap.dedent( | 
|  | """ | 
|  | from torch.nn import * | 
|  | class Repro(torch.nn.Module): | 
|  | def __init__(self): | 
|  | super().__init__() | 
|  | """ | 
|  | ) | 
|  |  | 
|  | for module_name, module in gm.named_children(): | 
|  | module_str = f"{module.__repr__()}" | 
|  | # module should be a core torch.nn.Module, so all parameters | 
|  | # should be on the same device. | 
|  | example_param = next(module.parameters(), None) | 
|  | if example_param is not None and example_param.is_cuda: | 
|  | module_str = f"{module_str}.cuda()" | 
|  | model_str += f"{tab*2}self.{module_name} = {module_str}\n" | 
|  |  | 
|  | for buffer_name, buffer in gm._buffers.items(): | 
|  | if buffer is None: | 
|  | continue | 
|  | # Serialize full data for small buffers | 
|  | if buffer.numel() <= MAX_CONSTANT_NUMEL_INLINE: | 
|  | from torch._tensor_str import PRINT_OPTS | 
|  |  | 
|  | assert PRINT_OPTS.threshold >= MAX_CONSTANT_NUMEL_INLINE | 
|  | tensor_str = repr(buffer) | 
|  | elif torch.is_floating_point(buffer): | 
|  | tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})" | 
|  | else: | 
|  | tensor_str = ( | 
|  | f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})" | 
|  | ) | 
|  | if buffer.is_cuda: | 
|  | tensor_str = f"{tensor_str}.cuda()" | 
|  | model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n" | 
|  |  | 
|  | for param_name, param in gm._parameters.items(): | 
|  | if param is None: | 
|  | continue | 
|  | maybe_device = "" | 
|  | if param.is_cuda: | 
|  | maybe_device = ', device="cuda"' | 
|  | tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}{maybe_device}))" | 
|  | model_str += f"{tab*2}self.{param_name} = {tensor_str}\n" | 
|  |  | 
|  | # TODO - Keep this code for now. But, I don't think we will need this. | 
|  | # attrs = dir(gm) | 
|  | # for attr in attrs: | 
|  | #     if "_tensor_constant" in attr: | 
|  | #         val = getattr(gm, attr) | 
|  | #         model_str += f"    {attr} = {val!r}\n" | 
|  |  | 
|  | model_str += f"{_addindent(gm.code, 4)}\n" | 
|  | return model_str | 
|  |  | 
|  |  | 
|  | @functools.lru_cache(None)  # subprocess is expensive | 
|  | def _cuda_system_info_comment(): | 
|  | if not torch.cuda.is_available(): | 
|  | return "# torch.cuda.is_available()==False, no GPU info collected\n" | 
|  |  | 
|  | model_str = "# CUDA Info: \n" | 
|  | try: | 
|  | cuda_version_out = subprocess.run(["nvcc", "--version"], stdout=subprocess.PIPE) | 
|  | cuda_version_lines = cuda_version_out.stdout.decode().split("\n") | 
|  | comment = "".join([f"# {s} \n" for s in cuda_version_lines if s not in [""]]) | 
|  | model_str += f"{comment}\n" | 
|  | except FileNotFoundError: | 
|  | model_str += "# nvcc not found\n" | 
|  |  | 
|  | gpu_names = Counter( | 
|  | torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count()) | 
|  | ) | 
|  |  | 
|  | model_str += "# GPU Hardware Info: \n" | 
|  | for name, count in gpu_names.items(): | 
|  | model_str += f"# {name} : {count} \n" | 
|  | model_str += "\n" | 
|  | return model_str | 
|  |  | 
|  |  | 
|  | def generate_config_string(*, stable_output=False): | 
|  | import torch._functorch.config | 
|  | import torch._inductor.config | 
|  |  | 
|  | if stable_output: | 
|  | return "# config omitted due to stable_output=True" | 
|  |  | 
|  | return f"""\ | 
|  | import torch._dynamo.config | 
|  | import torch._inductor.config | 
|  | import torch._functorch.config | 
|  | {torch._dynamo.config.codegen_config()} | 
|  | {torch._inductor.config.codegen_config()} | 
|  | {torch._functorch.config.codegen_config()} | 
|  | """ | 
|  |  | 
|  |  | 
|  | def get_minifier_repro_path(): | 
|  | return os.path.join(minifier_dir(), "minifier_launcher.py") | 
|  |  | 
|  |  | 
|  | def helper_for_dump_minify(contents): | 
|  | minified_repro_path = get_minifier_repro_path() | 
|  | log.warning("Writing minified repro to:\n%s", minified_repro_path) | 
|  |  | 
|  | if use_buck: | 
|  | BuckTargetWriter(minified_repro_path).write() | 
|  | try: | 
|  | with open(minified_repro_path, "w") as fd: | 
|  | fd.write(contents) | 
|  |  | 
|  | except OSError as e: | 
|  | log.exception(e) | 
|  | raise NotImplementedError("Could not write to {minified_repro_path}") from e | 
|  |  | 
|  |  | 
|  | class AccuracyError(Exception): | 
|  | pass | 
|  |  | 
|  |  | 
|  | def clone_inputs_retaining_gradness(example_inputs): | 
|  | """ | 
|  | This clone inputs is different from utils clone_input. In case of minifier, | 
|  | all the tensors are leaf tensors while creating a new graph. So, we set the | 
|  | requires_grad field w/o checking the leafness of the tensor. | 
|  | """ | 
|  | cloned_inputs = clone_inputs(example_inputs) | 
|  | for idx in range(len(example_inputs)): | 
|  | if isinstance(cloned_inputs[idx], torch.Tensor): | 
|  | cloned_inputs[idx].requires_grad_(example_inputs[idx].requires_grad) | 
|  | return cloned_inputs | 
|  |  | 
|  |  | 
|  | def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False): | 
|  | """ | 
|  | Runs a forward and possibly backward iteration for a given mod and args. | 
|  |  | 
|  | When disable_clone is True, we will use args as-is without cloning. | 
|  | This is higher fidelity but we may destroy the args in the process. | 
|  | """ | 
|  | from torch._functorch.aot_autograd import make_boxed_func | 
|  |  | 
|  | from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass | 
|  |  | 
|  | gm = copy.deepcopy(gm) | 
|  | if not disable_clone: | 
|  | args = clone_inputs_retaining_gradness(args) | 
|  |  | 
|  | if hasattr(gm, "zero_grad"): | 
|  | gm.zero_grad(True) | 
|  |  | 
|  | # TorchInductor returned callable expects lists. So, boxing the call. | 
|  | orig_named_parameters = getattr(gm, "named_parameters", None) | 
|  | orig_named_buffers = getattr(gm, "named_buffers", None) | 
|  | if not hasattr(gm, "_boxed_call") and ( | 
|  | orig_named_parameters is not None or orig_named_buffers is not None | 
|  | ): | 
|  | gm = make_boxed_func(gm) | 
|  | if orig_named_parameters is not None: | 
|  | gm.named_parameters = orig_named_parameters | 
|  | if orig_named_buffers is not None: | 
|  | gm.named_buffers = orig_named_buffers | 
|  |  | 
|  | out = gm(args) | 
|  | if only_fwd: | 
|  | return out | 
|  | if requires_bwd_pass(out): | 
|  | loss = reduce_to_scalar_loss(out) | 
|  | loss.backward() | 
|  | return collect_results(gm, out, None, args) | 
|  |  | 
|  |  | 
|  | def same_two_models( | 
|  | gm, | 
|  | opt_gm, | 
|  | example_inputs, | 
|  | only_fwd=False, | 
|  | *, | 
|  | require_fp64=False, | 
|  | ignore_non_fp=False, | 
|  | ): | 
|  | """ | 
|  | Check two models have same accuracy. | 
|  |  | 
|  | require_fp64: if True, raise an error if we unable to calculate the fp64 reference | 
|  | ignore_non_fp: if True, do not compare outputs which are not floating point.  This | 
|  | is mostly useful for the minifier (which wants to avoid quantizing floating point | 
|  | error into integer/boolean error) | 
|  | """ | 
|  | from .eval_frame import OptimizedModule | 
|  | from .testing import ( | 
|  | named_buffers_for_optimized_module, | 
|  | named_parameters_for_optimized_module, | 
|  | ) | 
|  | from .utils import same | 
|  |  | 
|  | if isinstance(gm, OptimizedModule): | 
|  | gm.named_parameters = named_parameters_for_optimized_module(gm) | 
|  | gm.named_buffers = named_buffers_for_optimized_module(gm) | 
|  |  | 
|  | if isinstance(opt_gm, OptimizedModule): | 
|  | opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm) | 
|  | opt_gm.named_buffers = named_buffers_for_optimized_module(opt_gm) | 
|  |  | 
|  | ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd) | 
|  |  | 
|  | fp64_ref = None | 
|  | if config.same_two_models_use_fp64: | 
|  | try: | 
|  | fp64_model, fp64_examples = cast_to_fp64( | 
|  | copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs) | 
|  | ) | 
|  | fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd) | 
|  | except Exception: | 
|  | if require_fp64: | 
|  | raise RuntimeError("Could not generate fp64 outputs") | 
|  | log.warning("Could not generate fp64 outputs") | 
|  |  | 
|  | try: | 
|  | res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd) | 
|  | except Exception as e: | 
|  | # This means that the minified graph is bad/exposes a different problem. | 
|  | # As we are checking accuracy here, lets log the exception and return True. | 
|  | log.exception( | 
|  | ( | 
|  | "While minifying the program in accuracy minification mode, " | 
|  | "ran into a runtime exception which is likely an unrelated issue." | 
|  | " Skipping this graph." | 
|  | ) | 
|  | ) | 
|  | return True | 
|  |  | 
|  | passing = same( | 
|  | ref, | 
|  | res, | 
|  | fp64_ref, | 
|  | tol=config.repro_tolerance, | 
|  | equal_nan=True, | 
|  | ignore_non_fp=ignore_non_fp, | 
|  | ) | 
|  | return passing | 
|  |  | 
|  |  | 
|  | def cast_convert_element_type_to_fp64(model): | 
|  | for node in model.graph.nodes: | 
|  | if ( | 
|  | node.op == "call_function" | 
|  | and node.target == torch.ops.prims.convert_element_type.default | 
|  | ): | 
|  | assert len(node.args) == 2 | 
|  | if is_float_dtype(node.args[1]) and node.args[1] != torch.float64: | 
|  | node.args = (node.args[0], torch.float64) | 
|  | model.graph.lint() | 
|  | model.recompile() | 
|  | return model | 
|  |  | 
|  |  | 
|  | def cast_to(dtype, model, inputs): | 
|  | from torch.utils._pytree import tree_map | 
|  |  | 
|  | model = model.to(dtype) | 
|  | if dtype == torch.float64: | 
|  | # If casting to fp64 for accuracy comparison, we need to | 
|  | # take care of convert_element_type explicitly | 
|  | model = cast_convert_element_type_to_fp64(model) | 
|  |  | 
|  | inputs = tree_map( | 
|  | lambda x: x.to(dtype) | 
|  | if isinstance(x, torch.Tensor) and x.is_floating_point() | 
|  | else x, | 
|  | inputs, | 
|  | ) | 
|  | return model, inputs | 
|  |  | 
|  |  | 
|  | def cast_to_fp64(model, inputs): | 
|  | return cast_to(torch.float64, model, inputs) | 
|  |  | 
|  |  | 
|  | def backend_accuracy_fails( | 
|  | gm, | 
|  | example_inputs, | 
|  | compiler_fn, | 
|  | only_fwd=False, | 
|  | *, | 
|  | require_fp64=False, | 
|  | ignore_non_fp=False, | 
|  | ): | 
|  | try: | 
|  | compiled_gm = compiler_fn( | 
|  | copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs) | 
|  | ) | 
|  | return not same_two_models( | 
|  | gm, | 
|  | compiled_gm, | 
|  | example_inputs, | 
|  | only_fwd, | 
|  | require_fp64=require_fp64, | 
|  | ignore_non_fp=ignore_non_fp, | 
|  | ) | 
|  | except Exception as e: | 
|  | # This means that the the minified graph is bad/exposes a different problem. | 
|  | # As we are checking accuracy here, lets log the exception and return False. | 
|  | log.exception( | 
|  | ( | 
|  | "While minifying the program in accuracy minification mode, " | 
|  | "ran into a runtime exception which is likely an unrelated issue." | 
|  | " Skipping this graph" | 
|  | ) | 
|  | ) | 
|  | return False | 
|  |  | 
|  |  | 
|  | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # | 
|  | #                       REPRO SUPPORT CODE | 
|  | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # | 
|  |  | 
|  |  | 
|  | # Helper functions for computing what the default values of tensor | 
|  | # values should be.  These all coincide with factory functions, e.g., torch.empty | 
|  |  | 
|  |  | 
|  | def _stride_or_default( | 
|  | stride: Optional[Sequence[int]], *, shape: Sequence[int] | 
|  | ) -> Sequence[int]: | 
|  | return stride if stride is not None else utils.make_contiguous_strides_for(shape) | 
|  |  | 
|  |  | 
|  | def _mk_defaulter(d: T) -> Callable[[Optional[T]], T]: | 
|  | return lambda x: x if x is not None else d | 
|  |  | 
|  |  | 
|  | _dtype_or_default = _mk_defaulter(torch.float32) | 
|  | _device_or_default = _mk_defaulter(torch.device("cpu")) | 
|  | _storage_offset_or_default = _mk_defaulter(0) | 
|  | _requires_grad_or_default = _mk_defaulter(False) | 
|  | _is_leaf_or_default = _mk_defaulter(False) | 
|  |  | 
|  |  | 
|  | class NopInputReader: | 
|  | def __init__(self): | 
|  | self.total = 0 | 
|  |  | 
|  | def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): | 
|  | self.total += 1 | 
|  |  | 
|  | def tensor(self, *args, **kwargs): | 
|  | pass | 
|  |  | 
|  | def symint(self, *args, **kwargs): | 
|  | pass | 
|  |  | 
|  |  | 
|  | # TODO: Support bundling the entire repro into a zip file for ease of | 
|  | # transferring around | 
|  | class InputReader: | 
|  | def __init__(self, save_dir=None, *, pbar=None): | 
|  | # If None, we will generate random data instead.  It's important | 
|  | # to natively support this use case as it will allow people to | 
|  | # share repros without including the real data, if the problem | 
|  | # reproduces even on random data. | 
|  | if save_dir is None: | 
|  | log.warning("no save_dir specified, will generate random data") | 
|  | self.store = ContentStoreReader(save_dir) if save_dir is not None else None | 
|  | self.args = [] | 
|  | self.pbar = pbar | 
|  |  | 
|  | def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): | 
|  | if self.pbar is not None: | 
|  | self.pbar.update(1) | 
|  | device = _device_or_default(device) | 
|  | dtype_hint = _dtype_or_default(dtype_hint) | 
|  | if self.store is not None and storage_hash is not None: | 
|  | try: | 
|  | storage = self.store.read_storage(storage_hash) | 
|  | except FileNotFoundError: | 
|  | pass | 
|  | else: | 
|  | if device != storage.device: | 
|  | log.warning("device mismatch: %s != %s", device, storage.device) | 
|  | # TODO: transfer it to the right device?  But failing this | 
|  | # way would be very mysterious!  Would have been better | 
|  | # not to store device in the serialized format... | 
|  | return storage | 
|  | log.warning("could not load %s, generating random data instead", storage_hash) | 
|  | shape = (nbytes // dtype_hint.itemsize,) | 
|  | stride = _stride_or_default(None, shape=shape) | 
|  | return rand_strided(shape, stride, dtype_hint, device).untyped_storage() | 
|  |  | 
|  | def tensor( | 
|  | self, | 
|  | storage, | 
|  | shape, | 
|  | stride=None, | 
|  | *, | 
|  | storage_offset=None, | 
|  | dtype=None, | 
|  | requires_grad=None, | 
|  | is_leaf=None, | 
|  | **metadata, | 
|  | ): | 
|  | stride = _stride_or_default(stride, shape=shape) | 
|  | storage_offset = _storage_offset_or_default(storage_offset) | 
|  | dtype = _dtype_or_default(dtype) | 
|  | is_leaf = _is_leaf_or_default(is_leaf) | 
|  | requires_grad = _requires_grad_or_default(requires_grad) | 
|  | t = torch.tensor( | 
|  | [], dtype=dtype, device=storage.device, requires_grad=requires_grad | 
|  | ) | 
|  | with torch.no_grad(): | 
|  | t.set_(storage, storage_offset, shape, stride) | 
|  | if not is_leaf: | 
|  | # Fake up some autograd history in a very naughty way | 
|  | with torch.enable_grad(): | 
|  | t = t.clone(memory_format=torch.preserve_format) | 
|  | with torch.no_grad(): | 
|  | t.set_(storage, storage_offset, shape, stride) | 
|  | assert torch._subclasses.meta_utils.safe_is_leaf(t) == is_leaf | 
|  | torch._utils.set_tensor_metadata(t, metadata) | 
|  | self.args.append(t) | 
|  | return t  # for BC | 
|  |  | 
|  | def symint(self, val): | 
|  | self.args.append(val) | 
|  | return val  # for BC | 
|  |  | 
|  |  | 
|  | # Here is our writer strategy: | 
|  | #  1. We will stream all of the inputs to disk | 
|  | #  2. You can now deterministically randomize the inputs, or reload | 
|  | #     the inputs from disk | 
|  | #  3. You can YOLO run the script without the inputs, in which case | 
|  | #     we'll fill the inputs with random data and pray.  This is the | 
|  | #     legacy behavior, but it's also useful if you want to find out | 
|  | #     if we're so broken even random inputs trigger it | 
|  | #  4. We could offer an in process "check if the randomized thing | 
|  | #     works too" but this is delicate so we don't do it | 
|  |  | 
|  |  | 
|  | class InputWriter: | 
|  | def __init__(self, save_dir, *, stable_hash=False): | 
|  | self._lines = [] | 
|  | # TODO: consider ensuring tensor and storage counters line up? | 
|  | self.storage_counter = itertools.count() | 
|  | self.save_dir = save_dir | 
|  | self.store = ( | 
|  | ContentStoreWriter(save_dir, stable_hash=stable_hash) | 
|  | if save_dir is not None | 
|  | else None | 
|  | ) | 
|  | self.seen_storages = {} | 
|  |  | 
|  | def lines(self): | 
|  | r = [ | 
|  | "def load_args(reader):", | 
|  | ] | 
|  | r.extend(f"    {l}" for l in self._lines) | 
|  | # In case we need to change the internal format of load_args | 
|  | # in an FC-breaking way | 
|  | r.append("load_args._version = 0") | 
|  | return r | 
|  |  | 
|  | # Storages are untyped, but we need to initialize them with data if | 
|  | # we don't have the real data, so we give a hint saying what kind | 
|  | # of initialization may be appropriate | 
|  | # | 
|  | # If we had a FakeTensor, device_hint tells us what device should be | 
|  | def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str: | 
|  | ws = StorageWeakRef(untyped_storage) | 
|  | v = self.seen_storages.get(ws) | 
|  | if v is not None: | 
|  | return v | 
|  | v = f"buf{next(self.storage_counter)}" | 
|  | maybe_dtype_hint = "" | 
|  | if _dtype_or_default(None) != _dtype_or_default(dtype_hint): | 
|  | maybe_dtype_hint = f", dtype_hint={dtype_hint!r}" | 
|  | # TODO: being optional on device is kind of pointless as the default | 
|  | # is CPU but most repros we care about are CUDA | 
|  | maybe_device = "" | 
|  | device = untyped_storage.device | 
|  | if device.type == "meta": | 
|  | assert device_hint is not None | 
|  | device = device_hint | 
|  | if _device_or_default(None) != device: | 
|  | maybe_device = f", device={device!r}" | 
|  | nbytes = untyped_storage.nbytes() | 
|  | storage_hash = None | 
|  | if self.store is not None and untyped_storage.device.type != "meta": | 
|  | storage_hash = self.store.write_storage(untyped_storage) | 
|  | self._lines.append( | 
|  | f"{v} = reader.storage({storage_hash!r}, {nbytes!r}{maybe_device}{maybe_dtype_hint})" | 
|  | ) | 
|  | self.seen_storages[ws] = v | 
|  | return v | 
|  |  | 
|  | def tensor(self, name, t) -> None: | 
|  | storage = self.storage( | 
|  | t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device | 
|  | ) | 
|  | args = [] | 
|  | # NB: this is positional, must come first | 
|  | if _stride_or_default(None, shape=t.shape) != t.stride(): | 
|  | args.append(str(tuple(t.stride()))) | 
|  | if _dtype_or_default(None) != t.dtype: | 
|  | args.append(f"dtype={t.dtype!r}") | 
|  | if _storage_offset_or_default(None) != t.storage_offset(): | 
|  | args.append(f"storage_offset={t.storage_offset()!r}") | 
|  | tensor_metadata = torch._utils.get_tensor_metadata(t) | 
|  | if tensor_metadata: | 
|  | args.extend(f"{k}={v!r}" for k, v in tensor_metadata.items()) | 
|  | if _requires_grad_or_default(None) != t.requires_grad: | 
|  | args.append(f"requires_grad={t.requires_grad!r}") | 
|  | is_leaf = torch._subclasses.meta_utils.safe_is_leaf(t) | 
|  | if _is_leaf_or_default(None) != is_leaf: | 
|  | args.append(f"is_leaf={is_leaf!r}") | 
|  | self._lines.append( | 
|  | "reader.tensor(" | 
|  | + ", ".join([storage, str(tuple(t.shape)), *args]) | 
|  | + f")  # {name}" | 
|  | ) | 
|  |  | 
|  | # TODO: this doesn't actually symint atm | 
|  | def symint(self, name, val) -> None: | 
|  | if isinstance(val, torch.SymInt): | 
|  | val = val.node.hint | 
|  | self._lines.append(f"reader.symint({val!r})  # {name}") |