| # mypy: disable-error-code="method-assign" |
| |
| import copy |
| import functools |
| import getpass |
| import itertools |
| import logging |
| import os |
| import subprocess |
| import tempfile |
| import textwrap |
| from collections import Counter |
| from importlib import import_module |
| from typing import Callable, Optional, TypeVar |
| |
| import torch |
| import torch._prims_common as utils |
| import torch._subclasses.meta_utils |
| |
| from torch._dynamo.testing import rand_strided |
| from torch._prims_common import is_float_dtype |
| from torch.multiprocessing.reductions import StorageWeakRef |
| from torch.utils._content_store import ContentStoreReader, ContentStoreWriter |
| |
| from . import config |
| from .utils import clone_inputs, get_debug_dir |
| |
| log = logging.getLogger(__name__) |
| |
| T = TypeVar("T") |
| |
| |
| inductor_config = import_module("torch._inductor.config") |
| use_buck = inductor_config.is_fbcode() |
| |
| if use_buck: |
| import libfb.py.build_info |
| |
| |
| extra_deps = [] |
| extra_imports = "" |
| if use_buck: |
| extra_deps = [ |
| "//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu", |
| "//caffe2/torch/fb/sparsenn:sparsenn_operators", |
| "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu", |
| "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops", |
| ] |
| cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//") |
| extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps]) |
| |
| |
| BUCK_CMD_PREFIX = ["buck2", "run", "@mode/dev-nosan"] |
| |
| |
| class BuckTargetWriter: |
| def __init__(self, filename): |
| self.subdir, self.py_file = os.path.split(os.path.abspath(filename)) |
| self.target = self.py_file.replace(".py", "") |
| |
| # Get main_module path from fbcode |
| self.path = f'{self.subdir.replace("/", ".")}.{self.target}' |
| self.path = self.path[self.path.find("fbcode.") :] |
| self.path = self.path[7:] |
| |
| # Get cmd line path |
| tmp = self.subdir |
| tmp = tmp[tmp.find("fbcode/") :][7:] |
| self.cmd_line_path = f"//{tmp}:{self.target}" |
| |
| def build(self): |
| extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps]) |
| return textwrap.dedent( |
| f""" |
| load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary") |
| |
| python_binary( |
| name="{self.target}", |
| srcs = ["{self.py_file}"], |
| compile = False, |
| deps = [ |
| "//caffe2:torch", |
| "//caffe2/functorch:functorch", |
| "//triton:triton", |
| "{cur_target}", |
| ], |
| cpp_deps = [ |
| {extra_cpp_deps} |
| ], |
| main_module = "{self.path}", |
| ) |
| """ |
| ) |
| |
| def write(self, print_msg=True): |
| target_file = os.path.join(self.subdir, "TARGETS") |
| with open(target_file, "w") as fd: |
| fd.write(self.build()) |
| # log.warning("Wrote isolation TARGETS file at %s", target_file) |
| cmd_split = BUCK_CMD_PREFIX + [self.cmd_line_path] |
| if print_msg: |
| log.warning( |
| "Found an example that reproduces the error. Run this cmd to repro - %s", |
| " ".join(cmd_split), |
| ) |
| return cmd_split |
| |
| |
| def minifier_dir(): |
| path = os.path.join(get_debug_dir(), "minifier") |
| if path is None: |
| path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}" |
| if not os.path.exists(path): |
| os.makedirs(path, exist_ok=True) |
| return path |
| |
| |
| MAX_CONSTANT_NUMEL_INLINE = 4 |
| |
| |
| class NNModuleToString: |
| safe_reprs = [ |
| torch.nn.Linear, |
| torch.nn.Conv1d, |
| torch.nn.Conv2d, |
| torch.nn.Conv3d, |
| torch.nn.BatchNorm1d, |
| torch.nn.BatchNorm2d, |
| torch.nn.BatchNorm3d, |
| torch.nn.LayerNorm, |
| torch.nn.Dropout, |
| torch.nn.Softmax, |
| torch.nn.ReLU, |
| torch.nn.GELU, |
| torch.nn.Identity, |
| torch.nn.MaxPool2d, |
| torch.nn.Embedding, |
| torch.nn.Tanh, |
| torch.nn.ConvTranspose1d, |
| torch.nn.GLU, |
| torch.nn.LSTM, |
| torch.nn.Flatten, |
| torch.nn.AdaptiveAvgPool2d, |
| ] |
| |
| @staticmethod |
| def can_convert_to_string(gm): |
| cant_convert = set() |
| for _, module in gm.named_children(): |
| if type(module) not in NNModuleToString.safe_reprs: |
| cant_convert.add(module) |
| |
| if len(cant_convert) > 0: |
| log.warning("We have not tested reprs of some modules - %s", cant_convert) |
| # TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct. |
| return True |
| |
| @staticmethod |
| def convert(gm): |
| from torch.nn.modules.module import _addindent |
| |
| tab = " " * 4 |
| |
| model_str = textwrap.dedent( |
| """ |
| from torch.nn import * |
| class Repro(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| """ |
| ) |
| |
| for module_name, module in gm.named_children(): |
| module_str = f"{module.__repr__()}" |
| # module should be a core torch.nn.Module, so all parameters |
| # should be on the same device. |
| example_param = next(module.parameters(), None) |
| if example_param is not None and example_param.is_cuda: |
| module_str = f"{module_str}.cuda()" |
| model_str += f"{tab*2}self.{module_name} = {module_str}\n" |
| |
| for buffer_name, buffer in gm._buffers.items(): |
| if buffer is None: |
| continue |
| # Serialize full data for small buffers |
| if buffer.numel() <= MAX_CONSTANT_NUMEL_INLINE: |
| from torch._tensor_str import PRINT_OPTS |
| |
| assert PRINT_OPTS.threshold >= MAX_CONSTANT_NUMEL_INLINE |
| tensor_str = repr(buffer) |
| elif torch.is_floating_point(buffer): |
| tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})" |
| else: |
| tensor_str = ( |
| f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})" |
| ) |
| if buffer.is_cuda: |
| tensor_str = f"{tensor_str}.cuda()" |
| model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n" |
| |
| for param_name, param in gm._parameters.items(): |
| if param is None: |
| continue |
| maybe_device = "" |
| if param.is_cuda: |
| maybe_device = ', device="cuda"' |
| tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}{maybe_device}))" |
| model_str += f"{tab*2}self.{param_name} = {tensor_str}\n" |
| |
| # TODO - Keep this code for now. But, I don't think we will need this. |
| # attrs = dir(gm) |
| # for attr in attrs: |
| # if "_tensor_constant" in attr: |
| # val = getattr(gm, attr) |
| # model_str += f" {attr} = {val!r}\n" |
| |
| model_str += f"{_addindent(gm.code, 4)}\n" |
| return model_str |
| |
| |
| @functools.lru_cache(None) # subprocess is expensive |
| def _cuda_system_info_comment(): |
| if not torch.cuda.is_available(): |
| return "# torch.cuda.is_available()==False, no GPU info collected\n" |
| |
| model_str = "# CUDA Info: \n" |
| try: |
| cuda_version_out = subprocess.check_output(["nvcc", "--version"]) |
| cuda_version_lines = cuda_version_out.decode().split("\n") |
| comment = "".join([f"# {s} \n" for s in cuda_version_lines if s not in [""]]) |
| model_str += f"{comment}\n" |
| except FileNotFoundError: |
| model_str += "# nvcc not found\n" |
| |
| gpu_names = Counter( |
| torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count()) |
| ) |
| |
| model_str += "# GPU Hardware Info: \n" |
| for name, count in gpu_names.items(): |
| model_str += f"# {name} : {count} \n" |
| model_str += "\n" |
| return model_str |
| |
| |
| def generate_config_string(*, stable_output=False): |
| import torch._functorch.config |
| import torch._inductor.config |
| |
| if stable_output: |
| return "# config omitted due to stable_output=True" |
| |
| return f"""\ |
| import torch._dynamo.config |
| import torch._inductor.config |
| import torch._functorch.config |
| import torch.fx.experimental._config |
| {torch._dynamo.config.codegen_config()} |
| {torch._inductor.config.codegen_config()} |
| {torch._functorch.config.codegen_config()} |
| {torch.fx.experimental._config.codegen_config()} |
| """ |
| |
| |
| def get_minifier_repro_path(): |
| return os.path.join(minifier_dir(), "minifier_launcher.py") |
| |
| |
| def helper_for_dump_minify(contents): |
| minified_repro_path = get_minifier_repro_path() |
| log.warning("Writing minified repro to:\n%s", minified_repro_path) |
| |
| if use_buck: |
| BuckTargetWriter(minified_repro_path).write() |
| try: |
| with open(minified_repro_path, "w") as fd: |
| fd.write(contents) |
| |
| except OSError as e: |
| log.exception(e) |
| raise NotImplementedError("Could not write to {minified_repro_path}") from e |
| |
| |
| class AccuracyError(Exception): |
| pass |
| |
| |
| def clone_inputs_retaining_gradness(example_inputs): |
| """ |
| This clone inputs is different from utils clone_input. In case of minifier, |
| all the tensors are leaf tensors while creating a new graph. So, we set the |
| requires_grad field w/o checking the leafness of the tensor. |
| """ |
| cloned_inputs = clone_inputs(example_inputs) |
| for idx in range(len(example_inputs)): |
| if isinstance(cloned_inputs[idx], torch.Tensor): |
| cloned_inputs[idx].requires_grad_(example_inputs[idx].requires_grad) |
| return cloned_inputs |
| |
| |
| def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False): |
| """ |
| Runs a forward and possibly backward iteration for a given mod and args. |
| |
| When disable_clone is True, we will use args as-is without cloning. |
| This is higher fidelity but we may destroy the args in the process. |
| """ |
| from torch._functorch.aot_autograd import make_boxed_func |
| |
| from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass |
| |
| gm = copy.deepcopy(gm) |
| if not disable_clone: |
| args = clone_inputs_retaining_gradness(args) |
| |
| if hasattr(gm, "zero_grad"): |
| gm.zero_grad(True) |
| |
| # TorchInductor returned callable expects lists. So, boxing the call. |
| orig_named_parameters = getattr(gm, "named_parameters", None) |
| orig_named_buffers = getattr(gm, "named_buffers", None) |
| if not hasattr(gm, "_boxed_call") and ( |
| orig_named_parameters is not None or orig_named_buffers is not None |
| ): |
| gm = make_boxed_func(gm) |
| if orig_named_parameters is not None: |
| gm.named_parameters = orig_named_parameters |
| if orig_named_buffers is not None: |
| gm.named_buffers = orig_named_buffers |
| |
| out = gm(args) |
| if only_fwd: |
| return out |
| if requires_bwd_pass(out): |
| loss = reduce_to_scalar_loss(out) |
| loss.backward() |
| return collect_results(gm, out, None, args) |
| |
| |
| def same_two_models( |
| gm, |
| opt_gm, |
| example_inputs, |
| only_fwd=False, |
| *, |
| require_fp64=False, |
| ignore_non_fp=False, |
| ): |
| """ |
| Check two models have same accuracy. |
| |
| require_fp64: if True, raise an error if we unable to calculate the fp64 reference |
| ignore_non_fp: if True, do not compare outputs which are not floating point. This |
| is mostly useful for the minifier (which wants to avoid quantizing floating point |
| error into integer/boolean error) |
| """ |
| from .eval_frame import OptimizedModule |
| from .testing import ( |
| named_buffers_for_optimized_module, |
| named_parameters_for_optimized_module, |
| ) |
| from .utils import same |
| |
| if isinstance(gm, OptimizedModule): |
| gm.named_parameters = named_parameters_for_optimized_module(gm) |
| gm.named_buffers = named_buffers_for_optimized_module(gm) |
| |
| if isinstance(opt_gm, OptimizedModule): |
| opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm) |
| opt_gm.named_buffers = named_buffers_for_optimized_module(opt_gm) |
| |
| ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd) |
| |
| fp64_ref = None |
| if config.same_two_models_use_fp64: |
| try: |
| fp64_model, fp64_examples = cast_to_fp64( |
| copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs) |
| ) |
| fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd) |
| except Exception: |
| if require_fp64: |
| raise RuntimeError("Could not generate fp64 outputs") # noqa: TRY200 |
| log.warning("Could not generate fp64 outputs") |
| |
| try: |
| res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd) |
| except Exception as e: |
| # This means that the minified graph is bad/exposes a different problem. |
| # As we are checking accuracy here, lets log the exception and return True. |
| log.exception( |
| "While minifying the program in accuracy minification mode, " |
| "ran into a runtime exception which is likely an unrelated issue." |
| " Skipping this graph." |
| ) |
| return True |
| |
| passing = same( |
| ref, |
| res, |
| fp64_ref, |
| tol=config.repro_tolerance, |
| equal_nan=True, |
| ignore_non_fp=ignore_non_fp, |
| ) |
| return passing |
| |
| |
| def cast_dtype_args_to_fp64(model): |
| for node in model.graph.nodes: |
| if ( |
| node.op == "call_function" |
| and node.target == torch.ops.prims.convert_element_type.default |
| ): |
| assert len(node.args) == 2 |
| if is_float_dtype(node.args[1]) and node.args[1] != torch.float64: |
| node.args = (node.args[0], torch.float64) |
| if node.op == "call_function": |
| dtype = node.kwargs.get("dtype") |
| if dtype is not None and is_float_dtype(dtype): |
| new_kwargs = dict(node.kwargs) |
| new_kwargs["dtype"] = torch.float64 |
| node.kwargs = new_kwargs |
| |
| model.graph.lint() |
| model.recompile() |
| return model |
| |
| |
| def cast_to(dtype, model, inputs): |
| from torch.utils._pytree import tree_map |
| |
| model = model.to(dtype) |
| if dtype == torch.float64: |
| # If casting to fp64 for accuracy comparison, we need to |
| # replace dtype arguments embedded in the graph with fp64 |
| model = cast_dtype_args_to_fp64(model) |
| |
| inputs = tree_map( |
| lambda x: x.to(dtype) |
| if isinstance(x, torch.Tensor) and x.is_floating_point() |
| else x, |
| inputs, |
| ) |
| return model, inputs |
| |
| |
| def cast_to_fp64(model, inputs): |
| return cast_to(torch.float64, model, inputs) |
| |
| |
| def backend_accuracy_fails( |
| gm, |
| example_inputs, |
| compiler_fn, |
| only_fwd=False, |
| *, |
| require_fp64=False, |
| ignore_non_fp=False, |
| ): |
| try: |
| compiled_gm = compiler_fn( |
| copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs) |
| ) |
| return not same_two_models( |
| gm, |
| compiled_gm, |
| example_inputs, |
| only_fwd, |
| require_fp64=require_fp64, |
| ignore_non_fp=ignore_non_fp, |
| ) |
| except Exception as e: |
| # This means that the minified graph is bad/exposes a different problem. |
| # As we are checking accuracy here, lets log the exception and return False. |
| log.exception( |
| "While minifying the program in accuracy minification mode, " |
| "ran into a runtime exception which is likely an unrelated issue." |
| " Skipping this graph" |
| ) |
| return False |
| |
| |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # |
| # REPRO SUPPORT CODE |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # |
| |
| |
| # Helper functions for computing what the default values of tensor |
| # values should be. These all coincide with factory functions, e.g., torch.empty |
| |
| |
| def _stride_or_default( |
| stride: Optional["torch._prims_common.StrideType"], |
| *, |
| shape: "torch._prims_common.ShapeType", |
| ) -> "torch._prims_common.StrideType": |
| return stride if stride is not None else utils.make_contiguous_strides_for(shape) |
| |
| |
| def _mk_defaulter(d: T) -> Callable[[Optional[T]], T]: |
| return lambda x: x if x is not None else d |
| |
| |
| _dtype_or_default = _mk_defaulter(torch.float32) |
| _device_or_default = _mk_defaulter(torch.device("cpu")) |
| _storage_offset_or_default = _mk_defaulter(0) |
| _requires_grad_or_default = _mk_defaulter(False) |
| _is_leaf_or_default = _mk_defaulter(False) |
| |
| |
| class NopInputReader: |
| def __init__(self): |
| self.total = 0 |
| |
| def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): |
| self.total += 1 |
| |
| def tensor(self, *args, **kwargs): |
| pass |
| |
| def symint(self, *args, **kwargs): |
| pass |
| |
| |
| # TODO: Support bundling the entire repro into a zip file for ease of |
| # transferring around |
| class InputReader: |
| def __init__(self, save_dir=None, *, pbar=None): |
| # If None, we will generate random data instead. It's important |
| # to natively support this use case as it will allow people to |
| # share repros without including the real data, if the problem |
| # reproduces even on random data. |
| if save_dir is None: |
| log.warning("no save_dir specified, will generate random data") |
| self.store = ContentStoreReader(save_dir) if save_dir is not None else None |
| self.args = [] |
| self.pbar = pbar |
| |
| def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): |
| if self.pbar is not None: |
| self.pbar.update(1) |
| device = _device_or_default(device) |
| dtype_hint = _dtype_or_default(dtype_hint) |
| if self.store is not None and storage_hash is not None: |
| try: |
| storage = self.store.read_storage(storage_hash) |
| except FileNotFoundError: |
| pass |
| else: |
| if device != storage.device: |
| log.warning("device mismatch: %s != %s", device, storage.device) |
| # TODO: transfer it to the right device? But failing this |
| # way would be very mysterious! Would have been better |
| # not to store device in the serialized format... |
| return storage |
| log.warning("could not load %s, generating random data instead", storage_hash) |
| shape = (nbytes // dtype_hint.itemsize,) |
| stride = _stride_or_default(None, shape=shape) |
| return rand_strided(shape, stride, dtype_hint, device).untyped_storage() |
| |
| def tensor( |
| self, |
| storage, |
| shape, |
| stride=None, |
| *, |
| storage_offset=None, |
| dtype=None, |
| requires_grad=None, |
| is_leaf=None, |
| **metadata, |
| ): |
| stride = _stride_or_default(stride, shape=shape) |
| storage_offset = _storage_offset_or_default(storage_offset) |
| dtype = _dtype_or_default(dtype) |
| is_leaf = _is_leaf_or_default(is_leaf) |
| requires_grad = _requires_grad_or_default(requires_grad) |
| t = torch.tensor( |
| [], dtype=dtype, device=storage.device, requires_grad=requires_grad |
| ) |
| with torch.no_grad(): |
| t.set_(storage, storage_offset, shape, stride) |
| if not is_leaf: |
| # Fake up some autograd history in a very naughty way |
| with torch.enable_grad(): |
| t = t.clone(memory_format=torch.preserve_format) |
| with torch.no_grad(): |
| t.set_(storage, storage_offset, shape, stride) |
| assert torch._subclasses.meta_utils.safe_is_leaf(t) == is_leaf |
| torch._utils.set_tensor_metadata(t, metadata) |
| self.args.append(t) |
| return t # for BC |
| |
| def symint(self, val): |
| self.args.append(val) |
| return val # for BC |
| |
| |
| # Here is our writer strategy: |
| # 1. We will stream all of the inputs to disk |
| # 2. You can now deterministically randomize the inputs, or reload |
| # the inputs from disk |
| # 3. You can YOLO run the script without the inputs, in which case |
| # we'll fill the inputs with random data and pray. This is the |
| # legacy behavior, but it's also useful if you want to find out |
| # if we're so broken even random inputs trigger it |
| # 4. We could offer an in process "check if the randomized thing |
| # works too" but this is delicate so we don't do it |
| |
| |
| class InputWriter: |
| def __init__(self, save_dir, *, stable_hash=False): |
| self._lines = [] |
| # TODO: consider ensuring tensor and storage counters line up? |
| self.storage_counter = itertools.count() |
| self.save_dir = save_dir |
| self.store = ( |
| ContentStoreWriter(save_dir, stable_hash=stable_hash) |
| if save_dir is not None |
| else None |
| ) |
| self.seen_storages = {} |
| |
| def lines(self): |
| r = [ |
| "def load_args(reader):", |
| ] |
| r.extend(f" {l}" for l in self._lines) |
| # In case we need to change the internal format of load_args |
| # in an FC-breaking way |
| r.append("load_args._version = 0") |
| return r |
| |
| # Storages are untyped, but we need to initialize them with data if |
| # we don't have the real data, so we give a hint saying what kind |
| # of initialization may be appropriate |
| # |
| # If we had a FakeTensor, device_hint tells us what device should be |
| def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str: |
| ws = StorageWeakRef(untyped_storage) |
| v = self.seen_storages.get(ws) |
| if v is not None: |
| return v |
| v = f"buf{next(self.storage_counter)}" |
| maybe_dtype_hint = "" |
| if _dtype_or_default(None) != _dtype_or_default(dtype_hint): |
| maybe_dtype_hint = f", dtype_hint={dtype_hint!r}" |
| # TODO: being optional on device is kind of pointless as the default |
| # is CPU but most repros we care about are CUDA |
| maybe_device = "" |
| device = untyped_storage.device |
| if device.type == "meta": |
| assert device_hint is not None |
| device = device_hint |
| if _device_or_default(None) != device: |
| maybe_device = f", device={device!r}" |
| nbytes = untyped_storage.nbytes() |
| storage_hash = None |
| if self.store is not None and untyped_storage.device.type != "meta": |
| storage_hash = self.store.write_storage(untyped_storage) |
| self._lines.append( |
| f"{v} = reader.storage({storage_hash!r}, {nbytes!r}{maybe_device}{maybe_dtype_hint})" |
| ) |
| self.seen_storages[ws] = v |
| return v |
| |
| def tensor(self, name, t) -> None: |
| storage = self.storage( |
| t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device |
| ) |
| args = [] |
| # NB: this is positional, must come first |
| if _stride_or_default(None, shape=t.shape) != t.stride(): |
| args.append(str(tuple(t.stride()))) |
| if _dtype_or_default(None) != t.dtype: |
| args.append(f"dtype={t.dtype!r}") |
| if _storage_offset_or_default(None) != t.storage_offset(): |
| args.append(f"storage_offset={t.storage_offset()!r}") |
| tensor_metadata = torch._utils.get_tensor_metadata(t) |
| if tensor_metadata: |
| args.extend(f"{k}={v!r}" for k, v in tensor_metadata.items()) |
| if _requires_grad_or_default(None) != t.requires_grad: |
| args.append(f"requires_grad={t.requires_grad!r}") |
| is_leaf = torch._subclasses.meta_utils.safe_is_leaf(t) |
| if _is_leaf_or_default(None) != is_leaf: |
| args.append(f"is_leaf={is_leaf!r}") |
| self._lines.append( |
| "reader.tensor(" |
| + ", ".join([storage, str(tuple(t.shape)), *args]) |
| + f") # {name}" |
| ) |
| |
| # TODO: this doesn't actually symint atm |
| def symint(self, name, val) -> None: |
| if isinstance(val, torch.SymInt): |
| val = val.node.hint |
| self._lines.append(f"reader.symint({val!r}) # {name}") |