| import torch | 
 | import unittest | 
 | import operator | 
 | import numbers | 
 | import pickle | 
 | import copy | 
 | import sys | 
 | import functools | 
 | import contextlib | 
 | from pathlib import Path | 
 | from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Tracer, Graph | 
 | from torch.fx.experimental import GraphManipulation | 
 | from torch.fx.experimental import shape_prop | 
 | from torch.fx.experimental.subgraph_creation_example import split_module | 
 | from torch.fx.immutable_collections import immutable_dict, immutable_list | 
 | from copy import deepcopy | 
 |  | 
 | from torch.fx.proxy import TraceError | 
 |  | 
 | from fx.quantization import Quantizer | 
 |  | 
 | from typing import Any, Callable, Dict, NamedTuple, List, Optional, Tuple, Union | 
 | from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, IS_WINDOWS, IS_SANDCASTLE, IS_MACOS | 
 | from torch.testing._internal.jit_utils import JitTestCase | 
 |  | 
 | try: | 
 |     from torchvision.models import resnet18 | 
 |     HAS_TORCHVISION = True | 
 | except ImportError: | 
 |     HAS_TORCHVISION = False | 
 | skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") | 
 |  | 
 | class SimpleTest(torch.nn.Module): | 
 |     def forward(self, x): | 
 |         return torch.relu(x + 3.0) | 
 |  | 
 | def a_non_torch_leaf(a, b): | 
 |     return a + b | 
 |  | 
 | class Pair(NamedTuple): | 
 |     x : torch.Tensor | 
 |     y : torch.Tensor | 
 |  | 
 | class TestFX(JitTestCase): | 
 |     def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None): | 
 |         """Check that an nn.Module's results match the GraphModule version | 
 |         for a given set of args/kwargs. | 
 |         """ | 
 |         kwargs = kwargs if kwargs else {} | 
 |         ref_outs = m(*args, **kwargs) | 
 |         gm = symbolic_trace(m) | 
 |         gm.graph.lint(gm) | 
 |         test_outs = gm(*args, **kwargs) | 
 |         self.assertEqual(ref_outs, test_outs) | 
 |  | 
 |     def test_graph_module(self): | 
 |         class MySub(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.w = torch.nn.Parameter(torch.rand(4, 3)) | 
 |  | 
 |             def forward(self, x): | 
 |                 return self.w + x | 
 |  | 
 |         class MyModule(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.lin = torch.nn.Linear(4, 3) | 
 |                 self.sub_mod = MySub() | 
 |                 self.w = torch.nn.Parameter(torch.rand(3)) | 
 |  | 
 |             def forward(self, A, B, c): | 
 |                 t = torch.sigmoid(A) + self.lin(c) | 
 |                 return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3)) | 
 |  | 
 |         m = MyModule() | 
 |         gm = symbolic_trace(m) | 
 |  | 
 |         ms = torch.jit.script(gm) | 
 |  | 
 |         class M2(torch.nn.Module): | 
 |             def forward(self, A): | 
 |                 m, idx = torch.max(A, 0) | 
 |                 return m + 1, idx + 1 | 
 |  | 
 |         m2 = M2() | 
 |         gm2 = symbolic_trace(m2) | 
 |  | 
 |         class T(torch.nn.Module): | 
 |  | 
 |             def forward(self, A, b=4, *args, c=5, **kwargs): | 
 |                 x = A + 1 + args[0] + kwargs['3'] | 
 |                 return x | 
 |  | 
 |         t = T() | 
 |         symbolic_trace(t) | 
 |  | 
 |     def test_custom_import(self): | 
 |         graph = torch.fx.Graph() | 
 |         a = graph.placeholder('x') | 
 |         b = graph.placeholder('y') | 
 |         c = graph.call_function(a_non_torch_leaf, (a, b)) | 
 |         d = graph.call_function(torch.sin, (c,)) | 
 |         graph.output(d) | 
 |         gm = GraphModule(torch.nn.Module(), graph) | 
 |         x, y = torch.rand(1), torch.rand(1) | 
 |         self.assertEqual(torch.sin(x + y), gm(x, y)) | 
 |  | 
 |     def test_args_kwargs(self): | 
 |         class T(torch.nn.Module): | 
 |             def forward(self, *args, **kwargs): | 
 |                 x = args[0] + kwargs['foo'] | 
 |                 return x | 
 |  | 
 |         t = T() | 
 |         self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)}) | 
 |  | 
 |     def test_args_kwargs_no_self(self): | 
 |         class T(torch.nn.Module): | 
 |             def forward(*args, **kwargs):  # noqa: B902 | 
 |                 self = args[0] | 
 |                 return torch.relu(args[1]) | 
 |  | 
 |         t = T() | 
 |         with self.assertRaisesRegex(RuntimeError, r'cannot be part of \*args expansion'): | 
 |             self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)}) | 
 |  | 
 |     def test_fx_shifts(self): | 
 |         class MyModule(torch.nn.Module): | 
 |             def forward(self, x): | 
 |                 return x << 3, x >> 3 | 
 |  | 
 |         input = torch.LongTensor(10).random_(0, 1024) | 
 |  | 
 |         m = MyModule() | 
 |         self.checkGraphModule(m, (input,)) | 
 |  | 
 |     def test_dict(self): | 
 |         class MyDictMod(torch.nn.Module): | 
 |             def forward(self, d): | 
 |                 return d['3'].relu(), {'4' : d['3'].neg()} | 
 |  | 
 |         input_dict = {'3': torch.rand(3, 4)} | 
 |         m = MyDictMod() | 
 |  | 
 |         self.checkGraphModule(m, (input_dict,)) | 
 |  | 
 |     def test_disallow_override(self): | 
 |         # Custom delegate to disallow in-place tensor operations | 
 |         class NoMutableCallTracer(Tracer): | 
 |             def create_node(self, kind : str, target : Union[str, Callable], | 
 |                             args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None, | 
 |                             type_expr : Optional[Any] = None) -> Node: | 
 |                 name = target if isinstance(target, str) else torch.typename(target) | 
 |                 if name[-1] == '_': | 
 |                     raise RuntimeError('In-place operations are not supported') | 
 |                 return super().create_node(kind, target, args, kwargs, name) | 
 |  | 
 |         # Test method | 
 |         class MyInplaceMod(torch.nn.Module): | 
 |             def forward(self, x): | 
 |                 x.add_(3.0) | 
 |                 return x | 
 |  | 
 |         m = MyInplaceMod() | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, 'In-place operations'): | 
 |             NoMutableCallTracer().trace(m) | 
 |  | 
 |         # Test free function | 
 |         class MyInplaceMod2(torch.nn.Module): | 
 |             def forward(self, x): | 
 |                 torch.log_(x) | 
 |                 return x | 
 |         m2 = MyInplaceMod2() | 
 |         with self.assertRaisesRegex(RuntimeError, 'In-place operations'): | 
 |             NoMutableCallTracer().trace(m2) | 
 |  | 
 |         # Test symbolic node as an arg | 
 |         class MyInplaceMod3(torch.nn.Module): | 
 |             def forward(self, x): | 
 |                 y = torch.ones(3, 4) | 
 |                 y.add_(x) | 
 |                 return x | 
 |         m3 = MyInplaceMod3() | 
 |         with self.assertRaisesRegex(RuntimeError, 'In-place operations'): | 
 |             NoMutableCallTracer().trace(m3) | 
 |  | 
 |     def test_leaf_module(self): | 
 |         # Custom delegate to make it so that there are no leaf modules, everything | 
 |         # should get traced through | 
 |         class NoLeafModulesTracer(Tracer): | 
 |             def is_leaf_module(self, m, qualname): | 
 |                 return False | 
 |  | 
 |         class MyReluMod(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.relu = torch.nn.ReLU() | 
 |  | 
 |             def forward(self, x): | 
 |                 return self.relu(x) | 
 |  | 
 |         mrm = MyReluMod() | 
 |         sym = NoLeafModulesTracer().trace(mrm) | 
 |         for node in sym.nodes: | 
 |             self.assertNotEqual(node.op, 'call_module') | 
 |         sym.lint(sym) | 
 |  | 
 |     def test_graph_edit_with_proxy(self): | 
 |         class M(torch.nn.Module): | 
 |             def forward(self, a, b): | 
 |                 return a + b | 
 |         m = M() | 
 |         g = symbolic_trace(m).graph | 
 |         new_g = torch.fx.Graph() | 
 |         val_map : Dict[Node, Node] = {} | 
 |         output_val = new_g.graph_copy(g, val_map) | 
 |         t = Proxy(output_val) | 
 |         # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. | 
 |         new_g.output((t + t).node) | 
 |         gm = GraphModule(m, new_g) | 
 |         gm.graph.lint(gm) | 
 |         self.assertEqual(gm(3, 4), 14) | 
 |  | 
 |     def test_graph_unique_names(self): | 
 |         class M(torch.nn.Module): | 
 |             def forward(self, a, b): | 
 |                 return a + b | 
 |         m = M() | 
 |         g = symbolic_trace(m).graph | 
 |         new_g = torch.fx.Graph() | 
 |         val_map : Dict[Node, Node] = {} | 
 |         output_val = new_g.graph_copy(g, val_map) | 
 |         t = Proxy(output_val) | 
 |         # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. | 
 |         new_g.output((t + t).node) | 
 |         gm = GraphModule(m, new_g) | 
 |         seen_names : Set[str] = set() | 
 |         for node in gm.graph.nodes: | 
 |             assert node.name not in seen_names | 
 |             seen_names.add(node.name) | 
 |  | 
 |     def test_graph_unique_names_manual(self): | 
 |         graph : torch.fx.Graph = torch.fx.Graph() | 
 |         a : torch.fx.Node = graph.create_node('placeholder', 'x') | 
 |         b : torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,), name='foo_1_1') | 
 |         c : torch.fx.Node = graph.create_node('get_attr', 'y_attr', name='foo_1') | 
 |         d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) | 
 |         graph.output(d) | 
 |         graph2 = torch.fx.Graph() | 
 |         val_map : Dict[Node, Node] = {} | 
 |         graph2.graph_copy(graph, val_map) | 
 |         seen_names : Set[str] = set() | 
 |         for node in graph2.nodes: | 
 |             assert node.name not in seen_names | 
 |             seen_names.add(node.name) | 
 |  | 
 |     @skipIfNoTorchVision | 
 |     def test_resnet(self): | 
 |         resnet = resnet18() | 
 |         resnet.train() | 
 |  | 
 |         res_graph = symbolic_trace(resnet) | 
 |         res_script = torch.jit.script(res_graph) | 
 |  | 
 |         ip = torch.rand(1, 3, 224, 224) | 
 |  | 
 |         a = resnet(ip) | 
 |         b = res_graph(ip) | 
 |         c = res_script(ip) | 
 |         self.assertEqual(a, b) | 
 |         self.assertEqual(a, c) | 
 |  | 
 |         quantizer = Quantizer(res_graph) | 
 |  | 
 |         for i in range(10): | 
 |             quantizer.observe((torch.rand(1, 3, 224, 224),)) | 
 |  | 
 |         qgraph = quantizer.quantize() | 
 |         qgraph.graph.lint(qgraph) | 
 |         qgraph_script = torch.jit.script(qgraph) | 
 |  | 
 |         d = qgraph(ip) | 
 |         e = qgraph_script(ip) | 
 |  | 
 |         assert (a - d).abs().max() < 2 | 
 |         self.assertEqual(d, e) | 
 |  | 
 |     def test_unpack(self): | 
 |         class M(torch.nn.Module): | 
 |             def forward(self, a, b): | 
 |                 c, d = a | 
 |                 return c + d + b | 
 |  | 
 |         a = (torch.rand(1), torch.rand(1)) | 
 |         b = torch.rand(1) | 
 |         m = M() | 
 |         self.checkGraphModule(m, (a, b)) | 
 |  | 
 |     def test_native_callable(self): | 
 |         if TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS: | 
 |             raise unittest.SkipTest("non-portable load_library call used in test") | 
 |         torch_root = Path(__file__).resolve().parent.parent | 
 |         p = torch_root / 'build' / 'lib' / 'libtorchbind_test.so' | 
 |         torch.ops.load_library(str(p)) | 
 |         # This test exercises the case where we use FX to translate from Python | 
 |         # code to some native callable object | 
 |         # | 
 |         # For the purposes of testing, we use ElementwiseInterpreter defined | 
 |         # in test_custom_class.cpp. | 
 |         # | 
 |         # We test that we can | 
 |         # 1) Construct a native callable from FX IR | 
 |         # 2) Construct a drop-in replacement module that delegates to the | 
 |         #    native callable rather than the original code | 
 |         # 3) Run both the original code and native callable wrapper with | 
 |         #    equivalent results | 
 |         # 4) TorchScript compile the native callable wrapper and confirm | 
 |         #    equivalent results with the reference | 
 |         # 5) TorchScript serialize and deserialize the native callable | 
 |         #    and confirm equivalent results with the reference | 
 |  | 
 |         # We use this simple Module as a reference computation | 
 |         class MySimpleMod(torch.nn.Module): | 
 |             def forward(self, x): | 
 |                 return 3.0 * x + x | 
 |  | 
 |         msm = MySimpleMod() | 
 |  | 
 |         # This is what a lowering pass might look like: a function that takes | 
 |         # a valid nn.Module, symbolically traces it, lowers the Module to some | 
 |         # representation, and wraps that representation up into another | 
 |         # nn.Module instance that handles dispatch to the compiled/lowered code. | 
 |         def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Module: | 
 |             # ===== Stage 1: Symbolic trace the module ===== | 
 |             mod = symbolic_trace(orig_mod) | 
 |  | 
 |             # ===== Stage 2: Lower GraphModule representation to the C++ | 
 |             #       interpreter's instruction format ====== | 
 |             instructions = [] | 
 |             constant_idx = 0 | 
 |             constants = {} | 
 |             fn_input_names = [] | 
 |  | 
 |             target_to_name = { | 
 |                 operator.add : "add", | 
 |                 operator.mul : "mul" | 
 |             } | 
 |  | 
 |             output_node : Optional[Node] = None | 
 |             # For each instruction, create a triple | 
 |             # (instruction_name : str, inputs : List[str], output : str) | 
 |             # to feed into the C++ interpreter | 
 |             for n in mod.graph.nodes: | 
 |                 target, args, out_name = n.target, n.args, n.name | 
 |                 assert len(n.kwargs) == 0, "kwargs currently not supported" | 
 |  | 
 |                 if n.op == 'placeholder': | 
 |                     # Placeholders specify function argument names. Save these | 
 |                     # for later when we generate the wrapper GraphModule | 
 |                     fn_input_names.append(target) | 
 |                 elif n.op == 'call_function': | 
 |                     assert target in target_to_name, "Unsupported call target " + target | 
 |                     arg_names = [] | 
 |                     for arg in args: | 
 |                         if not isinstance(arg, Node): | 
 |                             # Pull out constants. These constants will later be | 
 |                             # fed to the interpreter C++ object via add_constant() | 
 |                             arg_name = f'constant_{constant_idx}' | 
 |                             constants[arg_name] = torch.Tensor( | 
 |                                 [arg] if isinstance(arg, numbers.Number) else arg) | 
 |                             arg_names.append(arg_name) | 
 |                             constant_idx += 1 | 
 |                         else: | 
 |                             arg_names.append(arg.name) | 
 |                     instructions.append((target_to_name[target], arg_names, out_name)) | 
 |                 elif n.op == 'output': | 
 |                     if output_node is not None: | 
 |                         raise RuntimeError('Multiple output nodes!') | 
 |                     output_node = n | 
 |                 else: | 
 |                     raise RuntimeError('Unsupported opcode ' + n.op) | 
 |  | 
 |             interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter() | 
 |             # Load constants | 
 |             for k, v in constants.items(): | 
 |                 interpreter.add_constant(k, v) | 
 |             # Specify names for positional input arguments | 
 |             interpreter.set_input_names(fn_input_names) | 
 |             # Load instructions | 
 |             interpreter.set_instructions(instructions) | 
 |             # Specify name for single output | 
 |             assert isinstance(output_node.args[0], torch.fx.Node) | 
 |             interpreter.set_output_name(output_node.args[0].name) | 
 |  | 
 |             # ===== Stage 3: Create a wrapper GraphModule around the interpreter ===== | 
 |             class WrapperModule(torch.nn.Module): | 
 |                 def __init__(self, interpreter): | 
 |                     super().__init__() | 
 |                     self.interpreter = interpreter | 
 |  | 
 |             wrapper = WrapperModule(interpreter) | 
 |  | 
 |             # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter | 
 |             # 3) Returns the speficied return value | 
 |  | 
 |             # FIXME: The following code could be greatly simplified by symbolic_trace'ing | 
 |             # the wrapper with a Tracer that considers the Wrapper instance a root | 
 |             # module, however, I can't get `__call__` exposed on TorchBind classes | 
 |             # without it messing up Python `hasattr` for some reason. More digging | 
 |             # into CPython's implementation of hasattr is probably in order... | 
 |  | 
 |             graph = torch.fx.Graph() | 
 |             # Add placeholders for fn inputs | 
 |             placeholder_nodes = [] | 
 |             for name in fn_input_names: | 
 |                 placeholder_nodes.append(graph.create_node('placeholder', name)) | 
 |  | 
 |             # Get the interpreter object | 
 |             interpreter_node = graph.create_node('get_attr', 'interpreter') | 
 |  | 
 |             # Add a node to call the interpreter instance | 
 |             output_node = graph.create_node( | 
 |                 op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes)) | 
 |  | 
 |             # Register output | 
 |             graph.output(output_node) | 
 |  | 
 |             graph.lint(wrapper) | 
 |  | 
 |             # Return final GraphModule!!! | 
 |             return GraphModule(wrapper, graph) | 
 |  | 
 |  | 
 |         # Lower GraphModule to C++ interpreter | 
 |         lowered = lower_to_elementwise_interpreter(msm) | 
 |  | 
 |         # Compare correctness with original module | 
 |         x = torch.rand(3, 4) | 
 |         ref_out = msm(x) | 
 |         test_out = lowered(x) | 
 |         torch.testing.assert_allclose(test_out, ref_out) | 
 |  | 
 |         # Test TorchScript compilation | 
 |         scripted_lowered = torch.jit.script(lowered) | 
 |         script_out = scripted_lowered(x) | 
 |         torch.testing.assert_allclose(script_out, ref_out) | 
 |  | 
 |         # Test TorchScript ser/de | 
 |         import_copy = self.getExportImportCopy(scripted_lowered) | 
 |         imported_out = import_copy(x) | 
 |         torch.testing.assert_allclose(imported_out, ref_out) | 
 |  | 
 |     def test_reserved_getattr(self): | 
 |         """Ensure that we do not name any nodes with a reserved builtin like `getattr`""" | 
 |         class M(torch.nn.Module): | 
 |             def forward(self, a): | 
 |                 return a.foo.bar.baz | 
 |  | 
 |         m = M() | 
 |         m_g = symbolic_trace(m) | 
 |         m_g.graph.lint(m_g) | 
 |         for node in m_g.graph.nodes: | 
 |             self.assertTrue(node.name != "getattr") | 
 |  | 
 |     def test_node_tagging(self): | 
 |         class TaggingTracer(Tracer): | 
 |             def create_node(self, kind : str, target : Union[str, Callable], | 
 |                             args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None, | 
 |                             type_expr : Optional[Any] = None) -> Node: | 
 |                 n = super().create_node(kind, target, args, kwargs, name) | 
 |                 n.tag = 'foo' | 
 |                 return n | 
 |  | 
 |         class M(torch.nn.Module): | 
 |             def forward(self, a, b): | 
 |                 return a + b | 
 |  | 
 |         m = M() | 
 |         g = TaggingTracer().trace(m) | 
 |         g.lint(m) | 
 |         for n in g.nodes: | 
 |             self.assertTrue(hasattr(n, 'tag')) | 
 |             self.assertEqual(n.tag, 'foo') | 
 |  | 
 |     def test_tensor_attribute(self): | 
 |         class TensorAttribute(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.tensor = torch.rand(3, 4) | 
 |  | 
 |             def forward(self, x): | 
 |                 return torch.nn.functional.linear(x, self.tensor) | 
 |  | 
 |         ta = TensorAttribute() | 
 |         traced = symbolic_trace(ta) | 
 |         traced(torch.rand(4, 4)) | 
 |  | 
 |         class WrapperForQualname(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.ta = TensorAttribute() | 
 |  | 
 |             def forward(self, x): | 
 |                 return torch.nn.functional.linear(x, self.ta.tensor) | 
 |  | 
 |         wfq = WrapperForQualname() | 
 |         traced2 = symbolic_trace(wfq) | 
 |         traced2.graph.lint(traced2) | 
 |         traced2(torch.rand(4, 4)) | 
 |  | 
 |     def test_symbolic_trace_sequential(self): | 
 |         class Simple(torch.nn.Module): | 
 |             def forward(self, x): | 
 |                 return torch.neg(x) | 
 |  | 
 |         seq = torch.nn.Sequential( | 
 |             Simple(), | 
 |             Simple(), | 
 |             Simple() | 
 |         ) | 
 |         traced = symbolic_trace(seq) | 
 |         traced.graph.lint(traced) | 
 |         x = torch.rand(3, 4) | 
 |         self.assertEqual(traced(x), seq(x)) | 
 |  | 
 |     def test_tensor_constant(self): | 
 |         class ConstTensor(torch.nn.Module): | 
 |             def forward(self, x): | 
 |                 return torch.nn.functional.linear(x, torch.zeros(3, 4)) | 
 |  | 
 |         ct = ConstTensor() | 
 |         traced = symbolic_trace(ct) | 
 |         traced.graph.lint(traced) | 
 |         traced(torch.rand(4, 4)) | 
 |  | 
 |     def test_pickle_graphmodule(self): | 
 |         class Nested(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.st = torch.nn.Linear(4, 4) | 
 |  | 
 |             def forward(self, x): | 
 |                 return self.st(x) | 
 |  | 
 |         n = Nested() | 
 |         traced = symbolic_trace(n) | 
 |         traced.graph.lint(traced) | 
 |         pickled = pickle.dumps(traced) | 
 |         loaded = pickle.loads(pickled) | 
 |         loaded.graph.lint(loaded) | 
 |         x = torch.rand(3, 4) | 
 |         self.assertEqual(loaded(x), traced(x)) | 
 |  | 
 |     def test_deepcopy_graphmodule_with_transform(self): | 
 |         st = SimpleTest() | 
 |         traced = symbolic_trace(st) | 
 |         traced.graph.lint(traced) | 
 |  | 
 |         def transform(traced): | 
 |             new_graph = torch.fx.Graph() | 
 |             val_map : Dict[Node, Node] = {} | 
 |             output_value = new_graph.graph_copy(traced.graph, val_map) | 
 |             relu_out = new_graph.create_node( | 
 |                 op='call_method', target='neg', args=(output_value,), kwargs={}) | 
 |             new_graph.output(relu_out) | 
 |             return GraphModule(traced, new_graph) | 
 |         transformed = transform(traced) | 
 |         transformed.graph.lint(transformed) | 
 |         copied = copy.deepcopy(transformed) | 
 |         self.assertNotEqual(id(type(transformed)), id(type(copied))) | 
 |         x = torch.randn(3, 4) | 
 |         self.assertEqual(copied(x), transformed(x)) | 
 |  | 
 |     def test_deepcopy_with_submods_params(self): | 
 |         class Bar(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.param = torch.nn.Parameter(torch.rand(3, 4)) | 
 |  | 
 |             def forward(self, x): | 
 |                 return torch.relu(x) + self.param | 
 |  | 
 |         class Baz(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.param = torch.nn.Parameter(torch.rand(3, 4)) | 
 |                 self.bar = Bar() | 
 |  | 
 |             def forward(self, x): | 
 |                 return self.bar(x) - self.param | 
 |  | 
 |         baz = Baz() | 
 |         traced = symbolic_trace(baz) | 
 |         traced.graph.lint(traced) | 
 |         copied = copy.deepcopy(traced) | 
 |         copied.graph.lint(copied) | 
 |  | 
 |     def test_unpack_list_better_error(self): | 
 |         class SomeArgs(torch.nn.Module): | 
 |             def forward(self, a, b): | 
 |                 return torch.rand(3, 4) | 
 |  | 
 |         class UnpacksList(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.sa = SomeArgs() | 
 |  | 
 |             def forward(self, x : list): | 
 |                 return self.sa(*x) | 
 |  | 
 |         ul = UnpacksList() | 
 |         with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'): | 
 |             symbolic_trace(ul) | 
 |  | 
 |     def test_unpack_dict_better_error(self): | 
 |         class SomeKwargs(torch.nn.Module): | 
 |             def forward(self, x=3, y=4): | 
 |                 return torch.rand(3, 4) | 
 |  | 
 |         class UnpacksDict(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.sk = SomeKwargs() | 
 |  | 
 |             def forward(self, x : dict): | 
 |                 return self.sk(**x) | 
 |  | 
 |         ud = UnpacksDict() | 
 |         with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'): | 
 |             symbolic_trace(ud) | 
 |  | 
 |     def test_torch_custom_ops(self): | 
 |         class M(torch.nn.Module): | 
 |             def forward(self, a): | 
 |                 b = torch.ops.aten.sigmoid(a) | 
 |                 c = torch.ops.aten.cat([a, b]) | 
 |                 return torch.ops.aten.cat((c, c)) | 
 |         m = M() | 
 |         input = torch.randn(3) | 
 |         ref_out = m(input) | 
 |         gm = symbolic_trace(m) | 
 |         gm.graph.lint(gm) | 
 |         out = gm(input) | 
 |         self.assertEqual(out, ref_out) | 
 |  | 
 |     def test_replace_target_nodes_with(self): | 
 |         class testModule(torch.nn.Module): | 
 |             def forward(self, a, b): | 
 |                 return a + b | 
 |         m = testModule() | 
 |         traced = symbolic_trace(m) | 
 |         input1 = torch.randn(1) | 
 |         input2 = torch.randn(1) | 
 |         assert (input1 + input2) == traced(input1, input2) | 
 |         GraphManipulation.replace_target_nodes_with( | 
 |             fx_module=traced, | 
 |             old_op="call_function", | 
 |             old_target=operator.add, | 
 |             new_op="call_function", | 
 |             new_target=operator.mul, | 
 |         ) | 
 |         assert (input1 * input2) == traced(input1, input2) | 
 |  | 
 |     def test_pretty_print(self): | 
 |         st = SimpleTest() | 
 |         traced = symbolic_trace(st) | 
 |         traced.graph.lint(traced) | 
 |         printed = str(traced) | 
 |         assert 'GraphModuleImpl()' in printed | 
 |         assert 'torch.relu' in printed | 
 |  | 
 |     def test_pretty_print_graph(self): | 
 |         class KwargPrintTest(torch.nn.Module): | 
 |             def forward(self, x): | 
 |                 return torch.squeeze(x + 3.0, dim=2) | 
 |         st = KwargPrintTest() | 
 |         traced = symbolic_trace(st) | 
 |         traced.graph.lint(traced) | 
 |         stringed = str(traced.graph) | 
 |         for s in ['args', 'kwargs', '#users']: | 
 |             assert s in stringed | 
 |  | 
 |     def test_graph_fns(self): | 
 |         g = Graph() | 
 |         a = g.placeholder('a') | 
 |         b = g.call_module('linear', (a,)) | 
 |         c = g.get_attr('bias') | 
 |         d = g.call_method('add', (b, c)) | 
 |         e = g.call_function(torch.sin, (d,)) | 
 |         g.output(e) | 
 |         mod = torch.nn.Module() | 
 |         mod.linear = torch.nn.Linear(3, 4) | 
 |         mod.bias = torch.rand(4) | 
 |         gm = GraphModule(mod, g) | 
 |         gm.graph.lint(gm) | 
 |         input = torch.rand(3) | 
 |         r = gm(input) | 
 |         ref = torch.sin(mod.linear(input) + mod.bias) | 
 |         self.assertEqual(r, ref) | 
 |  | 
 |     def test_construct_root_dict(self): | 
 |         graph : torch.fx.Graph = torch.fx.Graph() | 
 |         a : torch.fx.Node = graph.create_node('placeholder', 'x') | 
 |         b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,)) | 
 |         c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam') | 
 |         d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) | 
 |         graph.output(d) | 
 |  | 
 |         linear_mod : torch.nn.Module = torch.nn.Linear(3, 4) | 
 |         add_param : torch.Tensor = torch.rand(3, 4) | 
 |         gm : torch.fx.GraphModule = torch.fx.GraphModule( | 
 |             {'foo.bar.baz': linear_mod, 'zip.zap.zam' : add_param}, graph) | 
 |         gm.graph.lint(gm) | 
 |  | 
 |         assert 'self.foo.bar.baz' in gm.code | 
 |  | 
 |         x : torch.Tensor = torch.rand(3, 3) | 
 |         out : torch.Tensor = gm(x) | 
 |         ref_out : torch.Tensor = linear_mod(x) + add_param | 
 |         self.assertEqual(out, ref_out) | 
 |  | 
 |     def test_symbolic_trace_assert(self): | 
 |         message = "assert_foobar" | 
 |  | 
 |         class AssertsTensorShape(torch.nn.Module): | 
 |             def forward(self, x): | 
 |                 torch.Assert(x.shape[1] > 4, message) | 
 |                 return x | 
 |  | 
 |         m = AssertsTensorShape() | 
 |         # verify traceability | 
 |         traced = symbolic_trace(m) | 
 |         # verify assertion on traced model works correctly at runtime | 
 |         traced(torch.rand(4, 5)) | 
 |         with self.assertRaisesRegex(AssertionError, message): | 
 |             traced(torch.rand(4, 3)) | 
 |  | 
 |     def test_copy_no_remap(self): | 
 |         traced = symbolic_trace(SimpleTest()) | 
 |         g = traced.graph | 
 |         copied = torch.fx.Graph() | 
 |         for node in g.nodes: | 
 |             copied.node_copy(node) | 
 |         with self.assertRaisesRegex(RuntimeError, 'does not belong to this Graph'): | 
 |             copied.lint() | 
 |  | 
 |     def test_wrong_topo(self): | 
 |         graph : torch.fx.Graph = torch.fx.Graph() | 
 |         a : torch.fx.Node = graph.create_node('placeholder', 'x') | 
 |         b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,)) | 
 |         c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam') | 
 |         d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) | 
 |         graph.output(d) | 
 |         nodes = list(graph.nodes) | 
 |         nodes[3].append(nodes[2]) | 
 |         with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'): | 
 |             graph.lint() | 
 |  | 
 |     def test_example_shape_prop(self): | 
 |         class TestCase(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.attr = torch.randn(3, 4) | 
 |                 self.submod = torch.nn.Linear(4, 4) | 
 |  | 
 |             def forward(self, x): | 
 |                 return torch.neg(self.submod(x.relu() + self.attr)) | 
 |         tc = TestCase() | 
 |         tc_traced = symbolic_trace(tc) | 
 |         ref_out = tc_traced(torch.rand(3, 4)) | 
 |         shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4)) | 
 |  | 
 |         # Make sure we're testing all opcodes | 
 |         opcodes = set() | 
 |         output_shape : Optional[torch.Shape] = None | 
 |         for node in tc_traced.graph.nodes: | 
 |             opcodes.add(node.op) | 
 |             if node.op == 'output': | 
 |                 output_shape = node.args[0].shape | 
 |         self.assertEqual(opcodes, set(['placeholder', 'get_attr', 'call_function', 'call_method', | 
 |                                        'call_module', 'output'])) | 
 |  | 
 |         # Test shape propogation and make sure results match actual | 
 |         self.assertEqual(output_shape, ref_out.shape) | 
 |  | 
 |     def test_fn_type_annotations(self): | 
 |         class Foo(torch.nn.Module): | 
 |             def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]: | 
 |                 return {'a': p.x + p.y + z + i} | 
 |  | 
 |         foo_scripted = torch.jit.script(Foo()) | 
 |         foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) | 
 |  | 
 |         fxed = symbolic_trace(Foo()) | 
 |         fxed_scripted = torch.jit.script(fxed) | 
 |         fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) | 
 |  | 
 |     def test_fn_type_annotation_empty(self): | 
 |         def forward(a : List[torch.Tensor]): | 
 |             return a[0] | 
 |         torch.jit.script(symbolic_trace(forward)) | 
 |  | 
 |     def test_wrapped_method(self): | 
 |         def wrap_with_relu(fn): | 
 |             @functools.wraps(fn) | 
 |             def wrapper(*args, **kwargs): | 
 |                 return torch.relu(fn(*args, **kwargs)) | 
 |             return wrapper | 
 |  | 
 |         class Foo(torch.nn.Module): | 
 |             @wrap_with_relu | 
 |             def forward(self, x, w): | 
 |                 return torch.matmul(x, w) | 
 |  | 
 |         f = Foo() | 
 |         traced = symbolic_trace(f) | 
 |         x, w = torch.rand(3, 4), torch.rand(4, 4) | 
 |         self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes)) | 
 |  | 
 |     def test_sequential(self): | 
 |         m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)) | 
 |         gm = torch.fx.symbolic_trace(m) | 
 |         gm_copy = copy.deepcopy(gm) | 
 |  | 
 |     def test_ctx_mgr(self): | 
 |         @contextlib.contextmanager | 
 |         def do_nothing(): | 
 |             yield | 
 |  | 
 |         class M(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |  | 
 |             @do_nothing() | 
 |             def forward(self, x): | 
 |                 return torch.relu(x) | 
 |  | 
 |         m = M() | 
 |         self.checkGraphModule(m, (torch.rand(3, 4),)) | 
 |  | 
 |     def test_typename_print(self): | 
 |         graph : torch.fx.Graph = torch.fx.Graph() | 
 |         x : torch.fx.Node = graph.create_node('placeholder', 'x') | 
 |         b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,), | 
 |                                               type_expr=List[float]) | 
 |         output : torch.fx.Node = graph.output(b) | 
 |         self.assertTrue('typing.List[float]' in str(graph)) | 
 |  | 
 |     def test_inf_nan(self): | 
 |         class FooMod(torch.nn.Module): | 
 |             def forward(self, x): | 
 |                 return x + float('inf'), x + float('-inf'), x + float('nan') | 
 |  | 
 |         fm = FooMod() | 
 |         self.checkGraphModule(fm, (torch.rand(3, 4),)) | 
 |  | 
 |     def test_inf_nan_kwds(self): | 
 |         graph : torch.fx.Graph = torch.fx.Graph() | 
 |         x : torch.fx.Node = graph.create_node('placeholder', 'x') | 
 |         b : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('inf')), {}, name='inf') | 
 |         c : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('nan')), {}, name='nan') | 
 |         graph.output((b, c)) | 
 |  | 
 |         gm = torch.fx.GraphModule(torch.nn.Module(), graph) | 
 |         x = torch.rand(3, 4) | 
 |         self.assertEqual(gm(x), (x + float('inf'), x + float('nan'))) | 
 |  | 
 |     def test_subgraph_creation(self): | 
 |         class MyModule(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.param = torch.nn.Parameter(torch.rand(3, 4)) | 
 |                 self.linear = torch.nn.Linear(4, 5) | 
 |  | 
 |             def forward(self, x, y): | 
 |                 z = self.linear(x + self.param).clamp(min=0.0, max=1.0) | 
 |                 w = self.linear(y).clamp(min=0.0, max=1.0) | 
 |                 return z + w | 
 |  | 
 |         # symbolically trace model | 
 |         my_module = MyModule() | 
 |         my_module_traced = symbolic_trace(my_module) | 
 |  | 
 |         # random mod partitioning | 
 |         partition_counter = 0 | 
 |         NPARTITIONS = 3 | 
 |  | 
 |         def mod_partition(node: Node): | 
 |             nonlocal partition_counter | 
 |             partition = partition_counter % NPARTITIONS | 
 |             partition_counter = (partition_counter + 1) % NPARTITIONS | 
 |             return partition | 
 |  | 
 |         # split module in module with submodules | 
 |         module_with_submodules = split_module(my_module_traced, my_module, mod_partition) | 
 |  | 
 |         x = torch.rand(3, 4) | 
 |         y = torch.rand(3, 4) | 
 |  | 
 |         orig_out = my_module_traced(x, y) | 
 |         submodules_out = module_with_submodules(x, y) | 
 |  | 
 |         self.assertEqual(orig_out, submodules_out) | 
 |  | 
 |     def test_deepcopy_recursion_depth(self): | 
 |         depth = sys.getrecursionlimit() + 20 | 
 |  | 
 |         g = torch.fx.Graph() | 
 |         x = g.placeholder('x') | 
 |         for i in range(depth): | 
 |             x = g.call_function(torch.relu, (x,)) | 
 |         g.output(x) | 
 |  | 
 |         copied_graph = copy.deepcopy(g) | 
 |  | 
 |         val_map = {} | 
 |         for orig_node, new_node in zip(g.nodes, copied_graph.nodes): | 
 |             val_map[orig_node] = new_node | 
 |  | 
 |         for orig_node, new_node in zip(g.nodes, copied_graph.nodes): | 
 |             orig_users = set(orig_node.users.keys()) | 
 |             orig_users_equiv = set(val_map[u] for u in orig_users) | 
 |             new_users = set(new_node.users.keys()) | 
 |             self.assertEqual(orig_users_equiv, new_users) | 
 |  | 
 |     @skipIfNoTorchVision | 
 |     def test_replace_uses(self): | 
 |         rn18 = resnet18() | 
 |  | 
 |         class LowerReluTracer(torch.fx.Tracer): | 
 |             def is_leaf_module(self, m : torch.nn.Module, qualname : str): | 
 |                 if isinstance(m, torch.nn.ReLU): | 
 |                     return False | 
 |                 return super().is_leaf_module(m, qualname) | 
 |  | 
 |         rn18_traced = GraphModule(rn18, LowerReluTracer().trace(rn18)) | 
 |  | 
 |         to_erase = [] | 
 |         for node in rn18_traced.graph.nodes: | 
 |             if node.op == 'call_function' and node.target in [torch.relu, torch.nn.functional.relu]: | 
 |                 kwargs = node.kwargs.copy() | 
 |                 # Neg doesn't have in-place | 
 |                 kwargs.pop('inplace') | 
 |                 with rn18_traced.graph.inserting_before(node): | 
 |                     new_node = rn18_traced.graph.call_function( | 
 |                         the_function=torch.neg, args=node.args, kwargs=node.kwargs) | 
 |                 node.replace_all_uses_with(replace_with=new_node) | 
 |                 to_erase.append(node) | 
 |  | 
 |         for node in to_erase: | 
 |             rn18_traced.graph.erase_node(node) | 
 |  | 
 |     def test_insertion_point(self): | 
 |         graph : torch.fx.Graph = torch.fx.Graph() | 
 |         x : torch.fx.Node = graph.create_node('placeholder', 'x') | 
 |         b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) | 
 |         output : torch.fx.Node = graph.output(b) | 
 |  | 
 |         with graph.inserting_before(b): | 
 |             neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,)) | 
 |             _, *relu_args = b.args | 
 |             b.args = (neg, *relu_args) | 
 |  | 
 |         gm = torch.fx.GraphModule(torch.nn.Module(), graph) | 
 |  | 
 |         input = torch.randn(33, 44) | 
 |         self.assertEqual(gm(input), torch.relu(torch.neg(input))) | 
 |  | 
 |  | 
 |     def test_move_before(self): | 
 |         graph : torch.fx.Graph = torch.fx.Graph() | 
 |         x : torch.fx.Node = graph.create_node('placeholder', 'x') | 
 |         b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) | 
 |         output : torch.fx.Node = graph.output(b) | 
 |  | 
 |         neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,)) | 
 |         _, *relu_args = b.args | 
 |         b.args = (neg, *relu_args) | 
 |         b.prepend(neg) | 
 |  | 
 |         gm = torch.fx.GraphModule(torch.nn.Module(), graph) | 
 |  | 
 |         input = torch.randn(33, 44) | 
 |         self.assertEqual(gm(input), torch.relu(torch.neg(input))) | 
 |  | 
 |     def test_erase_node_error(self): | 
 |         st = SimpleTest() | 
 |         traced = symbolic_trace(st) | 
 |  | 
 |         for node in traced.graph.nodes: | 
 |             # Test deleting with uses both in another Node and at the output | 
 |             if node.target in [operator.add, torch.relu]: | 
 |                 with self.assertRaisesRegex(RuntimeError, 'but it still had .* users in the graph'): | 
 |                     traced.graph.erase_node(node) | 
 |  | 
 |     def test_copy_it(self): | 
 |         d = immutable_dict([(3, 4), (5, 6)]) | 
 |         l = immutable_list([(3, 4), (5, 6)]) | 
 |  | 
 |         self.assertEqual(d, deepcopy(d)) | 
 |         self.assertEqual(l, deepcopy(l)) | 
 |  | 
 |     def test_find_uses(self): | 
 |         graph = torch.fx.Graph() | 
 |         x = torch.fx.Proxy(graph.placeholder('x')) | 
 |  | 
 |         y = torch.relu(x) | 
 |         z = x + x | 
 |         u = torch.neg(x) | 
 |         graph.output((y + z + u).node) | 
 |         graph.lint() | 
 |  | 
 |         users_of_x = x.node.users | 
 |         self.assertEqual(len(users_of_x), 3) | 
 |         expected_ops = set(['relu', 'add', 'neg']) | 
 |         for use in users_of_x: | 
 |             assert any(use.name.startswith(prefix) for prefix in expected_ops) | 
 |  | 
 |     def test_inline_graph(self): | 
 |         class InlineInto(torch.nn.Module): | 
 |             def forward(self, x): | 
 |                 return torch.relu(x) | 
 |  | 
 |         class ToInline(torch.nn.Module): | 
 |             def forward(self, x): | 
 |                 return torch.neg(x) | 
 |  | 
 |         inline_into = symbolic_trace(InlineInto()) | 
 |         to_inline = symbolic_trace(ToInline()) | 
 |  | 
 |         combined_graph = torch.fx.Graph() | 
 |         output_node = combined_graph.graph_copy(inline_into.graph, {}) | 
 |  | 
 |         input_node = list(to_inline.graph.nodes)[0] | 
 |         assert input_node and input_node.op == 'placeholder' | 
 |  | 
 |         val_map = {input_node : output_node} | 
 |         output = combined_graph.graph_copy(to_inline.graph, val_map) | 
 |         combined_graph.output(output) | 
 |  | 
 |         combined_module = torch.fx.GraphModule(torch.nn.Module(), combined_graph) | 
 |  | 
 |         input = torch.rand(3, 4) | 
 |         self.assertEqual(combined_module(input), input.relu().neg()) | 
 |  | 
 |     def test_multi_insert_point(self): | 
 |         graph = torch.fx.Graph() | 
 |         x = torch.fx.Proxy(graph.placeholder('x')) | 
 |         relu = torch.relu(x) | 
 |  | 
 |         with graph.inserting_before(relu.node): | 
 |             y = torch.neg(x) | 
 |             z = torch.tanh(y) | 
 |  | 
 |         graph.output((relu.node, z.node)) | 
 |         graph.lint() | 
 |  | 
 |         expected_ops = ['x', 'neg', 'tanh', 'relu'] | 
 |         for node, expected in zip(graph.nodes, expected_ops): | 
 |             assert expected in node.name | 
 |  | 
 |     def test_reassign_args_kwargs_uses(self): | 
 |         graph = torch.fx.Graph() | 
 |         x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y')) | 
 |         z = x + y | 
 |         zed = z + z + z | 
 |         graph.output(zed.node) | 
 |         graph.lint() | 
 |  | 
 |         # zed = z + z + z -> zed = z + z + x | 
 |         zed.node.args = (zed.node.args[0], x.node) | 
 |         self.assertEqual(x.node.users.keys(), [z.node, zed.node]) | 
 |  | 
 |         # z = x + y -> z = y + y | 
 |         z.node.args = (y.node, y.node) | 
 |         self.assertEqual(x.node.users.keys(), [zed.node]) | 
 |  | 
 |     def test_trace_function(self): | 
 |         def foo(x, y): | 
 |             return torch.relu(x) + y | 
 |  | 
 |         x, y = torch.randn(3, 4), torch.randn(3, 4) | 
 |         self.checkGraphModule(foo, (x, y)) | 
 |  | 
 |     def test_direct_param_use(self): | 
 |         class TransposeTest(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.b = torch.nn.Parameter(torch.rand(4, 3)) | 
 |  | 
 |             def forward(self, x): | 
 |                 return self.b | 
 |  | 
 |         class Foo(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.a = TransposeTest() | 
 |  | 
 |             def forward(self, x): | 
 |                 return self.a.b, self.a.b.t(), self.a.b.view(12) | 
 |  | 
 |         traced = torch.fx.symbolic_trace(Foo()) | 
 |         assert(all('constant' not in node.target for node in traced.graph.nodes)) | 
 |  | 
 |     def test_single_default_arg(self): | 
 |         class M(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |  | 
 |             def forward(self, y=1): | 
 |                 return y | 
 |  | 
 |         m = M() | 
 |         self.checkGraphModule(m, ()) | 
 |         self.checkGraphModule(m, (3,)) | 
 |  | 
 |     def test_multiple_default_args(self): | 
 |         class M(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |  | 
 |             def forward(self, y=1, z=2): | 
 |                 return y + z | 
 |  | 
 |         m = M() | 
 |         self.checkGraphModule(m, ()) | 
 |         self.checkGraphModule(m, (3,)) | 
 |         self.checkGraphModule(m, (3, 4)) | 
 |  | 
 |     def test_regular_and_default_args(self): | 
 |         class M(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |  | 
 |             def forward(self, x, y=1): | 
 |                 return x + y | 
 |  | 
 |         m = M() | 
 |         self.checkGraphModule(m, (2,)) | 
 |         self.checkGraphModule(m, (2, 3)) | 
 |  | 
 |     def test_string_literal_return(self): | 
 |         class M(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |  | 
 |             def forward(self): | 
 |                 return "foo" | 
 |  | 
 |         m = M() | 
 |         self.checkGraphModule(m, ()) | 
 |  | 
 |  | 
 | if __name__ == '__main__': | 
 |     run_tests() |