| # Owner(s): ["module: ProxyTensor"] | 
 |  | 
 | from torch.testing._internal.common_utils import TestCase, run_tests, xfail_inherited_tests | 
 | import torch | 
 | import unittest | 
 | import warnings | 
 | import operator | 
 | from collections.abc import Iterable | 
 | from torch.testing._internal.common_device_type import instantiate_device_type_tests | 
 | from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed, skip, xfail, skipOps | 
 | from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException | 
 |  | 
 | from torch._decomp import decomposition_table | 
 | from torch.fx.experimental.symbolic_shapes import ( | 
 |     sym_float, eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets, | 
 |     constrain_range, guard_int, GuardOnDataDependentSymNode | 
 | ) | 
 | from torch.testing._internal.custom_op_db import custom_op_db | 
 | from torch.testing._internal.common_device_type import ops | 
 | from torch._C import _disabled_torch_function_impl | 
 | from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule | 
 | from torch.utils._pytree import tree_map | 
 | from torch import nn | 
 | import re | 
 |  | 
 | import functools | 
 | import itertools | 
 |  | 
 | aten = torch.ops.aten | 
 |  | 
 | HAS_CUDA = torch.cuda.is_available() | 
 |  | 
 |  | 
 | def strip_end(s, suffix): | 
 |     if suffix and s.endswith(suffix): | 
 |         return s[:-len(suffix)] | 
 |     else: | 
 |         return s | 
 |  | 
 |  | 
 | def show_guards(gm): | 
 |     names = [strip_end(n, "_1") for n in fx_placeholder_targets(gm)] | 
 |     return "\n".join( | 
 |         gm.shape_env.produce_guards(fx_placeholder_vals(gm), names, _simplified=True, constraint_inputs=None) | 
 |     ) | 
 |  | 
 |  | 
 | def process_failures(): | 
 |     """ | 
 |     Takes file containing failures like | 
 |  | 
 |     FAILED test/test_proxy_tensor.py::TestProxyTensorOpInfoCPU::test_make_fx_symbolic_exhaustive___getitem___cpu_float32 - RuntimeError: aten.size.default - couldn't find symbolic meta function/decomposition  # noqa: B950 | 
 |  | 
 |     and processes them into a list of opinfo xfails | 
 |     """ | 
 |     f = open('pytest_failures') | 
 |     failures = f.readlines() | 
 |     failures = [i.strip() for i in failures] | 
 |  | 
 |     def process_failure_string(s, matcher): | 
 |         out = re.search(matcher, s) | 
 |         return out.groups() | 
 |  | 
 |     SYMBOLIC_TRACE_MATCH = r'exhaustive_(.*)_cpu.*: (.*)' | 
 |     failures = [process_failure_string(s, SYMBOLIC_TRACE_MATCH) for s in failures] | 
 |  | 
 |     def create_normalized_name(op): | 
 |         if op.variant_test_name == '': | 
 |             s = op.name | 
 |         else: | 
 |             s = f"{op.name}.{op.variant_test_name}" | 
 |         return s.replace('.', '_') | 
 |  | 
 |     remap_opinfo = {create_normalized_name(op): (op.name, op.variant_test_name) for op in op_db} | 
 |  | 
 |     print("symbolic_tensor_failures = {") | 
 |     for failure, reason in failures: | 
 |         print(f"    xfail{remap_opinfo[failure]},  # {reason}") | 
 |     print("}") | 
 |  | 
 |  | 
 | USE_TORCHVISION = False | 
 | try: | 
 |     import torchvision | 
 |     USE_TORCHVISION = True | 
 | except ImportError: | 
 |     warnings.warn("Couldn't import torchvision. Some of our tests use it, try " | 
 |                   "to install it with commands from pytorch.org, post-fixed with " | 
 |                   "`--no-deps` to avoid overwriting the pytorch installation", | 
 |                   UserWarning) | 
 |  | 
 |  | 
 | def _create_new_input(x): | 
 |     if not isinstance(x, torch.Tensor): | 
 |         return x | 
 |     if x.dtype != torch.float: | 
 |         return x + 1 | 
 |     if x.is_leaf: | 
 |         return torch.rand_like(x, requires_grad=x.requires_grad) | 
 |     else: | 
 |         return torch.rand_like(x) | 
 |  | 
 | """ | 
 | Delays a cos being executed on the unwraptensor until its used. Simulates a CommTensor used | 
 | """ | 
 | class UnwrapTensor(torch.Tensor): | 
 |     @staticmethod | 
 |     def __new__(cls, tensor: torch.Tensor): | 
 |         r = torch.Tensor._make_wrapper_subclass( | 
 |             cls, | 
 |             tensor.size(), | 
 |             dtype=tensor.dtype, | 
 |             device=tensor.device, | 
 |             layout=tensor.layout, | 
 |             requires_grad=tensor.requires_grad, | 
 |         ) | 
 |         r._tensor = tensor | 
 |         return r | 
 |  | 
 |     def __repr__(self): | 
 |         # TODO: consider all_gather the local tensors for better debugging | 
 |         return f"UnwrapTensor({self._tensor})" | 
 |  | 
 |     __torch_function__ = _disabled_torch_function_impl | 
 |  | 
 |     @classmethod | 
 |     def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | 
 |         def unwrap(e): | 
 |             ret = e | 
 |             if isinstance(e, UnwrapTensor): | 
 |                 ret = e._tensor.cos() | 
 |  | 
 |             return ret | 
 |  | 
 |         args = tree_map(unwrap, args) | 
 |         kwargs = tree_map(unwrap, kwargs) | 
 |         return func(*args, **kwargs) | 
 |  | 
 | class TestGenericProxyTensor(TestCase): | 
 |     # WARNING: if any of your inputs are index tensors, DO NOT use this | 
 |     # function | 
 |     def _test(self, f, inps): | 
 |         fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps) | 
 |         new_inps = tree_map(_create_new_input, inps) | 
 |         r1 = fx_f(*new_inps) | 
 |         r2 = f(*new_inps) | 
 |         self.assertEqual(r1, r2) | 
 |  | 
 |     def test_pre_autograd_mode_stack(self): | 
 |         def f(a): | 
 |             b = torch.ones(4, 4) | 
 |             return torch.matmul(a, b) | 
 |         # We expect to see matmul in the trace - it should NOT be decomposed into mm. | 
 |         # Also, torch.ones() doesn't show up in the trace. | 
 |         # This is annoying but expected: ones() never dispatches to the Autograd dispatch key, | 
 |         # so our mode never sees it - it goes directly to the BackendSelect key. | 
 |         fx_g = make_fx(f, pre_autograd=True)(torch.ones(4, 4)) | 
 |         self.assertExpectedInline(fx_g.code.strip(), """\ | 
 | def forward(self, a_1): | 
 |     ones = torch.ops.aten.ones.default([4, 4], device = device(type='cpu'), pin_memory = False) | 
 |     matmul = torch.ops.aten.matmul.default(a_1, ones);  a_1 = ones = None | 
 |     return matmul""") | 
 |  | 
 |  | 
 |     def test_make_fx_simple(self): | 
 |         def f(x): | 
 |             return torch.sin(x) | 
 |         self._test(f, (torch.randn(3),)) | 
 |  | 
 |     def test_scalar_device(self, device='cpu'): | 
 |         def f(a, b): | 
 |             return a + b | 
 |         self._test(f, [torch.randn(3, device=device), torch.tensor(5)]) | 
 |  | 
 |     def test_isolated_graphmodule(self): | 
 |         def is_any_sum(gm): | 
 |             return any(node.target == torch.ops.aten.sum.default for node in gm.graph.nodes) | 
 |  | 
 |         def is_any_digamma(gm): | 
 |             return any(node.target == torch.ops.aten.digamma.default for node in gm.graph.nodes) | 
 |  | 
 |         def is_any_sigmoid(gm): | 
 |             return any(node.target == torch.ops.aten.sigmoid.default for node in gm.graph.nodes) | 
 |  | 
 |         def inner(x): | 
 |             return torch.sum(x) | 
 |  | 
 |         def f(x): | 
 |             gm = get_isolated_graphmodule(inner, (x,), {}) | 
 |             self.assertTrue(is_any_sum(gm)) | 
 |             return x + torch.randn(x.shape) | 
 |  | 
 |         # get_isolated_graphmodule uses make_fx internally that shouldn't be traced | 
 |         # by the outer make_fx call | 
 |         traced = make_fx(f)(torch.randn(3)) | 
 |         self.assertFalse(is_any_sum(traced)) | 
 |  | 
 |         # When factory functions are used, they should not be traced | 
 |         # by the outer make_fx call | 
 |         def inner_with_factory(): | 
 |             val = torch.tensor(float(1)) | 
 |             val.add_(2) | 
 |             return torch.full((10, 10), val).sum() | 
 |  | 
 |         def f1(x): | 
 |             gm = get_isolated_graphmodule(inner_with_factory, (), {}) | 
 |             self.assertTrue(is_any_sum(gm)) | 
 |             return torch.sigmoid(x) | 
 |  | 
 |         def f2(x): | 
 |             gm = get_isolated_graphmodule(f1, (x,), {}) | 
 |             self.assertFalse(is_any_sum(gm)) | 
 |             self.assertTrue(is_any_sigmoid(gm)) | 
 |             return torch.digamma(x) | 
 |  | 
 |         traced = make_fx(f2)(torch.randn(3)) | 
 |         self.assertFalse(is_any_sum(traced)) | 
 |         self.assertFalse(is_any_sigmoid(traced)) | 
 |         self.assertTrue(is_any_digamma(traced)) | 
 |  | 
 |         # Verify nested make_fx calls don't make factory functions to be leaked | 
 |         # into the outer graph. Verify that `make_fx`` itself does not leak its execution. | 
 |         def f2(x): | 
 |             gm = make_fx(f1)(x) | 
 |             self.assertFalse(is_any_sum(gm)) | 
 |             self.assertTrue(is_any_sigmoid(gm)) | 
 |             return torch.digamma(x) | 
 |  | 
 |         traced = make_fx(f2)(torch.randn(3)) | 
 |         self.assertFalse(is_any_sum(traced)) | 
 |         self.assertFalse(is_any_sigmoid(traced)) | 
 |         self.assertTrue(is_any_digamma(traced)) | 
 |  | 
 |         # Verify that the `forward`` function of a graph module produced as a | 
 |         # side effect of an interior `make_fx` is still traced | 
 |         def f3(x): | 
 |             gm = make_fx(f1)(x) | 
 |             self.assertFalse(is_any_sum(gm)) | 
 |             self.assertTrue(is_any_sigmoid(gm)) | 
 |             # `gm.forward`` is still traced | 
 |             return torch.digamma(gm(x)) | 
 |  | 
 |         traced = make_fx(f3)(torch.randn(3)) | 
 |         self.assertFalse(is_any_sum(traced)) | 
 |         self.assertTrue(is_any_sigmoid(traced)) | 
 |         self.assertTrue(is_any_digamma(traced)) | 
 |  | 
 |         # Verify interaction with non-ProxyTensor modes | 
 |         from torch.testing._internal.logging_tensor import LoggingTensorMode | 
 |  | 
 |         def f1_logging(x): | 
 |             with LoggingTensorMode(): | 
 |                 gm = get_isolated_graphmodule(inner_with_factory, (), {}) | 
 |             self.assertTrue(is_any_sum(gm)) | 
 |             return torch.sigmoid(x) | 
 |  | 
 |         def f2_logging(x): | 
 |             with LoggingTensorMode(), LoggingTensorMode(): | 
 |                 gm = get_isolated_graphmodule(f1_logging, (x,), {}) | 
 |             self.assertFalse(is_any_sum(gm)) | 
 |             self.assertTrue(is_any_sigmoid(gm)) | 
 |             return torch.digamma(x) | 
 |  | 
 |         traced = make_fx(f2_logging)(torch.randn(3)) | 
 |         self.assertFalse(is_any_sum(traced)) | 
 |         self.assertFalse(is_any_sigmoid(traced)) | 
 |         self.assertTrue(is_any_digamma(traced)) | 
 |  | 
 |         # Verify interaction with another tensor subclass | 
 |         # This case currently doesn't work and should raise an error | 
 |         # See: https://github.com/pytorch/pytorch/pull/81764#issuecomment-1200472068 | 
 |         from torch.testing._internal.logging_tensor import LoggingTensor | 
 |  | 
 |         def f1_logging_tensor(x): | 
 |             gm = get_isolated_graphmodule(inner_with_factory, (), {}) | 
 |             self.assertTrue(is_any_sum(gm)) | 
 |             return torch.sigmoid(x) | 
 |  | 
 |         def f2_logging_tensor(x): | 
 |             x = LoggingTensor(x) | 
 |             gm = get_isolated_graphmodule(f1_logging_tensor, (x,), {}) | 
 |             self.assertFalse(is_any_sum(gm)) | 
 |             self.assertTrue(is_any_sigmoid(gm)) | 
 |             return torch.digamma(x) | 
 |  | 
 |         traced = make_fx(f2_logging_tensor)(torch.randn(3)) | 
 |         self.assertFalse(is_any_sum(traced)) | 
 |         self.assertFalse(is_any_sigmoid(traced))  # this fails, sigmoid is traced with LoggingTensor | 
 |         self.assertTrue(is_any_digamma(traced)) | 
 |  | 
 |     # See https://github.com/pytorch/pytorch/issues/97541 | 
 |     def test_empty_like_doesnt_burn_in_defaults(self): | 
 |         def f(x): | 
 |             return torch.empty_like(x) | 
 |         out = make_fx(f)(torch.randn(3)) | 
 |         self.assertExpectedInline(out.code.strip(), """\ | 
 | def forward(self, x_1): | 
 |     empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False);  x_1 = None | 
 |     return empty_like""") | 
 |  | 
 |     def test_proxy_tensor_mode_with_decomp_table_preserves_proxy(self): | 
 |         def f(x): | 
 |             y = x.new_zeros(x.size()) | 
 |             y.copy_(x) | 
 |             return y | 
 |  | 
 |         def _new_zeros_decomp(inp, size, dtype=None, layout=None, device=None, pin_memory=None): | 
 |             return torch.zeros(size, dtype=inp.dtype, device=inp.device) | 
 |  | 
 |         factory_func_decomp = {torch.ops.aten.new_zeros.default: _new_zeros_decomp} | 
 |  | 
 |         # When new_zeros() decomposes into torch.zero(), we expect ProxyTensorMode | 
 |         # to still be (re-entrantly) enabled, so that the `torch.zero()` call | 
 |         # returns a ProxyTensor. | 
 |         out = make_fx(f, decomposition_table=factory_func_decomp)(torch.ones(2)) | 
 |         self.assertExpectedInline(out.code, """\ | 
 |  | 
 |  | 
 |  | 
 | def forward(self, x_1): | 
 |     zeros = torch.ops.aten.zeros.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) | 
 |     copy_ = torch.ops.aten.copy_.default(zeros, x_1);  zeros = x_1 = None | 
 |     return copy_ | 
 |     """) | 
 |  | 
 |     def test_make_fx_reentrant_dispatch(self): | 
 |         def f(x): | 
 |             return torch.ops.aten.norm.Scalar(x, 2.0) | 
 |  | 
 |         def norm_decomp(x, p=2.0): | 
 |             if p != 2.0: | 
 |                 raise RuntimeError("can't handle with p != 2") | 
 |             return torch.sqrt(torch.sum(torch.square(x))) | 
 |  | 
 |         decomp = {torch.ops.aten.norm.Scalar: norm_decomp} | 
 |  | 
 |         traced = make_fx(f, decomposition_table=decomp, tracing_mode=self.tracing_mode)(torch.rand(3)) | 
 |  | 
 |         for n in traced.graph.nodes: | 
 |             self.assertTrue("square" not in str(n.target)) | 
 |             self.assertTrue("norm" not in str(n.target)) | 
 |  | 
 |     @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") | 
 |     def test_resnet18_backward_trace(self): | 
 |         mod = torchvision.models.resnet18() | 
 |  | 
 |         # An old version of this test called the module directly.  This works | 
 |         # for tracing_mode == "real", but for fake tensors, we also have to | 
 |         # ensure that the parameters and buffers get wrapped in fake tensors | 
 |         # because free fake tensors are not supported.  Fortunately functional_call | 
 |         # does precisely this for us. | 
 |         def f(x, params, buffers): | 
 |             for p in params.values(): | 
 |                 p.grad = None | 
 |             loss = torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum() | 
 |             # I could have done this with the functional API, but there is | 
 |             # plenty of exercising this; I want to show mutating API still | 
 |             # works | 
 |             loss.backward() | 
 |             return [p.grad for p in params.values()] | 
 |  | 
 |         inp = torch.randn(3, 3, 250, 250) | 
 |         self._test(f, [inp, dict(mod.named_parameters()), dict(mod.named_buffers())]) | 
 |  | 
 |     def test_varargs(self): | 
 |         def f(*args): | 
 |             return sum(args) | 
 |  | 
 |         self._test(f, [torch.randn(2), torch.randn(2)]) | 
 |  | 
 |     def test_proxy_tensor(self): | 
 |         def f_grad(x): | 
 |             val = x.cos().cos().sum() | 
 |             return torch.autograd.grad(val, x) | 
 |  | 
 |         def f_backward(x): | 
 |             val = x.cos().cos().sum() | 
 |             val.backward() | 
 |             return x.grad | 
 |  | 
 |         for f in [f_grad, f_backward]: | 
 |             self._test(f, [torch.randn(3, requires_grad=True)]) | 
 |  | 
 |     def test_pickle_issue89626(self): | 
 |         import pickle | 
 |         x = torch.randn(2) | 
 |         make_fx(lambda x: x * 2, tracing_mode=self.tracing_mode)(x) | 
 |         pickle.dumps(x) | 
 |  | 
 |     def test_inplace_metadata(self): | 
 |         def f(x): | 
 |             x = x.clone() | 
 |             x.unsqueeze_(-1) | 
 |             assert x.shape[-1] == 1 | 
 |             return x | 
 |  | 
 |         self._test(f, [torch.randn(5)]) | 
 |  | 
 |     def test_mode_tracing_factory_function(self): | 
 |         def f(x): | 
 |             return x + torch.randn(x.shape) | 
 |  | 
 |         # default behavior should trace factory functions | 
 |         traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3)) | 
 |         self.assertTrue( | 
 |             any( | 
 |                 node.target == aten.randn.default | 
 |                 for node in traced.graph.nodes | 
 |             ) | 
 |         ) | 
 |  | 
 |     def test_val_metadata_mutation(self): | 
 |         def f(x): | 
 |             y = x.clone() | 
 |             y.unsqueeze_(0) | 
 |             return y | 
 |  | 
 |         traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3, requires_grad=True)) | 
 |         self.assertEqual([ | 
 |             tuple(node.meta['val'].shape) | 
 |             for node in traced.graph.nodes | 
 |             if 'val' in node.meta | 
 |         ], [(3,), (3,), (1, 3)]) | 
 |  | 
 |     def test_make_fx_overloads(self): | 
 |         def f(x): | 
 |             return x.cos() + torch.randn(x.shape) | 
 |  | 
 |         traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3)) | 
 |  | 
 |         self.assertTrue(all([isinstance(node.target, torch._ops.OpOverload) | 
 |                              for node in traced.graph.nodes if node.op == 'call_function'])) | 
 |  | 
 |     def test_tensor_constants(self): | 
 |         def f(): | 
 |             val = torch.tensor(float('inf')) | 
 |             return torch.full((100, 100), val) | 
 |  | 
 |         self._test(f, []) | 
 |  | 
 |     def test_allclose(self): | 
 |         def f(a, b): | 
 |             return torch.allclose(a, b) | 
 |  | 
 |         def test_f(): | 
 |             make_fx(f, tracing_mode=self.tracing_mode)( | 
 |                 torch.zeros(3), torch.zeros(3) | 
 |             ) | 
 |  | 
 |         if self.tracing_mode != "real": | 
 |             self.assertRaises(DataDependentOutputException, test_f) | 
 |         else: | 
 |             self.assertRaisesRegex(RuntimeError, "data-dependent", test_f) | 
 |  | 
 |     def test_constant_proxy_tensor_mut(self): | 
 |         def f(): | 
 |             val = torch.tensor(float(1)) | 
 |             val.add_(2) | 
 |             return torch.full((100, 100), val) | 
 |  | 
 |         g = make_fx(f, tracing_mode=self.tracing_mode)() | 
 |         self.assertEqual(g(), f()) | 
 |         # In case we mutated shared state in the g graph! | 
 |         self.assertEqual(g(), f()) | 
 |  | 
 |     def test_constant_unbind(self): | 
 |         def f(): | 
 |             val = torch.tensor([2]) | 
 |             r, = torch.unbind(val, 0) | 
 |             return r.item() | 
 |  | 
 |         g = make_fx(f, tracing_mode=self.tracing_mode)() | 
 |         self.assertEqual(g(), f()) | 
 |  | 
 |     def test_constant_blowup(self): | 
 |         def f(): | 
 |             val = torch.tensor([2]) | 
 |             blowup = val.repeat(1000) | 
 |             return bool(blowup.sum().item() == 2) | 
 |  | 
 |         def test_f(): | 
 |             make_fx(f, tracing_mode=self.tracing_mode)() | 
 |  | 
 |         if self.tracing_mode == "fake": | 
 |             self.assertRaises(DataDependentOutputException, test_f) | 
 |         else: | 
 |             self.assertRaisesRegex(RuntimeError, "data-dependent", test_f) | 
 |  | 
 |     def test_constant_random(self): | 
 |         def f(): | 
 |             val = torch.tensor([2.0]) | 
 |             val.normal_() | 
 |             return bool(val.item() == 2.1) | 
 |  | 
 |         def test_f(): | 
 |             make_fx(f, tracing_mode=self.tracing_mode)() | 
 |  | 
 |         if self.tracing_mode == "fake": | 
 |             self.assertRaises(DataDependentOutputException, test_f) | 
 |         else: | 
 |             self.assertRaisesRegex(RuntimeError, "data-dependent", test_f) | 
 |  | 
 |     def test_decomposition_interpreter(self): | 
 |         def fn(x): | 
 |             return torch.nn.functional.silu(x) | 
 |  | 
 |         x = torch.rand((4, 4)) | 
 |         fx_module = make_fx(fn, tracing_mode=self.tracing_mode, decomposition_table=None)(x) | 
 |  | 
 |         found_silu = False | 
 |         for n in fx_module.graph.nodes: | 
 |             if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default: | 
 |                 found_silu = True | 
 |  | 
 |         self.assertTrue(found_silu) | 
 |  | 
 |         new_graph = torch.fx.Graph() | 
 |         silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]} | 
 |         DecompositionInterpreter( | 
 |             fx_module, | 
 |             new_graph=new_graph, | 
 |             decomposition_table=silu_decomp_table, | 
 |         ).run(x) | 
 |  | 
 |         decomposed_module = torch.fx.GraphModule(fx_module, new_graph) | 
 |  | 
 |         for n in decomposed_module.graph.nodes: | 
 |             self.assertTrue(n.target != torch.ops.aten.silu) | 
 |             self.assertTrue(n.target != torch.ops.aten.silu.default) | 
 |  | 
 |         self.assertEqual(fx_module(x), decomposed_module(x)) | 
 |  | 
 |     def test_make_fx_model_fwd_bwd(self): | 
 |         class Foo(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.linear = torch.nn.Linear(5, 5) | 
 |  | 
 |             def forward(self, x): | 
 |                 return self.linear(x).relu() | 
 |  | 
 |         model = Foo() | 
 |  | 
 |         def f(x, params): | 
 |             out = torch.func.functional_call(model, params, x).sum() | 
 |             out.backward() | 
 |             return list(params.values()) | 
 |         input = torch.randn(3, 5, requires_grad=True) | 
 |         params = dict(model.named_parameters()) | 
 |         fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params) | 
 |         # fx may change the order of parameters in list, so using set() to compare | 
 |         self.assertTrue( | 
 |             torch.allclose(fx_f(input, params)[0], f(input, params)[0]) | 
 |             or | 
 |             torch.allclose(fx_f(input, params)[0], f(input, params)[1]) | 
 |         ) | 
 |         self.assertTrue( | 
 |             torch.allclose(fx_f(input, params)[1], f(input, params)[0]) | 
 |             or | 
 |             torch.allclose(fx_f(input, params)[1], f(input, params)[1]) | 
 |         ) | 
 |  | 
 |     def test_make_fx_model_double_param(self): | 
 |         class Emformer(torch.nn.Module): | 
 |             def __init__( | 
 |                 self, | 
 |                 input_dim: int = 256, | 
 |             ) -> None: | 
 |                 super().__init__() | 
 |  | 
 |                 self.layer_norm = torch.nn.LayerNorm(input_dim) | 
 |  | 
 |             def forward(mod_self, x):  # noqa: B902 | 
 |                 self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor)) | 
 |                 y = mod_self.layer_norm(x) | 
 |                 self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor)) | 
 |                 z = mod_self.layer_norm(y) | 
 |                 return z | 
 |  | 
 |  | 
 |         gm = make_fx(Emformer())(torch.randn(16, 1, 256)) | 
 |         ops = {n.target for n in gm.graph.nodes if n.op == 'call_function'} | 
 |         self.assertEqual(len(ops), 2) | 
 |  | 
 |  | 
 |     def test_make_fx_model_fwd_bwd_wgtupdate(self): | 
 |         class Foo(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.linear = torch.nn.Linear(5, 5) | 
 |  | 
 |             def forward(self, x): | 
 |                 return self.linear(x).relu() | 
 |  | 
 |         model = Foo() | 
 |  | 
 |         def f(args, params, buffers): | 
 |             for p in params.values(): | 
 |                 p.grad = None | 
 |             if not isinstance(args, Iterable): | 
 |                 args = [args] | 
 |             params_and_buffers = {**params, **buffers} | 
 |             out = torch.func.functional_call(model, params_and_buffers, args) | 
 |             out.sum().backward() | 
 |             return [p - 1e-4 * p.grad for p in params.values()] | 
 |  | 
 |         input = torch.randn(3, 5, requires_grad=True) | 
 |         params = dict(model.named_parameters()) | 
 |         buffers = dict(model.named_buffers()) | 
 |         fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params, buffers) | 
 |         # fx may change the order of parameters in list, so using set() to compare | 
 |         # also there is a numerical difference in results so changing atol from 1e-08 to 1e-03 | 
 |         self.assertTrue( | 
 |             torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[0], atol=1e-03) | 
 |             or | 
 |             torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[1], atol=1e-03) | 
 |         ) | 
 |         self.assertTrue( | 
 |             torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[0], atol=1e-03) | 
 |             or | 
 |             torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03) | 
 |         ) | 
 |  | 
 |     def test_trace_subclasses(self): | 
 |         def f1(x): | 
 |             x = UnwrapTensor(x) | 
 |             y = x * 2 | 
 |             return y | 
 |  | 
 |         def f2(x): | 
 |             wrapped = UnwrapTensor(x) | 
 |             y = x * wrapped | 
 |             return y | 
 |  | 
 |         inp = [torch.randn(5)] | 
 |         self._test(f1, inp) | 
 |         self._test(f2, inp) | 
 |  | 
 |     def test_partial_decomp(self): | 
 |         def f(a, b, c): | 
 |             x = torch.addmm(a, b, c) | 
 |             y = torch.addmm(a, b, c, beta=2, alpha=1) | 
 |             return x + y | 
 |         inps = [torch.randn(5, 5), torch.randn(5, 5), torch.randn(5, 5)] | 
 |         fx_g = make_fx(f)(*inps) | 
 |  | 
 |         def addmm(a, b, c, beta=1, alpha=1): | 
 |             if beta == 1 and alpha == 1: | 
 |                 return NotImplemented | 
 |             return beta * a + alpha * (b @ c) | 
 |  | 
 |         decomposed_fx = make_fx(f, decomposition_table={aten.addmm.default: addmm})(*inps) | 
 |  | 
 |         self.assertEqual(fx_g(*inps), decomposed_fx(*inps)) | 
 |         self.assertEqual(len([n for n in fx_g.graph.nodes if n.target == aten.addmm.default]), 2) | 
 |         self.assertEqual(len([n for n in decomposed_fx.graph.nodes if n.target == aten.addmm.default]), 1) | 
 |  | 
 |     def test_decomp_of_capture(self): | 
 |         val = torch.randn(5) | 
 |  | 
 |         def f(x): | 
 |             return x.t() + val.t() | 
 |  | 
 |         def nop(x): | 
 |             return x.cos() | 
 |  | 
 |         traced = make_fx(f, decomposition_table={torch.ops.aten.t.default: nop})(torch.randn(5)) | 
 |         self.assertEqual(len([n for n in traced.graph.nodes if n.target == torch.ops.aten.t.default]), 0) | 
 |  | 
 |  | 
 |     @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') | 
 |     def test_amp_cache(self): | 
 |         layer = torch.nn.Conv2d(3, 3, 3).cuda() | 
 |  | 
 |         def f(x, w): | 
 |             return torch.nn.functional.conv2d(x, w, stride=layer.stride) | 
 |  | 
 |         inp = torch.randn(4, 3, 10, 10, device='cuda') | 
 |         with torch.autocast('cuda'): | 
 |             out_graph = make_fx(f)(inp, layer.weight).graph | 
 |             out_graph2 = make_fx(f)(inp, layer.weight).graph | 
 |  | 
 |         self.assertEqual(len(out_graph.nodes), len(out_graph2.nodes)) | 
 |         for a, b in zip(out_graph.nodes, out_graph2.nodes): | 
 |             self.assertEqual(a.op, b.op) | 
 |  | 
 |     def test_strides(self): | 
 |         def f(x): | 
 |             self.assertTrue(x.is_contiguous()) | 
 |             self.assertFalse(x.is_contiguous(memory_format=torch.channels_last)) | 
 |             x = x.permute(0, 3, 1, 2) | 
 |             self.assertFalse(x.is_contiguous()) | 
 |             self.assertTrue(x.is_contiguous(memory_format=torch.channels_last)) | 
 |             return x | 
 |         make_fx(f)(torch.randn(2, 3, 4, 5)) | 
 |  | 
 |         def f(x): | 
 |             self.assertTrue(x.is_contiguous()) | 
 |             y = x[:, 1] | 
 |             self.assertFalse(y.is_contiguous()) | 
 |             y = x[:, ::2] | 
 |             self.assertFalse(y.is_contiguous()) | 
 |             return x.cos() | 
 |  | 
 |         make_fx(f)(torch.randn(2, 3, 4, 5)) | 
 |  | 
 |     def test_pr_86917(self): | 
 |         # Tests the issue brought up here https://github.com/pytorch/pytorch/pull/86917#issuecomment-1283155344 | 
 |         def f(a, b): | 
 |             return torch.ops.aten.nll_loss_forward(a, b, None, 1, 10) | 
 |  | 
 |         self._test(f, [torch.randn(1, 10), torch.zeros(1, dtype=torch.long)]) | 
 |  | 
 | class TestGenericProxyTensorReal(TestGenericProxyTensor): | 
 |     tracing_mode = "real" | 
 |  | 
 |  | 
 | class TestGenericProxyTensorFake(TestGenericProxyTensor): | 
 |     tracing_mode = "fake" | 
 |  | 
 |  | 
 | @xfail_inherited_tests([ | 
 |     "test_make_fx_overloads", | 
 | ]) | 
 | class TestGenericProxyTensorSymbolic(TestGenericProxyTensor): | 
 |     tracing_mode = "symbolic" | 
 |  | 
 |  | 
 | del TestGenericProxyTensor | 
 |  | 
 |  | 
 | class TestRealProxyTensor(TestCase): | 
 |     pass | 
 |  | 
 | class TestFakeProxyTensor(TestCase): | 
 |     def test_issue82547(self): | 
 |         x = nn.Parameter(torch.randn(3, 3)) | 
 |  | 
 |         def f(): | 
 |             return torch.ops.aten.t.default(x) | 
 |         self.assertRaisesRegex(Exception, "Please convert all Tensors", lambda: make_fx(f, tracing_mode="fake")()) | 
 |  | 
 |         class A(torch.Tensor): | 
 |             pass | 
 |  | 
 |         x = A(torch.randn(3, 3)) | 
 |         self.assertRaisesRegex(TypeError, "no implementation found", lambda: make_fx(f, tracing_mode="fake")()) | 
 |  | 
 |     def test_use_fake_and_tensor(self): | 
 |         def f(x, y): | 
 |             z = torch.tensor([2.0, 3.0]) | 
 |             return x + y + z | 
 |  | 
 |         g = make_fx(f, tracing_mode="fake")(torch.randn(2), torch.randn(2)) | 
 |         x, y = torch.randn(2), torch.randn(2) | 
 |         self.assertEqual(g(x, y), f(x, y)) | 
 |  | 
 |     def test_fused_adam(self): | 
 |         # See https://github.com/pytorch/pytorch/issues/99356 | 
 |         params = [torch.randn(10, 10, requires_grad=True) for _ in range(10)] | 
 |         grads = [torch.randn(10, 10) for _ in range(10)] | 
 |         exp_avgs = [torch.randn(10, 10) for _ in range(10)] | 
 |         exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)] | 
 |         max_exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)] | 
 |         state_steps = [torch.tensor(0) for _ in range(10)] | 
 |  | 
 |         def fused_adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps): | 
 |             return aten._fused_adam.default( | 
 |                 params, | 
 |                 grads, | 
 |                 exp_avgs, | 
 |                 exp_avg_sqs, | 
 |                 max_exp_avg_sqs, | 
 |                 state_steps, | 
 |                 lr=0.1, | 
 |                 beta1=0.9, | 
 |                 beta2=0.999, | 
 |                 weight_decay=0.01, | 
 |                 eps=1e-8, | 
 |                 amsgrad=False, | 
 |                 maximize=False, | 
 |             ) | 
 |  | 
 |         gm = make_fx(fused_adam, tracing_mode='fake')( | 
 |             params, | 
 |             grads, | 
 |             exp_avgs, | 
 |             exp_avg_sqs, | 
 |             max_exp_avg_sqs, | 
 |             state_steps, | 
 |         ) | 
 |         for n in gm.graph.nodes: | 
 |             if n.op == "call_function" and n.target == aten._fused_adam.default: | 
 |                 self.assertIn('val', n.meta) | 
 |  | 
 |     def test_alias(self): | 
 |         def f(x): | 
 |             return torch.ops.aten.alias(x) | 
 |  | 
 |         r = str(make_fx(f, tracing_mode="fake")(torch.randn(2)).code).strip() | 
 |         # NB: this should not have a detach call | 
 |         self.assertExpectedInline(r, """\ | 
 | def forward(self, x_1): | 
 |     alias = torch.ops.aten.alias.default(x_1);  x_1 = None | 
 |     return alias""") | 
 |  | 
 |     def test_meta(self): | 
 |         def f(x): | 
 |             a = x.cos() | 
 |             b = torch.var_mean(a, dim=0) | 
 |             c = b * 2 | 
 |             return c | 
 |  | 
 |         out = make_fx(f, tracing_mode="fake")(torch.randn(5, 5)) | 
 |         for n in out.graph.nodes: | 
 |             if n.op == 'output': | 
 |                 continue | 
 |             self.assertTrue('val' in n.meta) | 
 |  | 
 | def _get_node(fx_g, cond): | 
 |     for n in fx_g.graph.nodes: | 
 |         if cond(n): | 
 |             return n | 
 |     raise AssertionError | 
 |  | 
 | def _get_free_symbols(shape_env): | 
 |     vars = tuple(shape_env.var_to_val.keys()) | 
 |     return len([var for var in vars if var not in shape_env.replacements]) | 
 |  | 
 | def _trace(f, *args): | 
 |     inps = [torch.randn(arg) for arg in args] | 
 |     return make_fx(f, tracing_mode="symbolic")(*inps) | 
 |  | 
 | # TODO: Need to test the guards themselves specifically as well | 
 | class TestSymbolicTracing(TestCase): | 
 |     def _test_dynamic(self, fn, trace_inputs, test_inputs, assert_eq=True): | 
 |         """ | 
 |         Tests fn traced with trace_inputs against test_inputs | 
 |         Also returns shape env | 
 |         """ | 
 |         trace_inputs = [torch.randn(shape) for shape in trace_inputs] | 
 |         traced_f = make_fx(fn, tracing_mode="symbolic")(*trace_inputs) | 
 |         for input in test_inputs: | 
 |             input = [torch.randn(shape) for shape in input] | 
 |             rx, ry = traced_f(*input), fn(*input) | 
 |             if assert_eq: | 
 |                 self.assertEqual(rx, ry) | 
 |         return traced_f | 
 |  | 
 |  | 
 |     def test_debug_interpreter(self): | 
 |         import torch.library | 
 |         from torch.library import Library | 
 |  | 
 |         foo = Library("foo", "DEF") | 
 |         foo.define("foo(Tensor self) -> Tensor") | 
 |  | 
 |         # Operator where meta and cpu disagree on strides | 
 |         @torch.library.impl(foo, "foo", "CPU") | 
 |         def foo_cpu(x): | 
 |             return x.clone().T | 
 |  | 
 |         @torch.library.impl(foo, "foo", "Meta") | 
 |         def foo_meta(x): | 
 |             return x.clone() | 
 |  | 
 |         def f(x): | 
 |             return torch.ops.foo.foo.default(x) | 
 |  | 
 |         gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2, 2)) | 
 |         from torch._functorch.compilers import DebugInterpreter | 
 |  | 
 |         interp = DebugInterpreter(gm) | 
 |  | 
 |         # input mismatch is caught (indicates guard problem) | 
 |         self.assertRaisesRegex( | 
 |             AssertionError, r"3 != 1", | 
 |             lambda: interp.run(torch.randn(3, 3).T), | 
 |         ) | 
 |  | 
 |         # Catch the incorrect meta | 
 |         self.assertRaisesRegex( | 
 |             AssertionError, r"\(3, 1\) != \(1, 3\)", | 
 |             lambda: interp.run(torch.randn(3, 3)) | 
 |         ) | 
 |  | 
 |     def test_resize_from_zero(self): | 
 |         def f(x, y): | 
 |             x.resize_(y.size(0)) | 
 |  | 
 |         r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(0), torch.empty(2)).code).strip() | 
 |         self.assertExpectedInline(r, """\ | 
 | def forward(self, x_1, y_1): | 
 |     sym_size = torch.ops.aten.sym_size(y_1, 0);  y_1 = None | 
 |     resize_ = torch.ops.aten.resize_.default(x_1, [sym_size]);  x_1 = sym_size = None | 
 |     return None""") | 
 |  | 
 |  | 
 |     def test_unary(self): | 
 |         def f(x): | 
 |             assert x.shape[0] < 20 | 
 |             return x.cos() | 
 |         test_inputs = [] | 
 |         test_inputs.append([(2, 5)]) | 
 |         test_inputs.append([(6, 8)]) | 
 |         gm = self._test_dynamic(f, [(3, 4)], test_inputs) | 
 |         self.assertTrue(eval_guards(gm, torch.randn(4, 5))) | 
 |         self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s0: 4, s1: 5}") | 
 |         self.assertFalse(eval_guards(gm, torch.randn(25, 5))) | 
 |         self.assertExpectedInline(show_guards(gm), """L['x'].size()[0] < 20""") | 
 |  | 
 |     @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') | 
 |     def test_cpu_scalar_cuda(self): | 
 |         # Extracted from wave2vec2 | 
 |         def f(a, b): | 
 |             return (a * b) @ b | 
 |  | 
 |         r = str( | 
 |             make_fx(f, tracing_mode="symbolic")( | 
 |                 torch.tensor(1.0), torch.randn(2, 2, device='cuda') | 
 |             ).code | 
 |         ).strip() | 
 |         self.assertExpectedInline(r, """\ | 
 | def forward(self, a_1, b_1): | 
 |     mul = torch.ops.aten.mul.Tensor(a_1, b_1);  a_1 = None | 
 |     mm = torch.ops.aten.mm.default(mul, b_1);  mul = b_1 = None | 
 |     return mm""") | 
 |  | 
 |     def test_binary_broadcast(self): | 
 |         def f(a, b): | 
 |             c = a * b | 
 |             return c | 
 |  | 
 |         test_inputs = [] | 
 |         test_inputs.append([(1, 5), (3, 1)]) | 
 |         test_inputs.append([(1, 4), (4, 1)]) | 
 |         shape_env = self._test_dynamic(f, [(1, 2), (3, 1)], test_inputs).shape_env | 
 |         assert len(shape_env.guards) == 0 | 
 |  | 
 |     def test_multiply_shape(self): | 
 |         def f(a): | 
 |             return torch.empty(a.shape[0] * 2) | 
 |  | 
 |         r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip() | 
 |         self.assertExpectedInline(r, """\ | 
 | def forward(self, a_1): | 
 |     sym_size = torch.ops.aten.sym_size(a_1, 0);  a_1 = None | 
 |     mul = sym_size * 2;  sym_size = None | 
 |     empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False);  mul = None | 
 |     return empty""") | 
 |  | 
 |     def test_item(self): | 
 |         def f(a): | 
 |             r = a.item() | 
 |             return r * a | 
 |  | 
 |         r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(1)).code).strip() | 
 |         self.assertExpectedInline(r, """\ | 
 | def forward(self, a_1): | 
 |     _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1) | 
 |     mul = torch.ops.aten.mul.Tensor(a_1, _local_scalar_dense);  a_1 = _local_scalar_dense = None | 
 |     return mul""") | 
 |  | 
 |     def test_item_to_constructor(self): | 
 |         def f(a): | 
 |             r = a.item() | 
 |             constrain_range(r, min=2) | 
 |             return torch.empty(r) | 
 |  | 
 |         r = str(make_fx(f, tracing_mode="symbolic")(torch.randint(5, (1,))).code).strip() | 
 |         self.assertExpectedInline( | 
 |             r, """\ | 
 | def forward(self, a_1): | 
 |     _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1);  a_1 = None | 
 |     empty = torch.ops.aten.empty.memory_format([_local_scalar_dense], device = device(type='cpu'), pin_memory = False);  _local_scalar_dense = None | 
 |     return empty"""  # noqa: B950 | 
 |         ) | 
 |  | 
 |     def test_dynamic_pointwise_scalar(self): | 
 |         def f(gravity, mask): | 
 |             gravity[mask, 0] = gravity[mask, 0] * -1 | 
 |  | 
 |         r = str(make_fx(f, tracing_mode="symbolic")( | 
 |             torch.randn((12, 4)), | 
 |             torch.randint(0, 2, (12,), dtype=torch.bool) | 
 |         ).code).strip() | 
 |         self.assertExpectedInline(r, """\ | 
 | def forward(self, gravity_1, mask_1): | 
 |     select = torch.ops.aten.select.int(gravity_1, 1, 0) | 
 |     index = torch.ops.aten.index.Tensor(select, [mask_1]);  select = None | 
 |     mul = torch.ops.aten.mul.Tensor(index, -1);  index = None | 
 |     select_1 = torch.ops.aten.select.int(gravity_1, 1, 0);  gravity_1 = None | 
 |     index_put_ = torch.ops.aten.index_put_.default(select_1, [mask_1], mul);  select_1 = mask_1 = mul = None | 
 |     return None""") | 
 |  | 
 |     def test_reflect_r_over_x(self): | 
 |         def reflect_R_over_x(R): | 
 |             reflect = torch.eye(3, device=R.device) | 
 |             reflect[0, 0] = -1 | 
 |             return reflect @ R @ reflect | 
 |  | 
 |         def f(crop_camera, mask): | 
 |             crop_camera[mask] = reflect_R_over_x(crop_camera[mask]) | 
 |  | 
 |         r = str(make_fx(f, tracing_mode="symbolic")( | 
 |             torch.randn((12, 3, 3)), | 
 |             torch.randint(0, 2, (12,), dtype=torch.bool) | 
 |         ).code).strip() | 
 |         self.assertExpectedInline(r, """\ | 
 | def forward(self, crop_camera_1, mask_1): | 
 |     index = torch.ops.aten.index.Tensor(crop_camera_1, [mask_1]) | 
 |     eye = torch.ops.aten.eye.default(3, device = device(type='cpu'), pin_memory = False) | 
 |     _tensor_constant0 = self._tensor_constant0 | 
 |     lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None | 
 |     select = torch.ops.aten.select.int(eye, 0, 0) | 
 |     select_1 = torch.ops.aten.select.int(select, 0, 0);  select = None | 
 |     copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy);  select_1 = lift_fresh_copy = None | 
 |     transpose = torch.ops.aten.transpose.int(index, -2, -1) | 
 |     t = torch.ops.aten.t.default(eye) | 
 |     clone = torch.ops.aten.clone.default(transpose, memory_format = torch.contiguous_format);  transpose = None | 
 |     sym_size = torch.ops.aten.sym_size(index, 0);  index = None | 
 |     sym_size_1 = torch.ops.aten.sym_size(crop_camera_1, 2) | 
 |     mul = sym_size * sym_size_1 | 
 |     sym_size_2 = torch.ops.aten.sym_size(crop_camera_1, 1) | 
 |     _unsafe_view = torch.ops.aten._unsafe_view.default(clone, [mul, sym_size_2]);  clone = mul = sym_size_2 = None | 
 |     mm = torch.ops.aten.mm.default(_unsafe_view, t);  _unsafe_view = t = None | 
 |     view = torch.ops.aten.view.default(mm, [sym_size, sym_size_1, 3]);  mm = sym_size_1 = None | 
 |     transpose_1 = torch.ops.aten.transpose.int(view, -2, -1) | 
 |     clone_1 = torch.ops.aten.clone.default(transpose_1, memory_format = torch.contiguous_format);  transpose_1 = None | 
 |     mul_1 = sym_size * 3 | 
 |     sym_size_3 = torch.ops.aten.sym_size(view, 1);  view = None | 
 |     view_1 = torch.ops.aten.view.default(clone_1, [mul_1, sym_size_3]);  clone_1 = mul_1 = sym_size_3 = None | 
 |     mm_1 = torch.ops.aten.mm.default(view_1, eye);  view_1 = eye = None | 
 |     view_2 = torch.ops.aten.view.default(mm_1, [sym_size, 3, 3]);  mm_1 = sym_size = None | 
 |     index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_2);  crop_camera_1 = mask_1 = view_2 = None | 
 |     return None""") | 
 |  | 
 |     def test_unbacked_slice(self): | 
 |         def f(x, m): | 
 |             x = x[m] | 
 |             return x[slice(None, None, None), slice(None, None, None), slice(None, 2, None)] | 
 |  | 
 |         make_fx(f, tracing_mode="symbolic")( | 
 |             torch.randn((12, 3, 3)), | 
 |             torch.randint(0, 2, (12,), dtype=torch.bool) | 
 |         ) | 
 |  | 
 |     @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") | 
 |     def test_unbacked_batch_resnet(self): | 
 |         mod = torchvision.models.resnet18() | 
 |  | 
 |         def f(x, mask, params, buffers): | 
 |             for p in itertools.chain([x, mask], params.values(), buffers.values()): | 
 |                 for s in p.shape: | 
 |                     guard_int(s) | 
 |             x = x[mask] | 
 |             constrain_range(x.shape[0], min=1) | 
 |             for p in params.values(): | 
 |                 p.grad = None | 
 |             return torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum() | 
 |  | 
 |         make_fx(f, tracing_mode="symbolic")( | 
 |             torch.randn(3, 3, 250, 250), | 
 |             torch.randint(0, 2, (3,), dtype=torch.bool), | 
 |             dict(mod.named_parameters()), | 
 |             dict(mod.named_buffers()), | 
 |         ) | 
 |  | 
 |     def test_boolean_index(self): | 
 |         def f(images, handedness, valid): | 
 |             images = images[valid] | 
 |             handedness = handedness[valid] | 
 |             right_hand_mask = handedness == 1 | 
 |             images[right_hand_mask] = images[right_hand_mask].flip(-1) | 
 |  | 
 |         r = str(make_fx(f, tracing_mode="symbolic")( | 
 |             torch.randint(0, 256, (512, 1, 96, 96)), | 
 |             torch.randint(0, 1, (512,)), | 
 |             torch.randint(0, 2, (512,), dtype=torch.bool) | 
 |         ).code).strip() | 
 |         self.assertExpectedInline(r, """\ | 
 | def forward(self, images_1, handedness_1, valid_1): | 
 |     index = torch.ops.aten.index.Tensor(images_1, [valid_1]);  images_1 = None | 
 |     index_1 = torch.ops.aten.index.Tensor(handedness_1, [valid_1]);  handedness_1 = valid_1 = None | 
 |     eq = torch.ops.aten.eq.Scalar(index_1, 1);  index_1 = None | 
 |     index_2 = torch.ops.aten.index.Tensor(index, [eq]) | 
 |     flip = torch.ops.aten.flip.default(index_2, [-1]);  index_2 = None | 
 |     index_put_ = torch.ops.aten.index_put_.default(index, [eq], flip);  index = eq = flip = None | 
 |     return None""") | 
 |  | 
 |     def test_neg_shape(self): | 
 |         def f(a): | 
 |             return torch.empty(-a.shape[0] + 10) | 
 |  | 
 |         r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(2)).code).strip() | 
 |         self.assertExpectedInline(r, """\ | 
 | def forward(self, a_1): | 
 |     sym_size = torch.ops.aten.sym_size(a_1, 0);  a_1 = None | 
 |     neg = -sym_size;  sym_size = None | 
 |     add = neg + 10;  neg = None | 
 |     empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False);  add = None | 
 |     return empty""") | 
 |  | 
 |     def test_invalidate_nonzero(self): | 
 |         ok = False | 
 |  | 
 |         def f(a): | 
 |             nonlocal ok | 
 |             b = a.clone() | 
 |             x = b.nonzero() | 
 |             x1 = b.nonzero() | 
 |             x2 = b.nonzero() | 
 |             assert x1.shape[0] == x2.shape[0] | 
 |             ok = True | 
 |             b.normal_() | 
 |             y = b.nonzero() | 
 |             try: | 
 |                 bool(x1.shape[0] == y.shape[0]) | 
 |                 self.fail("didn't raise exception") | 
 |             except GuardOnDataDependentSymNode: | 
 |                 pass | 
 |  | 
 |         make_fx(f, tracing_mode="symbolic")(torch.randn(4)) | 
 |  | 
 |     def test_sqrt_size(self): | 
 |         def f(a): | 
 |             return a / a.size(-1) ** 0.5 | 
 |  | 
 |         r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip() | 
 |         self.assertExpectedInline(r, """\ | 
 | def forward(self, a_1): | 
 |     sym_size = torch.ops.aten.sym_size(a_1, 0) | 
 |     pow_1 = sym_size ** 0.5;  sym_size = None | 
 |     div = torch.ops.aten.div.Tensor(a_1, pow_1);  a_1 = pow_1 = None | 
 |     return div""") | 
 |  | 
 |  | 
 |     def test_symint_to_tensor(self): | 
 |         def f(a): | 
 |             return a / a.shape[0] | 
 |  | 
 |         r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip() | 
 |         self.assertExpectedInline(r, """\ | 
 | def forward(self, a_1): | 
 |     sym_size = torch.ops.aten.sym_size(a_1, 0) | 
 |     div = torch.ops.aten.div.Tensor(a_1, sym_size);  a_1 = sym_size = None | 
 |     return div""") | 
 |  | 
 |         r = str(make_fx(f, tracing_mode="symbolic", decomposition_table=decomposition_table)(torch.empty(4)).code).strip() | 
 |         self.assertExpectedInline(r, """\ | 
 | def forward(self, a_1): | 
 |     sym_size = torch.ops.aten.sym_size(a_1, 0) | 
 |     sym_float = torch.sym_float(sym_size);  sym_size = None | 
 |     div = torch.ops.prims.div.default(a_1, sym_float);  a_1 = sym_float = None | 
 |     return div""") | 
 |  | 
 |     def test_cat(self): | 
 |         def f(a, b): | 
 |             val = torch.mul(a, b) | 
 |             out = torch.cat([val, val]) | 
 |             if out.shape[0] * out.shape[1] > 20: | 
 |                 out = out.cos() | 
 |             return out | 
 |  | 
 |         test_inputs = [] | 
 |         test_inputs.append([(1, 5), (6, 1)]) | 
 |         test_inputs.append([(1, 4), (3, 1)]) | 
 |         gm = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs) | 
 |         self.assertTrue(eval_guards(gm, torch.randn(1, 10), torch.randn(6, 1))) | 
 |         self.assertFalse(eval_guards(gm, torch.randn(1, 2), torch.randn(4, 1))) | 
 |         self.assertExpectedInline(show_guards(gm), """2*L['a'].size()[1]*L['b'].size()[0] > 20""") | 
 |  | 
 |     def test_new_empty(self): | 
 |         def f(a, b): | 
 |             return a.new_empty(b.shape[0], b.shape[1] * 2) | 
 |  | 
 |         self._test_dynamic(f, [(2, 4), (4, 5)], [[(2, 3), (5, 7)], [(3, 7), (9, 3)]], assert_eq=False).shape_env | 
 |  | 
 |     def test_size_with_tensor(self): | 
 |         def f(tensor): | 
 |             max_size = torch.tensor([800, 1216], dtype=torch.int64) | 
 |             batch_shape = [2] + list(tensor.shape[:-2]) + list(max_size) | 
 |             return tensor.new_empty(batch_shape) | 
 |  | 
 |         a = torch.randn(3, 800, 1199) | 
 |         self.assertRaisesRegex( | 
 |             RuntimeError, "data-dependent", lambda: make_fx(f, tracing_mode="symbolic")(a) | 
 |         ) | 
 |  | 
 |     def test_expand(self): | 
 |         def f(a): | 
 |             b = torch.mul(a, a) | 
 |             c = b.expand(a.shape) | 
 |             return c | 
 |  | 
 |         self._test_dynamic(f, [(3,)], [[(3,)], [(4,)], [(2,)]]) | 
 |         self._test_dynamic(f, [(5, 1)], [[(4, 1)], [(3, 1)], [(6, 1)]]) | 
 |  | 
 |     def test_metadata(self): | 
 |         def f(a, b): | 
 |             d = a.new_empty(a.shape[0] + b.shape[0]) | 
 |             return d | 
 |         fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4)) | 
 |         meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default) | 
 |         meta_d = _get_node(fx_g, lambda x: x.target == operator.add) | 
 |         self.assertTrue(meta_c.meta['val'].shape[0].node.expr == meta_d.meta['val'].node.expr) | 
 |  | 
 |     def test_metadata_fresh(self): | 
 |         def f(x): | 
 |             assert x.shape[0] == 3 | 
 |             return x.cos() | 
 |  | 
 |         fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(3)) | 
 |         meta_cos = _get_node(fx_g, lambda x: x.target == aten.cos.default) | 
 |         meta_inp = _get_node(fx_g, lambda x: x.op == 'placeholder') | 
 |         self.assertTrue(meta_cos.meta['val'].shape[0].node.expr == 3) | 
 |         # Checks if the input expr has been updated even though the constraint | 
 |         # happened afterwards | 
 |         self.assertTrue(meta_inp.meta['val'].shape[0].node.expr == 3) | 
 |  | 
 |     def test_elementwise_meta_with_sym_numbers(self): | 
 |         def f(x, offset, as_sym_float=False): | 
 |             x0 = x.size()[0] | 
 |             if as_sym_float: | 
 |                 x0 = sym_float(x0) | 
 |             return torch.add(x0, offset) | 
 |  | 
 |         fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2.0, False) | 
 |         meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) | 
 |         self.assertEqual(meta_add.meta['val'].shape, ()) | 
 |         self.assertEqual(meta_add.meta['val'].dtype, torch.float32) | 
 |  | 
 |         fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, False) | 
 |         meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) | 
 |         self.assertEqual(meta_add.meta['val'].shape, ()) | 
 |         self.assertEqual(meta_add.meta['val'].dtype, torch.int64) | 
 |  | 
 |         fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, True) | 
 |         meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) | 
 |         self.assertEqual(meta_add.meta['val'].shape, ()) | 
 |         self.assertEqual(meta_add.meta['val'].dtype, torch.float32) | 
 |  | 
 |     def test_return_symint(self): | 
 |         def f(x): | 
 |             return x.shape[0], x.cos(), x.shape[0] / 5 | 
 |         self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]]) | 
 |  | 
 |         def f(x): | 
 |             return x.shape | 
 |         self._test_dynamic(f, [(5, 3)], [[(4, 6)]]) | 
 |  | 
 |     def test_rmethod(self): | 
 |         def f(x): | 
 |             return x.size(0) + x | 
 |         self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]]) | 
 |  | 
 |     def test_mega_guard(self): | 
 |         def f(a, b): | 
 |             assert a.shape[0] == b.shape[0] * 2 | 
 |             return a.cos() | 
 |         fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(16), torch.randn(8)) | 
 |         from torch._dynamo.source import LocalSource | 
 |         self.assertExpectedInline( | 
 |             str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")])), | 
 |             """["L['a'].size()[0] == 2*L['b'].size()[0]", "L['a'].stride()[0] == 1", "L['a'].storage_offset() == 0", "L['b'].stride()[0] == 1", "L['b'].storage_offset() == 0", "2 <= L['b'].size()[0]"]"""  # noqa: B950 | 
 |         ) | 
 |  | 
 |     def test_sym_storage_offset(self): | 
 |         def f(x, y): | 
 |             return x + y | 
 |  | 
 |         inp = (torch.randn(8)[3:], torch.randn(5)) | 
 |         fx_g = make_fx(f, tracing_mode="symbolic")(*inp) | 
 |         inp = (torch.randn(8)[3:], torch.randn(5)) | 
 |         self.assertEqual(fx_g(*inp), f(*inp)) | 
 |  | 
 |     def _assert_no_guards(self, fx_g, free_symbols): | 
 |         assert _get_free_symbols(fx_g.shape_env) == free_symbols, fx_g.shape_env.var_to_val | 
 |         assert len(fx_g.shape_env.get_nontrivial_guards()) == 0, fx_g.shape_env.format_guards() | 
 |  | 
 |     def test_guards_equal(self): | 
 |         def f(a, b): | 
 |             return a * b | 
 |  | 
 |         # NB: Numbers are carefully chosen to avoid duck shaping from applying | 
 |  | 
 |         fx_g = _trace(f, (5, 6), (5, 6)) | 
 |         self._assert_no_guards(fx_g, 2) | 
 |  | 
 |         fx_g = _trace(f, (5, 6, 7), (5, 6, 7)) | 
 |         self._assert_no_guards(fx_g, 3) | 
 |  | 
 |         fx_g = _trace(f, (5, 1), (1, 6)) | 
 |         self._assert_no_guards(fx_g, 2) | 
 |  | 
 |         def f(a, b, c, d): | 
 |             a = a + b | 
 |             cat = torch.cat([c, d]) | 
 |             return a + cat | 
 |  | 
 |         fx_g = _trace(f, 7, 7, 4, 3) | 
 |         self._assert_no_guards(fx_g, 2) | 
 |  | 
 |         def f(a, b, c, d, e): | 
 |             vals = [a, b, c, d, e] | 
 |             x = a | 
 |             for idx in range(len(vals) - 1): | 
 |                 x = torch.cat([x, vals[idx]]) + vals[idx + 1] | 
 |             return x | 
 |  | 
 |         fx_g = _trace(f, 2, 4, 8, 16, 32) | 
 |         self._assert_no_guards(fx_g, 1) | 
 |  | 
 |         def f(a, b): | 
 |             a = a.view(b.shape[0]) | 
 |             return a + b.sum() | 
 |  | 
 |         fx_g = _trace(f, (4, 2), 8) | 
 |         self._assert_no_guards(fx_g, 2) | 
 |  | 
 |         fx_g = _trace(f, (4, 2), (8, 5)) | 
 |         self._assert_no_guards(fx_g, 3) | 
 |  | 
 |         fx_g = _trace(f, (2, 3, 4), 24) | 
 |         self._assert_no_guards(fx_g, 3) | 
 |  | 
 |     def test_nonidentity_transitive_guards(self): | 
 |         def f(a, b, c, d, e): | 
 |             vals = [a, b, c, d, e] | 
 |             cat_vals = [] | 
 |             for idx in range(len(vals) - 1): | 
 |                 cat_vals.append(torch.cat([vals[idx], vals[idx]])) | 
 |             final_vals = [] | 
 |             for a, b in reversed(list(zip(cat_vals, vals[1:]))): | 
 |                 final_vals.append(a + b) | 
 |             return final_vals | 
 |  | 
 |         fx_g = _trace(f, 2, 4, 8, 16, 32) | 
 |         self.assertExpectedInline(show_guards(fx_g), """""") | 
 |  | 
 |  | 
 |  | 
 |  | 
 |  | 
 | make_fx_failures = { | 
 |     # unknown | 
 |     xfail('allclose'), | 
 |     xfail('equal'), | 
 |     # empty | 
 |     skip('new_empty'), | 
 |     skip('empty_like'), | 
 |     skip('empty'), | 
 |     skip('empty_permuted'), | 
 |     # flaky | 
 |     skip('linalg.lstsq', 'grad_oriented'), | 
 |     skip('nn.functional.max_unpool1d', '', device_type='cpu'), | 
 |     skip('nn.functional.max_unpool2d', '', device_type='cpu'), | 
 |     skip('nn.functional.max_unpool3d', '', device_type='cpu'), | 
 |     skip('linalg.lstsq'),  # flaky, probably just a precision issue | 
 |  | 
 |     # data-dependent control flow | 
 |     xfail('cov'), | 
 |     xfail('istft'), | 
 |     xfail('nn.functional.gaussian_nll_loss'), | 
 |     xfail('tensor_split'), | 
 |     xfail('corrcoef'), | 
 |     xfail('quantile'), | 
 |     xfail('nanquantile'), | 
 |     xfail('narrow'), | 
 |  | 
 |     # many complex operators incorrect striding, metadata | 
 |     skip('fft.fft', ''), | 
 |     skip('fft.hfft2', ''), | 
 |     skip('fft.hfft', ''), | 
 |     skip('fft.hfftn', ''), | 
 |     skip('fft.ifft', ''), | 
 |     skip('fft.ihfft2', ''), | 
 |     skip('fft.ihfft', ''), | 
 |     skip('fft.ihfftn', ''), | 
 |     skip('fft.irfft2', ''), | 
 |     skip('fft.irfft', ''), | 
 |     skip('fft.irfftn', ''), | 
 |     skip('fft.rfft2', ''), | 
 |     skip('fft.rfft', ''), | 
 |     skip('fft.rfftn', ''), | 
 |  | 
 |     # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse | 
 |     xfail('sparse.sampled_addmm'), | 
 |     xfail('sparse.mm', 'reduce'), | 
 |  | 
 |     # proxy tensor doesn't support sparse correctly right now | 
 |     skip('to_sparse'), | 
 |     # segfaults | 
 |     skip('block_diag'), | 
 | } | 
 |  | 
 | fake_tensor_failures = { | 
 |     # FakeTensor fallback doesn't work | 
 |     xfail('_segment_reduce', 'lengths'), | 
 |     xfail('multinomial'), | 
 |     xfail('cholesky'), | 
 |     xfail('cholesky_inverse'), | 
 |     # cannot do these as they rely on tensor data | 
 |     xfail('repeat_interleave'), | 
 |     # ASAN failures due to divide by 0 | 
 |     skip('nn.functional.nll_loss'), | 
 |  | 
 |     xfail('linalg.cond', ''), | 
 |     xfail("linalg.matrix_norm"), | 
 |     xfail("linalg.norm"), | 
 |     xfail("linalg.matrix_norm"), | 
 |     xfail("linalg.matrix_rank"), | 
 |     xfail("linalg.norm"), | 
 |     xfail("linalg.norm", "subgradients_at_zero"), | 
 |     xfail("linalg.svd"), | 
 |     xfail("linalg.svdvals"), | 
 |  | 
 |     xfail("norm", "nuc"), | 
 |     xfail("pca_lowrank"), | 
 |     xfail("stft"), | 
 |     xfail("svd"), | 
 |     xfail("svd_lowrank"), | 
 |     xfail("linalg.matrix_norm"), | 
 | } | 
 |  | 
 | symbolic_tensor_failures = { | 
 |     # Needs complex-value support | 
 |     xfail('polar'), | 
 |     xfail('linalg.eig'), | 
 |     xfail('linalg.eigvals'), | 
 |     xfail('cholesky_solve', ''),  # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back... | 
 |     xfail('combinations', ''), | 
 |     xfail('cumulative_trapezoid', ''),  # aten.slice.Tensor - couldn't find symbolic meta function/decomposition | 
 |     xfail('diff', ''),  # aten.empty_like.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('dsplit', ''),  # aten.slice.Tensor - couldn't find symbolic meta function/decomposition | 
 |     xfail('frexp', ''),  # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition | 
 |     xfail('geqrf', ''),  # aten.geqrf.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('gradient', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('histc', ''),  # Could not run 'aten::histc' with arguments from the 'Meta' backend. This could be because... | 
 |     xfail('histogram', ''),  # Could not run 'aten::histogram.bin_ct' with arguments from the 'Meta' backend. This c... | 
 |     xfail('histogramdd', ''),  # aten._histogramdd_bin_edges.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('hsplit', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('index_reduce', ''),  # Float | 
 |     xfail('inner', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('isin', ''),  # aten.isin.Tensor_Tensor - couldn't find symbolic meta function/decomposition | 
 |     xfail('kron', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('kthvalue', ''),  # aten.kthvalue.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.cond', ''),  # Tensors of type TensorImpl do not have numel | 
 |     xfail('linalg.eigh', ''),  # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.eigvalsh', ''),  # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.householder_product', ''),  # aten.linalg_householder_product.default - couldn't find symbolic meta funct... | 
 |     xfail('linalg.ldl_factor', ''),  # aten.linalg_ldl_factor_ex.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.ldl_factor_ex', ''),  # aten.linalg_ldl_factor_ex.default - couldn't find symbolic meta function/decompos... | 
 |     xfail('linalg.ldl_solve', ''),  # aten.linalg_ldl_solve.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.lu', ''),  # aten.linalg_lu.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.lu_factor', ''),  # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.lu_factor_ex', ''),  # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.lu_solve', ''),  # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.matrix_power'),  # RuntimeError: Trying to call aten.size on a tensor with symbolic shape | 
 |     xfail('linalg.matrix_rank', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.matrix_rank', 'hermitian'),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.multi_dot', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.pinv', ''),  # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.pinv', 'singular'),  # aten.linalg_cholesky_ex.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.pinv', 'hermitian'),  # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decompo... | 
 |     xfail('linalg.qr', ''),  # aten.linalg_qr.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.slogdet', ''),  # aten._linalg_slogdet.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.solve', ''),  # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.solve_ex', ''),  # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.solve_triangular', ''),  # aten.linalg_solve_triangular.default - couldn't find symbolic meta function/de... | 
 |     xfail('linalg.tensorinv', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.tensorsolve', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('linalg.vander', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('logaddexp2', ''),  # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('logdet', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('lu', ''),  # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('lu_solve', ''),  # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('lu_unpack', ''),  # aten.lu_unpack.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('masked_select', ''),  # aten.masked_select.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('matrix_exp', ''),  # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('median', ''),  # Could not run 'aten::median' with arguments from the 'Meta' backend. This could be becau... | 
 |     xfail('mode', ''),  # aten.mode.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('nanquantile', ''),  # Could not run 'aten::equal' with arguments from the 'Meta' backend. | 
 |     xfail('narrow', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('nn.functional.adaptive_max_pool1d', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('nn.functional.adaptive_max_pool2d', ''),  # aten.adaptive_max_pool2d.default - couldn't find symbolic meta funct... | 
 |     xfail('nn.functional.adaptive_max_pool3d', ''),  # argument 'output_size' (position 2) must be tupl... | 
 |     xfail('nn.functional.avg_pool3d', ''),  # aten.avg_pool3d.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('nn.functional.bilinear', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('nn.functional.binary_cross_entropy', ''),  # aten.new_empty.default - couldn't find symbolic meta function/decom... | 
 |     xfail('nn.functional.cosine_similarity', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('nn.functional.cross_entropy', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('nn.functional.ctc_loss'),  # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition | 
 |     xfail('nn.functional.embedding_bag', ''),  # aten._embedding_bag_forward_only.default - couldn't find symbolic meta fun... | 
 |     xfail('nn.functional.fractional_max_pool2d', ''),  # argument 'size' must be tuple of ints, but found element of t... | 
 |     xfail('nn.functional.fractional_max_pool3d', ''),  # argument 'size' must be tuple of ints, but found element of t... | 
 |     xfail('nn.functional.grid_sample', ''),  # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos... | 
 |     xfail('nn.functional.interpolate', 'linear'),  # aten.upsample_linear1d.vec - couldn't find symbolic meta function/dec... | 
 |     xfail('nn.functional.interpolate', 'trilinear'),  # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi... | 
 |     xfail('nn.functional.max_pool1d', ''),  # Trying to call aten.size on a tensor with symbolic shapes. | 
 |     xfail('nn.functional.max_pool3d', ''),  # aten.max_pool3d_with_indices.default - couldn't find symbolic meta function/d... | 
 |     xfail('nn.functional.max_unpool1d', 'grad'),  # aten.max_unpool2d.default - couldn't find symbolic meta function/decom... | 
 |     xfail('nn.functional.max_unpool2d', 'grad'),  # aten.max_unpool2d.default - couldn't find symbolic meta function/decom... | 
 |     xfail('nn.functional.max_unpool3d', 'grad'),  # aten.max_unpool3d.default - couldn't find symbolic meta function/decom... | 
 |     xfail('nn.functional.multi_margin_loss', ''),  # Could not run 'aten::multi_margin_loss' with arguments from the... | 
 |     xfail('nn.functional.multilabel_margin_loss', ''),  # Could not run 'aten::multilabel_margin_loss_forward' with ... | 
 |     xfail('nn.functional.pad', 'reflect'),  # aten.reflection_pad1d.default - couldn't find symbolic meta function/decompo... | 
 |     xfail('nn.functional.pad', 'replicate'),  # aten.replication_pad1d.default - couldn't find symbolic meta function/deco... | 
 |     xfail('nn.functional.pdist', ''),  # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend... | 
 |     xfail('nn.functional.pixel_unshuffle', ''),  # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco... | 
 |     xfail('nn.functional.smooth_l1_loss', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('normal', 'number_mean'),  # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition | 
 |     xfail('ormqr', ''),  # aten.ormqr.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('pca_lowrank', ''),  # aten.mm.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('pinverse', ''),  # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decomposition | 
 |     xfail('polygamma', 'polygamma_n_0'),  # aten.polygamma.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('polygamma', 'polygamma_n_1'),  # aten.polygamma.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('polygamma', 'polygamma_n_2'),  # aten.polygamma.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('polygamma', 'polygamma_n_3'),  # aten.polygamma.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('polygamma', 'polygamma_n_4'),  # aten.polygamma.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('quantile', ''),  # Could not run 'aten::equal' with arguments from the 'Meta' backend. | 
 |     xfail('qr', ''),  # aten.linalg_qr.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('renorm', ''),  # aten.renorm.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('repeat_interleave', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides | 
 |     xfail('resize_', ''),  # aten.clone.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('resize_as_', ''),  # aten.clone.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('roll', ''),  # Tensors of type TensorImpl do not have numel | 
 |     xfail('searchsorted', ''),  # Could not run 'aten::searchsorted.Tensor' with arguments from the 'Meta' backend. ... | 
 |     xfail('_segment_reduce', 'offsets'),  # aten.segment_reduce.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('special.airy_ai', ''),  # aten.special_airy_ai.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('special.bessel_y0', ''),  # aten.special_bessel_y0.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('special.bessel_y1', ''),  # aten.special_bessel_y1.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('special.chebyshev_polynomial_t', ''),  # aten.special_chebyshev_polynomial_t.default - couldn't find symbolic me... | 
 |     xfail('special.chebyshev_polynomial_u', ''),  # aten.special_chebyshev_polynomial_u.default - couldn't find symbolic me... | 
 |     xfail('special.hermite_polynomial_h', ''),  # aten.special_hermite_polynomial_h.default - couldn't find symbolic meta f... | 
 |     xfail('special.hermite_polynomial_he', ''),  # aten.special_hermite_polynomial_he.default - couldn't find symbolic meta... | 
 |     xfail('special.laguerre_polynomial_l', ''),  # aten.special_laguerre_polynomial_l.default - couldn't find symbolic meta... | 
 |     xfail('special.modified_bessel_i0', ''),  # aten.special_modified_bessel_i0.default - couldn't find symbolic meta funct... | 
 |     xfail('special.modified_bessel_i1', ''),  # aten.special_modified_bessel_i1.default - couldn't find symbolic meta funct... | 
 |     xfail('special.modified_bessel_k0', ''),  # aten.special_modified_bessel_k0.default - couldn't find symbolic meta funct... | 
 |     xfail('special.modified_bessel_k1', ''),  # aten.special_modified_bessel_k1.default - couldn't find symbolic meta funct... | 
 |     xfail('special.polygamma', 'special_polygamma_n_0'),  # aten.polygamma.default - couldn't find symbolic meta function/... | 
 |     xfail('special.scaled_modified_bessel_k0', ''),  # aten.special_scaled_modified_bessel_k0.default - couldn't find symbo... | 
 |     xfail('special.scaled_modified_bessel_k1', ''),  # aten.special_scaled_modified_bessel_k1.default - couldn't find symbo... | 
 |     xfail('stft', ''),  # argument 'size' must be tuple of ints, but found element of type torch._C.SymIntNode at... | 
 |     xfail('svd_lowrank', ''),  # aten.mm.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('take_along_dim', ''),  # dtype of indices should be Long but got Float | 
 |     xfail('tensordot', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('trapz', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('trapezoid', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('triangular_solve', ''),  # aten.triangular_solve.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('vsplit', ''),  # aten.size.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('unique_consecutive', ''),  # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('unique', ''),  # aten._unique2.default - couldn't find symbolic meta function/decomposition | 
 | } | 
 | symbolic_tensor_segfaults = { | 
 |     skip('nn.functional.batch_norm')  # Segfault?? | 
 | } | 
 |  | 
 | symbolic_tensor_failures.update(symbolic_tensor_segfaults) | 
 |  | 
 | outplace_symbolic_tensor_failures = { | 
 |     xfail('i0', ''),  # aten.i0.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('masked_scatter', ''),  # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition | 
 |     xfail('nn.functional.rrelu', ''),  # aten.empty_like.default - couldn't find symbolic meta function/decomposition | 
 | } | 
 |  | 
 | inplace_symbolic_tensor_failures = { | 
 |     # bugs | 
 |     xfail('float_power', ''),  # base given to float_power_ has dtype Float but the operation's result requires dtype Double | 
 |     # decomp not implemented | 
 |     xfail('unique', ''), | 
 |     # in-place has a different signature than out-of-place | 
 |     xfail('uniform', ''), | 
 |     # Views | 
 |     xfail('t', ''), | 
 |     xfail('transpose', ''), | 
 | } | 
 |  | 
 | # Copies inputs to inplace operations to avoid inplace modifications | 
 | #   to leaves requiring gradient | 
 | def _get_safe_inplace(inplace_variant): | 
 |     @functools.wraps(inplace_variant) | 
 |     def _fn(t, *args, **kwargs): | 
 |         return inplace_variant(t.clone(), *args, **kwargs) | 
 |  | 
 |     return _fn | 
 |  | 
 | def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False): | 
 |     def f(args, kwargs, extra_args, extra_kwargs): | 
 |         if extra_args: | 
 |             for i, t in extra_args: | 
 |                 args[i] = t.size() | 
 |         if extra_kwargs: | 
 |             for k, t in extra_kwargs.items(): | 
 |                 kwargs[k] = t.size() | 
 |  | 
 |         fn = _get_safe_inplace(op.get_inplace()) if inplace else op.op | 
 |         return fn(*args, **kwargs) | 
 |     sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) | 
 |     new_f = None | 
 |  | 
 |     # Limit ourselves to first 100 inputs so symbolic tracing tests don't take too long | 
 |     for sample_input in itertools.islice(sample_inputs_itr, 100): | 
 |         if inplace and sample_input.broadcasts_input: | 
 |             continue | 
 |         args = [sample_input.input] + list(sample_input.args) | 
 |         kwargs = sample_input.kwargs | 
 |  | 
 |         # If any argument is a torch.Size(), maybe get dynamic shapes for it by: | 
 |         # - Create a temporary Tensor whose size is the torch.Size() we want. Note that | 
 |         #   we use an expanded Tensor as we cannot pass "meta" Tensors to make_fx. | 
 |         # - Pass it to make_fx such that it is is converted to a proxy Tensor | 
 |         # - Unpack the size in the wrapper to get a torch.Size with dynamic shapes (in | 
 |         #   symbolic mode, a no-op otherwise) | 
 |         extra_args = [] | 
 |         extra_kwargs = {} | 
 |         for i, arg in enumerate(args): | 
 |             if isinstance(arg, torch.Size): | 
 |                 extra_args.append((i, torch.empty(arg, device="cpu"))) | 
 |         for key, value in kwargs.items(): | 
 |             if isinstance(value, torch.Size): | 
 |                 extra_kwargs[key] = torch.empty(value, device="cpu") | 
 |  | 
 |         try: | 
 |             new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs, extra_args, extra_kwargs) | 
 |         except DynamicOutputShapeException as e: | 
 |             self.skipTest("Dynamic output shape operation in trace") | 
 |         for arg in args: | 
 |             if isinstance(arg, torch.Tensor) and arg.dtype == torch.float: | 
 |                 arg.uniform_(0, 1) | 
 |         try: | 
 |             old_out = f(args, kwargs, extra_args, extra_kwargs) | 
 |         except Exception: | 
 |             continue | 
 |         new_out = wrapper_set_seed(new_f, args, kwargs, extra_args, extra_kwargs) | 
 |         self.assertEqual(new_out, old_out) | 
 |  | 
 | class TestProxyTensorOpInfo(TestCase): | 
 |     @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,)) | 
 |     @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures) | 
 |     def test_make_fx_exhaustive(self, device, dtype, op): | 
 |         _test_make_fx_helper(self, device, dtype, op, "real") | 
 |  | 
 |     @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,)) | 
 |     @skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures)) | 
 |     def test_make_fx_fake_exhaustive(self, device, dtype, op): | 
 |         _test_make_fx_helper(self, device, dtype, op, "fake") | 
 |  | 
 |     @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,)) | 
 |     @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive', | 
 |              make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | outplace_symbolic_tensor_failures) | 
 |     def test_make_fx_symbolic_exhaustive(self, device, dtype, op): | 
 |         _test_make_fx_helper(self, device, dtype, op, "symbolic") | 
 |  | 
 |     @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,)) | 
 |     @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace', | 
 |              make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | inplace_symbolic_tensor_failures) | 
 |     def test_make_fx_symbolic_exhaustive_inplace(self, device, dtype, op): | 
 |         if not op.get_inplace(): | 
 |             self.skipTest("No inplace variable for this op") | 
 |         _test_make_fx_helper(self, device, dtype, op, "symbolic", inplace=True) | 
 |  | 
 |  | 
 | only_for = ("cpu") | 
 | instantiate_device_type_tests(TestProxyTensorOpInfo, globals(), only_for=only_for) | 
 |  | 
 |  | 
 | if __name__ == '__main__': | 
 |     run_tests() |