| import builtins |
| import contextlib |
| import copy |
| import functools |
| import math |
| import numbers |
| import operator |
| import os |
| import pickle |
| import sys |
| import torch |
| import traceback |
| import warnings |
| import unittest |
| from math import sqrt |
| from pathlib import Path |
| from torch.multiprocessing import Process |
| from torch.testing import FileCheck |
| from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap |
| from torch.fx.node import Target, Argument |
| from torch.fx.passes import shape_prop |
| from torch.fx.immutable_collections import immutable_dict, immutable_list |
| from torch.fx.experimental.rewriter import RewritingTracer |
| from copy import deepcopy |
| |
| from torch.fx.proxy import TraceError |
| |
| from fx.quantization import Quantizer |
| from fx.test_subgraph_rewriter import TestSubgraphRewriter # noqa: F401 |
| from fx.test_dce_pass import TestDCE # noqa: F401 |
| |
| from typing import Any, Callable, Dict, NamedTuple, List, Optional, Tuple, Union |
| from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, IS_WINDOWS, IS_SANDCASTLE, IS_MACOS |
| from torch.testing._internal.jit_utils import JitTestCase |
| |
| from fx.named_tup import MyNamedTup |
| |
| try: |
| from torchvision.models import resnet18 |
| HAS_TORCHVISION = True |
| except ImportError: |
| HAS_TORCHVISION = False |
| skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") |
| |
| class SimpleTest(torch.nn.Module): |
| def forward(self, x): |
| return torch.relu(x + 3.0) |
| |
| def a_non_torch_leaf(a, b): |
| return a + b |
| |
| # Test wrap() passing both a function name as well as a function |
| # directly |
| def a_lifted_leaf(a, b): |
| return a[0] + a[1] + b |
| |
| wrap('a_lifted_leaf') |
| # Test wrapping twice doesn't break anything |
| wrap('a_lifted_leaf') |
| |
| def a_lifted_leaf2(a, b): |
| return a[0] + a[1] + b |
| |
| wrap(a_lifted_leaf2) |
| |
| wrap('len') |
| |
| @wrap |
| def wrapped_via_decorator(a): |
| return a + 1 |
| |
| |
| real_wrapped_via_decorator = wrapped_via_decorator |
| real_a_lifed_leaf = a_lifted_leaf |
| real_a_lifed_leaf2 = a_lifted_leaf2 |
| _sqrt = sqrt |
| |
| wrap('wrapper_fn') |
| |
| def wrapper_fn(x): |
| return torch.foo(x) |
| |
| class Pair(NamedTuple): |
| x : torch.Tensor |
| y : torch.Tensor |
| |
| class TestFX(JitTestCase): |
| def setUp(self): |
| if TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS: |
| return |
| torch_root = Path(__file__).resolve().parent.parent |
| p = torch_root / 'build' / 'lib' / 'libtorchbind_test.so' |
| torch.ops.load_library(str(p)) |
| |
| def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None): |
| """Check that an nn.Module's results match the GraphModule version |
| for a given set of args/kwargs. |
| """ |
| kwargs = kwargs if kwargs else {} |
| ref_outs = m(*args, **kwargs) |
| gm = symbolic_trace(m) |
| gm.graph.lint() |
| test_outs = gm(*args, **kwargs) |
| self.assertEqual(ref_outs, test_outs) |
| |
| def test_graph_module(self): |
| class MySub(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.w = torch.nn.Parameter(torch.rand(4, 3)) |
| |
| def forward(self, x): |
| return self.w + x |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.lin = torch.nn.Linear(4, 3) |
| self.sub_mod = MySub() |
| self.w = torch.nn.Parameter(torch.rand(3)) |
| |
| def forward(self, A, B, c): |
| t = torch.sigmoid(A) + self.lin(c) |
| return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3)) |
| |
| m = MyModule() |
| gm = symbolic_trace(m) |
| |
| ms = torch.jit.script(gm) |
| |
| class M2(torch.nn.Module): |
| def forward(self, A): |
| m, idx = torch.max(A, 0) |
| return m + 1, idx + 1 |
| |
| m2 = M2() |
| gm2 = symbolic_trace(m2) |
| |
| class T(torch.nn.Module): |
| |
| def forward(self, A, b=4, *args, c=5, **kwargs): |
| x = A + 1 + args[0] + kwargs['3'] |
| return x |
| |
| t = T() |
| symbolic_trace(t) |
| |
| def test_custom_import(self): |
| graph = torch.fx.Graph() |
| a = graph.placeholder('x') |
| b = graph.placeholder('y') |
| c = graph.call_function(a_non_torch_leaf, (a, b)) |
| d = graph.call_function(torch.sin, (c,)) |
| graph.output(d) |
| gm = GraphModule(torch.nn.Module(), graph) |
| x, y = torch.rand(1), torch.rand(1) |
| self.assertEqual(torch.sin(x + y), gm(x, y)) |
| |
| def test_args_kwargs(self): |
| class T(torch.nn.Module): |
| def forward(self, *args, **kwargs): |
| x = args[0] + kwargs['foo'] |
| return x |
| |
| t = T() |
| self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)}) |
| |
| def test_args_kwargs_no_self(self): |
| class T(torch.nn.Module): |
| def forward(*args, **kwargs): # noqa: B902 |
| self = args[0] |
| return torch.relu(args[1]) |
| |
| t = T() |
| with self.assertRaisesRegex(RuntimeError, r'cannot be part of \*args expansion'): |
| self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)}) |
| |
| def test_fx_shifts(self): |
| class MyModule(torch.nn.Module): |
| def forward(self, x): |
| return x << 3, x >> 3 |
| |
| input = torch.LongTensor(10).random_(0, 1024) |
| |
| m = MyModule() |
| self.checkGraphModule(m, (input,)) |
| |
| def test_dict(self): |
| class MyDictMod(torch.nn.Module): |
| def forward(self, d): |
| return d['3'].relu(), {'4' : d['3'].neg()} |
| |
| input_dict = {'3': torch.rand(3, 4)} |
| m = MyDictMod() |
| |
| self.checkGraphModule(m, (input_dict,)) |
| |
| def test_disallow_override(self): |
| # Custom delegate to disallow in-place tensor operations |
| class NoMutableCallTracer(Tracer): |
| def create_node(self, kind : str, target : Union[str, Callable], |
| args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None, |
| type_expr : Optional[Any] = None) -> Node: |
| name = target if isinstance(target, str) else torch.typename(target) |
| if name[-1] == '_': |
| raise RuntimeError('In-place operations are not supported') |
| return super().create_node(kind, target, args, kwargs, name) |
| |
| # Test method |
| class MyInplaceMod(torch.nn.Module): |
| def forward(self, x): |
| x.add_(3.0) |
| return x |
| |
| m = MyInplaceMod() |
| |
| with self.assertRaisesRegex(RuntimeError, 'In-place operations'): |
| NoMutableCallTracer().trace(m) |
| |
| # Test free function |
| class MyInplaceMod2(torch.nn.Module): |
| def forward(self, x): |
| torch.log_(x) |
| return x |
| m2 = MyInplaceMod2() |
| with self.assertRaisesRegex(RuntimeError, 'In-place operations'): |
| NoMutableCallTracer().trace(m2) |
| |
| # Test symbolic node as an arg |
| class MyInplaceMod3(torch.nn.Module): |
| def forward(self, x): |
| y = torch.ones(3, 4) |
| y.add_(x) |
| return x |
| m3 = MyInplaceMod3() |
| with self.assertRaisesRegex(RuntimeError, 'In-place operations'): |
| NoMutableCallTracer().trace(m3) |
| |
| def test_leaf_module(self): |
| # Custom delegate to make it so that there are no leaf modules, everything |
| # should get traced through |
| class NoLeafModulesTracer(Tracer): |
| def is_leaf_module(self, m, qualname): |
| return False |
| |
| class MyReluMod(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| return self.relu(x) |
| |
| mrm = MyReluMod() |
| sym = NoLeafModulesTracer().trace(mrm) |
| for node in sym.nodes: |
| self.assertNotEqual(node.op, 'call_module') |
| sym.lint() |
| |
| def test_wrap(self): |
| self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5)) |
| |
| def to_trace(y): |
| return a_lifted_leaf((4, y), 3) + a_lifted_leaf((3, 4), 5) + a_lifted_leaf((y, y), y) |
| |
| m = symbolic_trace(to_trace) |
| self.assertIn('a_lifted_leaf', m.code) |
| self.assertEqual(27, m(2)) |
| self.assertIs(a_lifted_leaf, real_a_lifed_leaf) |
| |
| def test_wrap_fn_directly(self): |
| self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5)) |
| |
| def to_trace(y): |
| return a_lifted_leaf2((4, y), 3) + a_lifted_leaf2((3, 4), 5) + a_lifted_leaf2((y, y), y) |
| |
| m = symbolic_trace(to_trace) |
| self.assertIn('a_lifted_leaf2', m.code) |
| self.assertEqual(27, m(2)) |
| self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2) |
| |
| def test_wrapped_via_decorator(self): |
| self.assertEqual(wrapped_via_decorator(0), 1) |
| |
| def to_trace(y): |
| return wrapped_via_decorator(y) |
| |
| m = symbolic_trace(to_trace) |
| self.assertIn('wrapped_via_decorator', m.code) |
| self.assertEqual(m(0), 1) |
| self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator) |
| self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched")) |
| |
| def test_graph_edit_with_proxy(self): |
| class M(torch.nn.Module): |
| def forward(self, a, b): |
| return a + b |
| m = M() |
| g = symbolic_trace(m).graph |
| new_g = torch.fx.Graph() |
| val_map : Dict[Node, Node] = {} |
| output_val = new_g.graph_copy(g, val_map) |
| t = Proxy(output_val) |
| # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. |
| new_g.output((t + t).node) |
| gm = GraphModule(m, new_g) |
| gm.graph.lint() |
| self.assertEqual(gm(3, 4), 14) |
| |
| def test_graph_unique_names(self): |
| class M(torch.nn.Module): |
| def forward(self, a, b): |
| return a + b |
| m = M() |
| g = symbolic_trace(m).graph |
| new_g = torch.fx.Graph() |
| val_map : Dict[Node, Node] = {} |
| output_val = new_g.graph_copy(g, val_map) |
| t = Proxy(output_val) |
| # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules. |
| new_g.output((t + t).node) |
| gm = GraphModule(m, new_g) |
| seen_names : Set[str] = set() |
| for node in gm.graph.nodes: |
| assert node.name not in seen_names |
| seen_names.add(node.name) |
| |
| def test_stack_traces(self): |
| class M(torch.nn.Module): |
| def forward(self, a, b): |
| return a + b |
| |
| tracer = torch.fx.Tracer() |
| tracer.record_stack_traces = True |
| |
| graph = tracer.trace(M()) |
| for node in graph.nodes: |
| if node.op == 'output': |
| continue |
| self.assertTrue(node.stack_trace is not None) |
| assert 'test_fx.py' in node.stack_trace |
| |
| def test_graph_unique_names_manual(self): |
| graph : torch.fx.Graph = torch.fx.Graph() |
| a : torch.fx.Node = graph.create_node('placeholder', 'x') |
| b : torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,), name='foo_1_1') |
| c : torch.fx.Node = graph.create_node('get_attr', 'y_attr', name='foo_1') |
| d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) |
| graph.output(d) |
| graph2 = torch.fx.Graph() |
| val_map : Dict[Node, Node] = {} |
| graph2.graph_copy(graph, val_map) |
| seen_names : Set[str] = set() |
| for node in graph2.nodes: |
| assert node.name not in seen_names |
| seen_names.add(node.name) |
| |
| @skipIfNoTorchVision |
| def test_resnet(self): |
| resnet = resnet18() |
| resnet.train() |
| |
| res_graph = symbolic_trace(resnet) |
| res_script = torch.jit.script(res_graph) |
| |
| ip = torch.rand(1, 3, 224, 224) |
| |
| a = resnet(ip) |
| b = res_graph(ip) |
| c = res_script(ip) |
| self.assertEqual(a, b) |
| self.assertEqual(a, c) |
| |
| quantizer = Quantizer(res_graph) |
| |
| for i in range(10): |
| quantizer.observe((torch.rand(1, 3, 224, 224),)) |
| |
| qgraph = quantizer.quantize() |
| qgraph.graph.lint() |
| qgraph_script = torch.jit.script(qgraph) |
| |
| d = qgraph(ip) |
| e = qgraph_script(ip) |
| |
| assert (a - d).abs().max() < 2 |
| self.assertEqual(d, e) |
| |
| def test_unpack(self): |
| class M(torch.nn.Module): |
| def forward(self, a, b): |
| c, d = a |
| return c + d + b |
| |
| a = (torch.rand(1), torch.rand(1)) |
| b = torch.rand(1) |
| m = M() |
| self.checkGraphModule(m, (a, b)) |
| |
| def test_native_callable(self): |
| if TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS: |
| raise unittest.SkipTest("non-portable load_library call used in test") |
| # This test exercises the case where we use FX to translate from Python |
| # code to some native callable object |
| # |
| # For the purposes of testing, we use ElementwiseInterpreter defined |
| # in test_custom_class.cpp. |
| # |
| # We test that we can |
| # 1) Construct a native callable from FX IR |
| # 2) Construct a drop-in replacement module that delegates to the |
| # native callable rather than the original code |
| # 3) Run both the original code and native callable wrapper with |
| # equivalent results |
| # 4) TorchScript compile the native callable wrapper and confirm |
| # equivalent results with the reference |
| # 5) TorchScript serialize and deserialize the native callable |
| # and confirm equivalent results with the reference |
| |
| # We use this simple Module as a reference computation |
| class MySimpleMod(torch.nn.Module): |
| def forward(self, x): |
| return 3.0 * x + x |
| |
| msm = MySimpleMod() |
| |
| # This is what a lowering pass might look like: a function that takes |
| # a valid nn.Module, symbolically traces it, lowers the Module to some |
| # representation, and wraps that representation up into another |
| # nn.Module instance that handles dispatch to the compiled/lowered code. |
| def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Module: |
| # ===== Stage 1: Symbolic trace the module ===== |
| mod = symbolic_trace(orig_mod) |
| |
| # ===== Stage 2: Lower GraphModule representation to the C++ |
| # interpreter's instruction format ====== |
| instructions = [] |
| constant_idx = 0 |
| constants = {} |
| fn_input_names = [] |
| |
| target_to_name = { |
| operator.add : "add", |
| operator.mul : "mul" |
| } |
| |
| output_node : Optional[Node] = None |
| # For each instruction, create a triple |
| # (instruction_name : str, inputs : List[str], output : str) |
| # to feed into the C++ interpreter |
| for n in mod.graph.nodes: |
| target, args, out_name = n.target, n.args, n.name |
| assert len(n.kwargs) == 0, "kwargs currently not supported" |
| |
| if n.op == 'placeholder': |
| # Placeholders specify function argument names. Save these |
| # for later when we generate the wrapper GraphModule |
| fn_input_names.append(target) |
| elif n.op == 'call_function': |
| assert target in target_to_name, "Unsupported call target " + target |
| arg_names = [] |
| for arg in args: |
| if not isinstance(arg, Node): |
| # Pull out constants. These constants will later be |
| # fed to the interpreter C++ object via add_constant() |
| arg_name = f'constant_{constant_idx}' |
| constants[arg_name] = torch.Tensor( |
| [arg] if isinstance(arg, numbers.Number) else arg) |
| arg_names.append(arg_name) |
| constant_idx += 1 |
| else: |
| arg_names.append(arg.name) |
| instructions.append((target_to_name[target], arg_names, out_name)) |
| elif n.op == 'output': |
| if output_node is not None: |
| raise RuntimeError('Multiple output nodes!') |
| output_node = n |
| else: |
| raise RuntimeError('Unsupported opcode ' + n.op) |
| |
| interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter() |
| # Load constants |
| for k, v in constants.items(): |
| interpreter.add_constant(k, v) |
| # Specify names for positional input arguments |
| interpreter.set_input_names(fn_input_names) |
| # Load instructions |
| interpreter.set_instructions(instructions) |
| # Specify name for single output |
| assert isinstance(output_node.args[0], torch.fx.Node) |
| interpreter.set_output_name(output_node.args[0].name) |
| |
| # ===== Stage 3: Create a wrapper GraphModule around the interpreter ===== |
| class WrapperModule(torch.nn.Module): |
| def __init__(self, interpreter): |
| super().__init__() |
| self.interpreter = interpreter |
| |
| wrapper = WrapperModule(interpreter) |
| |
| # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter |
| # 3) Returns the speficied return value |
| |
| # FIXME: The following code could be greatly simplified by symbolic_trace'ing |
| # the wrapper with a Tracer that considers the Wrapper instance a root |
| # module, however, I can't get `__call__` exposed on TorchBind classes |
| # without it messing up Python `hasattr` for some reason. More digging |
| # into CPython's implementation of hasattr is probably in order... |
| |
| graph = torch.fx.Graph() |
| # Add placeholders for fn inputs |
| placeholder_nodes = [] |
| for name in fn_input_names: |
| placeholder_nodes.append(graph.create_node('placeholder', name)) |
| |
| # Get the interpreter object |
| interpreter_node = graph.create_node('get_attr', 'interpreter') |
| |
| # Add a node to call the interpreter instance |
| output_node = graph.create_node( |
| op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes)) |
| |
| # Register output |
| graph.output(output_node) |
| |
| graph.lint() |
| |
| # Return final GraphModule!!! |
| return GraphModule(wrapper, graph) |
| |
| |
| # Lower GraphModule to C++ interpreter |
| lowered = lower_to_elementwise_interpreter(msm) |
| |
| # Compare correctness with original module |
| x = torch.rand(3, 4) |
| ref_out = msm(x) |
| test_out = lowered(x) |
| torch.testing.assert_allclose(test_out, ref_out) |
| |
| # Test TorchScript compilation |
| scripted_lowered = torch.jit.script(lowered) |
| script_out = scripted_lowered(x) |
| torch.testing.assert_allclose(script_out, ref_out) |
| |
| # Test TorchScript ser/de |
| import_copy = self.getExportImportCopy(scripted_lowered) |
| imported_out = import_copy(x) |
| torch.testing.assert_allclose(imported_out, ref_out) |
| |
| def test_reserved_getattr(self): |
| """Ensure that we do not name any nodes with a reserved builtin like `getattr`""" |
| class M(torch.nn.Module): |
| def forward(self, a): |
| return a.foo.bar.baz |
| |
| m = M() |
| m_g = symbolic_trace(m) |
| m_g.graph.lint() |
| for node in m_g.graph.nodes: |
| self.assertTrue(node.name != "getattr") |
| |
| def test_node_tagging(self): |
| class TaggingTracer(Tracer): |
| def create_node(self, kind : str, target : Union[str, Callable], |
| args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None, |
| type_expr : Optional[Any] = None) -> Node: |
| n = super().create_node(kind, target, args, kwargs, name) |
| n.tag = 'foo' |
| return n |
| |
| class M(torch.nn.Module): |
| def forward(self, a, b): |
| return a + b |
| |
| m = M() |
| g = TaggingTracer().trace(m) |
| g.lint() |
| for n in g.nodes: |
| self.assertTrue(hasattr(n, 'tag')) |
| self.assertEqual(n.tag, 'foo') |
| |
| def test_tensor_attribute(self): |
| class TensorAttribute(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.tensor = torch.rand(3, 4) |
| |
| def forward(self, x): |
| return torch.nn.functional.linear(x, self.tensor) |
| |
| ta = TensorAttribute() |
| traced = symbolic_trace(ta) |
| traced(torch.rand(4, 4)) |
| |
| class WrapperForQualname(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.ta = TensorAttribute() |
| |
| def forward(self, x): |
| return torch.nn.functional.linear(x, self.ta.tensor) |
| |
| wfq = WrapperForQualname() |
| traced2 = symbolic_trace(wfq) |
| traced2.graph.lint() |
| traced2(torch.rand(4, 4)) |
| |
| def test_symbolic_trace_sequential(self): |
| class Simple(torch.nn.Module): |
| def forward(self, x): |
| return torch.neg(x) |
| |
| seq = torch.nn.Sequential( |
| Simple(), |
| Simple(), |
| Simple() |
| ) |
| traced = symbolic_trace(seq) |
| traced.graph.lint() |
| x = torch.rand(3, 4) |
| self.assertEqual(traced(x), seq(x)) |
| |
| def test_tensor_constant(self): |
| class ConstTensor(torch.nn.Module): |
| def forward(self, x): |
| return torch.nn.functional.linear(x, torch.zeros(3, 4)) |
| |
| ct = ConstTensor() |
| traced = symbolic_trace(ct) |
| traced.graph.lint() |
| traced(torch.rand(4, 4)) |
| |
| def test_pickle_graphmodule(self): |
| class Nested(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.st = torch.nn.Linear(4, 4) |
| |
| def forward(self, x): |
| return self.st(x) |
| |
| n = Nested() |
| traced = symbolic_trace(n) |
| traced.graph.lint() |
| pickled = pickle.dumps(traced) |
| loaded = pickle.loads(pickled) |
| loaded.graph.lint() |
| x = torch.rand(3, 4) |
| self.assertEqual(loaded(x), traced(x)) |
| |
| def test_pickle_custom_import(self): |
| graph = torch.fx.Graph() |
| a = graph.placeholder('x') |
| b = graph.placeholder('y') |
| c = graph.call_function(a_non_torch_leaf, (a, b)) |
| d = graph.call_function(torch.sin, (c,)) |
| graph.output(d) |
| gm = GraphModule(torch.nn.Module(), graph) |
| pickled = pickle.dumps(gm) |
| loaded = pickle.loads(pickled) |
| loaded.graph.lint() |
| x, y = torch.rand(1), torch.rand(1) |
| self.assertEqual(loaded(x, y), gm(x, y)) |
| |
| def test_all_input_nodes(self): |
| graph : torch.fx.Graph = torch.fx.Graph() |
| a : torch.fx.Node = graph.placeholder('x') |
| b : torch.fx.Node = graph.call_module('linear_mod', args=(a,)) |
| c : torch.fx.Node = graph.get_attr('y_attr') |
| d : torch.fx.Node = graph.call_function(operator.add, args=(b, c)) |
| e : torch.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0)) |
| graph.output(e) |
| graph.lint() |
| |
| self.assertEqual(b.all_input_nodes, [a]) |
| self.assertEqual(c.all_input_nodes, []) |
| self.assertEqual(d.all_input_nodes, [b, c]) |
| self.assertEqual(e.all_input_nodes, [d]) |
| |
| def test_deepcopy_graphmodule_with_transform(self): |
| st = SimpleTest() |
| traced = symbolic_trace(st) |
| traced.graph.lint() |
| |
| def transform(traced): |
| new_graph = torch.fx.Graph() |
| val_map : Dict[Node, Node] = {} |
| output_value = new_graph.graph_copy(traced.graph, val_map) |
| relu_out = new_graph.create_node( |
| op='call_method', target='neg', args=(output_value,), kwargs={}) |
| new_graph.output(relu_out) |
| return GraphModule(traced, new_graph) |
| transformed = transform(traced) |
| transformed.graph.lint() |
| copied = copy.deepcopy(transformed) |
| self.assertNotEqual(id(type(transformed)), id(type(copied))) |
| x = torch.randn(3, 4) |
| self.assertEqual(copied(x), transformed(x)) |
| |
| def test_deepcopy_with_submods_params(self): |
| class Bar(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.param = torch.nn.Parameter(torch.rand(3, 4)) |
| |
| def forward(self, x): |
| return torch.relu(x) + self.param |
| |
| class Baz(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.param = torch.nn.Parameter(torch.rand(3, 4)) |
| self.bar = Bar() |
| |
| def forward(self, x): |
| return self.bar(x) - self.param |
| |
| baz = Baz() |
| traced = symbolic_trace(baz) |
| traced.graph.lint() |
| copied = copy.deepcopy(traced) |
| copied.graph.lint() |
| |
| def test_unpack_list_better_error(self): |
| class SomeArgs(torch.nn.Module): |
| def forward(self, a, b): |
| return torch.rand(3, 4) |
| |
| class UnpacksList(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sa = SomeArgs() |
| |
| def forward(self, x : list): |
| return self.sa(*x) |
| |
| ul = UnpacksList() |
| with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'): |
| symbolic_trace(ul) |
| |
| def test_unpack_dict_better_error(self): |
| class SomeKwargs(torch.nn.Module): |
| def forward(self, x=3, y=4): |
| return torch.rand(3, 4) |
| |
| class UnpacksDict(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sk = SomeKwargs() |
| |
| def forward(self, x : dict): |
| return self.sk(**x) |
| |
| ud = UnpacksDict() |
| with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'): |
| symbolic_trace(ud) |
| |
| def test_pretty_print_targets(self): |
| # Test that Graph pretty-print prints friendly name for targets |
| # in `operator` and `builtins` |
| |
| class SomeMod(torch.nn.Module): |
| def forward(self, x): |
| return torch.add(x.foo + x.bar, 3.0) |
| |
| traced = symbolic_trace(SomeMod()) |
| graph_str = str(traced.graph) |
| self.assertIn('builtins.getattr', graph_str) |
| self.assertIn('operator.add', graph_str) |
| self.assertIn('torch.add', graph_str) |
| |
| def test_pretty_print_node(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.param: torch.nn.Parameter = torch.nn.Parameter( |
| torch.rand(3, 4)) |
| self.linear = torch.nn.Linear(4, 5) |
| |
| def forward(self, x: torch.Tensor, y: int = 2): |
| return self.linear(x[y] + self.param).clamp(min=0.0, max=1.0) |
| |
| traced = symbolic_trace(M()) |
| |
| all_formatted = "\n".join([n.format_node() for n in traced.graph.nodes]) |
| |
| FileCheck().check("x").check("placeholder") \ |
| .check("y").check("placeholder") \ |
| .check("getitem").check("call_function") \ |
| .check("param").check("get_attr") \ |
| .check("add").check("call_function") \ |
| .check("linear").check("call_module") \ |
| .check("clamp").check("call_method") \ |
| .run(all_formatted) |
| |
| def test_script_tensor_constant(self): |
| # TorchScript seems to ignore attributes that start with `__`. |
| # We used to call anonymous Tensor values `__tensor_constant*`, but |
| # they were getting ignored by script. Now they're called |
| # `_tensor_constant*` |
| class IHaveATensorConstant(torch.nn.Module): |
| def forward(self, x): |
| return x + torch.rand(3, 4) |
| |
| traced = torch.fx.symbolic_trace(IHaveATensorConstant()) |
| torch.jit.script(traced) |
| |
| def test_torch_fx_len(self): |
| class FXLenTest(torch.nn.Module): |
| def forward(self, x): |
| return len(x) |
| |
| traced = symbolic_trace(FXLenTest()) |
| self.assertEqual(traced(torch.rand(3, 4)), 3) |
| |
| # Test scriptability |
| scripted = torch.jit.script(FXLenTest()) |
| self.assertEqual(scripted(torch.rand(3)), 3) |
| |
| traced_scripted = torch.jit.script(traced) |
| self.assertEqual(traced_scripted(torch.rand(3)), 3) |
| |
| # Test non-proxy len |
| class FXLenTest2(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.l = [3, 4, 5] |
| |
| def forward(self, x): |
| return x + len(self.l) |
| |
| traced2 = symbolic_trace(FXLenTest2()) |
| inp = torch.rand(3, 4) |
| self.assertEqual(traced2(inp), inp + 3.0) |
| self.assertIs(len, builtins.len) |
| |
| def test_sqrt(self): |
| class Sqrt1(torch.nn.Module): |
| def forward(self, x): |
| return sqrt(x.size(0)) |
| |
| class Sqrt2(torch.nn.Module): |
| def forward(self, x): |
| return math.sqrt(x.size(0)) |
| |
| class Sqrt3(torch.nn.Module): |
| def forward(self, x): |
| return x + math.sqrt(2) + sqrt(2) |
| |
| self.checkGraphModule(Sqrt1(), [torch.zeros(8)]) |
| self.checkGraphModule(Sqrt2(), [torch.zeros(8)]) |
| self.checkGraphModule(Sqrt3(), [torch.zeros(8)]) |
| self.assertIs(sqrt, _sqrt) |
| self.assertIs(math.sqrt, _sqrt) |
| |
| def test_torch_custom_ops(self): |
| class M(torch.nn.Module): |
| def forward(self, a): |
| b = torch.ops.aten.sigmoid(a) |
| c = torch.ops.aten.cat([a, b]) |
| return torch.ops.aten.cat((c, c)) |
| m = M() |
| input = torch.randn(3) |
| ref_out = m(input) |
| gm = symbolic_trace(m) |
| gm.graph.lint() |
| out = gm(input) |
| self.assertEqual(out, ref_out) |
| |
| def test_pickle_torch_custom_ops(self): |
| class M(torch.nn.Module): |
| def forward(self, a): |
| b = torch.ops.aten.sigmoid(a) |
| c = torch.ops.aten.cat([a, b]) |
| return torch.ops.aten.cat((c, c)) |
| m = M() |
| input = torch.randn(3) |
| ref_out = m(input) |
| gm = symbolic_trace(m) |
| gm.graph.lint() |
| pickled = pickle.dumps(gm) |
| loaded = pickle.loads(pickled) |
| self.assertEqual(loaded(input), gm(input)) |
| |
| def test_pretty_print(self): |
| st = SimpleTest() |
| traced = symbolic_trace(st) |
| traced.graph.lint() |
| printed = str(traced) |
| assert 'SimpleTest()' in printed |
| assert 'torch.relu' in printed |
| |
| def test_pretty_print_graph(self): |
| class KwargPrintTest(torch.nn.Module): |
| def forward(self, x): |
| return torch.squeeze(x + 3.0, dim=2) |
| st = KwargPrintTest() |
| traced = symbolic_trace(st) |
| traced.graph.lint() |
| stringed = str(traced.graph) |
| for s in ['args', 'kwargs', '#users']: |
| assert s in stringed |
| |
| def test_graph_fns(self): |
| g = Graph() |
| a = g.placeholder('a') |
| b = g.call_module('linear', (a,)) |
| c = g.get_attr('bias') |
| d = g.call_method('add', (b, c)) |
| e = g.call_function(torch.sin, (d,)) |
| g.output(e) |
| mod = torch.nn.Module() |
| mod.linear = torch.nn.Linear(3, 4) |
| mod.bias = torch.rand(4) |
| gm = GraphModule(mod, g) |
| gm.graph.lint() |
| input = torch.rand(3) |
| r = gm(input) |
| ref = torch.sin(mod.linear(input) + mod.bias) |
| self.assertEqual(r, ref) |
| |
| def test_remove_uses(self): |
| g : torch.fx.Graph = Graph() |
| x : torch.fx.Node = g.placeholder('x') |
| relu : torch.fx.Node = g.call_function(torch.relu, (x,)) |
| neg : torch.fx.Node = g.call_function(torch.neg, (relu,)) |
| g.output(neg) |
| |
| neg.replace_all_uses_with(relu) |
| g.erase_node(neg) |
| |
| self.assertTrue(neg not in relu.users) |
| |
| def test_nonetype_annotation(self): |
| eb = torch.nn.EmbeddingBag(3, 4) |
| symbolic_trace(eb) |
| |
| def test_pickle_nonetype_annotation(self): |
| eb = torch.nn.EmbeddingBag(10, 3, mode='sum') |
| traced = symbolic_trace(eb) |
| pickled = pickle.dumps(traced) |
| loaded = pickle.loads(pickled) |
| loaded.graph.lint() |
| input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9]) |
| offsets = torch.LongTensor([0, 4]) |
| self.assertEqual(loaded(input, offsets), traced(input, offsets)) |
| |
| def test_return_tuple(self): |
| class M(torch.nn.Module): |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| return (x, x + x) |
| |
| |
| original = M() |
| traced = symbolic_trace(original) |
| self.assertEqual(traced(torch.ones(1)), original.forward(torch.ones(1))) |
| |
| def test_construct_root_dict(self): |
| graph : torch.fx.Graph = torch.fx.Graph() |
| a : torch.fx.Node = graph.create_node('placeholder', 'x') |
| b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,)) |
| c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam') |
| d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) |
| graph.output(d) |
| |
| linear_mod : torch.nn.Module = torch.nn.Linear(3, 4) |
| add_param : torch.Tensor = torch.rand(3, 4) |
| gm : torch.fx.GraphModule = torch.fx.GraphModule( |
| {'foo.bar.baz': linear_mod, 'zip.zap.zam' : add_param}, graph) |
| gm.graph.lint() |
| |
| assert 'self.foo.bar.baz' in gm.code |
| |
| x : torch.Tensor = torch.rand(3, 3) |
| out : torch.Tensor = gm(x) |
| ref_out : torch.Tensor = linear_mod(x) + add_param |
| self.assertEqual(out, ref_out) |
| |
| def test_symbolic_trace_assert(self): |
| |
| class AssertsTensorShape(torch.nn.Module): |
| def forward(self, x): |
| torch._assert(x.shape[1] > 4, "assert_foobar") |
| return x |
| |
| m = AssertsTensorShape() |
| # verify traceability |
| traced = symbolic_trace(m) |
| # verify assertion on traced model works correctly at runtime |
| traced(torch.rand(4, 5)) |
| with self.assertRaisesRegex(AssertionError, "assert_foobar"): |
| traced(torch.rand(4, 3)) |
| # verify the symbolically traced module is scriptable |
| ms = torch.jit.script(m) |
| with self.assertRaisesRegex(torch.jit.Error, "assert_foobar"): |
| ms(torch.rand(4, 3)) |
| |
| def test_trace_fn_constant(self): |
| some_constant = torch.rand(3, 4) |
| |
| def add_const(x): |
| return some_constant + x |
| |
| traced = symbolic_trace(add_const) |
| |
| input = torch.rand(3, 4) |
| self.assertEqual(traced(input), add_const(input)) |
| |
| def test_copy_no_remap(self): |
| traced = symbolic_trace(SimpleTest()) |
| g = traced.graph |
| copied = torch.fx.Graph() |
| for node in g.nodes: |
| copied.node_copy(node) |
| with self.assertRaisesRegex(RuntimeError, 'does not belong to this Graph'): |
| copied.lint() |
| |
| def test_wrong_topo(self): |
| graph : torch.fx.Graph = torch.fx.Graph() |
| a : torch.fx.Node = graph.create_node('placeholder', 'x') |
| b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,)) |
| c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam') |
| d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c)) |
| graph.output(d) |
| nodes = list(graph.nodes) |
| nodes[3].append(nodes[2]) |
| with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'): |
| graph.lint() |
| |
| def test_example_shape_prop(self): |
| class TestCase(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.attr = torch.randn(3, 4) |
| self.submod = torch.nn.Linear(4, 4) |
| |
| def forward(self, x): |
| return torch.neg(self.submod(x.relu() + self.attr)) |
| tc = TestCase() |
| tc_traced = symbolic_trace(tc) |
| ref_out = tc_traced(torch.rand(3, 4)) |
| shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4)) |
| |
| # Make sure we're testing all opcodes |
| opcodes = set() |
| output_shape : Optional[torch.Shape] = None |
| for node in tc_traced.graph.nodes: |
| opcodes.add(node.op) |
| if node.op == 'output': |
| output_shape = node.args[0].shape |
| self.assertEqual(opcodes, set(['placeholder', 'get_attr', 'call_function', 'call_method', |
| 'call_module', 'output'])) |
| |
| # Test shape propogation and make sure results match actual |
| self.assertEqual(output_shape, ref_out.shape) |
| |
| def test_interpreter(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.param = torch.nn.Parameter(torch.rand(3, 4)) |
| self.linear = torch.nn.Linear(4, 5) |
| |
| def forward(self, x): |
| return self.linear(x + self.param).clamp(min=0.0, max=1.0) |
| |
| m = MyModule() |
| gm = torch.fx.symbolic_trace(m) |
| |
| interpreter = Interpreter(gm) |
| input = torch.randn(3, 4) |
| self.assertEqual(interpreter.run(input), gm(input)) |
| self.assertEqual(interpreter.run(input), m(input)) |
| |
| def test_interpreter_run_node_override(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.param = torch.nn.Parameter(torch.rand(3, 4)) |
| self.linear = torch.nn.Linear(4, 5) |
| |
| def forward(self, x): |
| return self.linear(x + self.param).clamp(min=0.0, max=1.0) |
| |
| m = MyModule() |
| gm = torch.fx.symbolic_trace(m) |
| |
| class RunNodeInterpreter(Interpreter): |
| def __init__(self, module): |
| super().__init__(module) |
| |
| def run_node(self, n : Node) -> Any: |
| result = super().run_node(n) |
| n.cached_value = result |
| return result |
| |
| input = torch.randn(3, 4) |
| RunNodeInterpreter(gm).run(input) |
| for node in gm.graph.nodes: |
| assert hasattr(node, 'cached_value') |
| |
| def test_interpreter_onthefly_swap(self): |
| |
| def fn(x): |
| return torch.sigmoid(x).neg() |
| |
| gm = torch.fx.symbolic_trace(fn) |
| |
| class NegSigmSwapInterpreter(Interpreter): |
| def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: |
| if target == torch.sigmoid: |
| return torch.neg(*args, **kwargs) |
| return super().call_function(n) |
| |
| def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: |
| if target == 'neg': |
| call_self, *args_tail = args |
| return call_self.sigmoid(*args_tail, **kwargs) |
| return super().call_method(n) |
| |
| input = torch.randn(3, 4) |
| result = NegSigmSwapInterpreter(gm).run(input) |
| self.assertEqual(result, torch.neg(input).sigmoid()) |
| |
| def test_interpreter_partial_eval(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.param = torch.nn.Parameter(torch.rand(3, 4)) |
| self.linear = torch.nn.Linear(4, 5) |
| |
| def forward(self, x): |
| return self.linear(x + self.param).clamp(min=0.0, max=1.0) |
| |
| gm = torch.fx.symbolic_trace(MyModule()) |
| interp = Interpreter(gm) |
| env = {} |
| for node in gm.graph.nodes: |
| if node.op == 'call_module' and node.target == 'linear': |
| env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0 |
| break |
| assert len(env) == 1 |
| x = torch.randn(3, 4) |
| result = interp.run(x, initial_env=env) |
| self.assertEqual(result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0)) |
| |
| def test_interpreter_star_args(self): |
| def with_star_args(x, *args): |
| return x + args[0] |
| |
| gm = torch.fx.symbolic_trace(with_star_args) |
| interp = Interpreter(gm) |
| result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4)) |
| self.assertEqual(result, torch.ones(3, 4) * 2.0) |
| |
| @skipIfNoTorchVision |
| def test_interpreter_noop_resnet18(self): |
| rn18 = resnet18() |
| transformed = torch.fx.Transformer(symbolic_trace(rn18)).transform() |
| inp = torch.randn(5, 3, 224, 224) |
| self.assertEqual(transformed(inp), rn18(inp)) |
| |
| def test_transformer_noop(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.param = torch.nn.Parameter(torch.rand(3, 4)) |
| self.linear = torch.nn.Linear(4, 5) |
| |
| def forward(self, x): |
| return self.linear(x + self.param).clamp(min=0.0, max=1.0) |
| |
| m = MyModule() |
| gm = torch.fx.symbolic_trace(m) |
| |
| new_gm = Transformer(gm).transform() |
| |
| input = torch.randn(3, 4) |
| self.assertEqual(new_gm(input), gm(input)) |
| |
| def test_transformer_op_swap(self): |
| |
| def fn(x): |
| return torch.sigmoid(x).neg() |
| |
| gm = torch.fx.symbolic_trace(fn) |
| |
| class NegSigmSwapXformer(Transformer): |
| def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: |
| if target == torch.sigmoid: |
| return torch.neg(*args, **kwargs) |
| return super().call_function(n) |
| |
| def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: |
| if target == 'neg': |
| call_self, *args_tail = args |
| return call_self.sigmoid(*args_tail, **kwargs) |
| return super().call_method(n) |
| |
| transformed = NegSigmSwapXformer(gm).transform() |
| input = torch.randn(3, 4) |
| self.assertEqual(transformed(input), torch.neg(input).sigmoid()) |
| |
| def test_transformer_multi_outputs(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.param = torch.nn.Parameter(torch.rand(3, 4)) |
| self.linear = torch.nn.Linear(4, 5) |
| |
| def forward(self, x): |
| x = x + self.param |
| out = self.linear(x) |
| return x, out |
| |
| m = MyModule() |
| gm = torch.fx.symbolic_trace(m) |
| |
| new_gm = Transformer(gm).transform() |
| |
| input = torch.randn(3, 4) |
| self.assertEqual(new_gm(input), gm(input)) |
| |
| def test_fn_type_annotations(self): |
| class Foo(torch.nn.Module): |
| def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]: |
| return {'a': p.x + p.y + z + i} |
| |
| foo_scripted = torch.jit.script(Foo()) |
| foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) |
| |
| fxed = symbolic_trace(Foo()) |
| fxed_scripted = torch.jit.script(fxed) |
| fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3) |
| |
| def test_fn_type_annotation_empty(self): |
| def forward(a : List[torch.Tensor]): |
| return a[0] |
| torch.jit.script(symbolic_trace(forward)) |
| |
| def test_wrapped_method(self): |
| def wrap_with_relu(fn): |
| @functools.wraps(fn) |
| def wrapper(*args, **kwargs): |
| return torch.relu(fn(*args, **kwargs)) |
| return wrapper |
| |
| class Foo(torch.nn.Module): |
| @wrap_with_relu |
| def forward(self, x, w): |
| return torch.matmul(x, w) |
| |
| f = Foo() |
| traced = symbolic_trace(f) |
| x, w = torch.rand(3, 4), torch.rand(4, 4) |
| self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes)) |
| |
| def test_empty_graph_codegen(self): |
| graph = torch.fx.Graph() |
| gm = torch.fx.GraphModule(torch.nn.Module(), graph) |
| self.assertEqual(gm(), None) |
| |
| def test_sequential(self): |
| m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)) |
| gm = torch.fx.symbolic_trace(m) |
| gm_copy = copy.deepcopy(gm) |
| |
| def test_ctx_mgr(self): |
| @contextlib.contextmanager |
| def do_nothing(): |
| yield |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| @do_nothing() |
| def forward(self, x): |
| return torch.relu(x) |
| |
| m = M() |
| self.checkGraphModule(m, (torch.rand(3, 4),)) |
| |
| def test_typename_print(self): |
| graph : torch.fx.Graph = torch.fx.Graph() |
| x : torch.fx.Node = graph.create_node('placeholder', 'x') |
| b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,), |
| type_expr=List[float]) |
| output : torch.fx.Node = graph.output(b) |
| self.assertTrue('typing.List[float]' in str(graph)) |
| |
| def test_ellipsis(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x, y): |
| return x + y[:, 1:10, ...] |
| |
| traced = symbolic_trace(M()) |
| x, y = torch.rand(5, 9, 3, 4), torch.rand(5, 15, 3, 4) |
| self.assertEqual(traced(x, y), x + y[:, 1:10, ...]) |
| |
| def test_inf_nan(self): |
| class FooMod(torch.nn.Module): |
| def forward(self, x): |
| return x + float('inf'), x + float('-inf'), x + float('nan') |
| |
| fm = FooMod() |
| self.checkGraphModule(fm, (torch.rand(3, 4),)) |
| |
| def test_inf_nan_kwds(self): |
| graph : torch.fx.Graph = torch.fx.Graph() |
| x : torch.fx.Node = graph.create_node('placeholder', 'x') |
| b : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('inf')), {}, name='inf') |
| c : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('nan')), {}, name='nan') |
| graph.output((b, c)) |
| |
| gm = torch.fx.GraphModule(torch.nn.Module(), graph) |
| x = torch.rand(3, 4) |
| self.assertEqual(gm(x), (x + float('inf'), x + float('nan'))) |
| |
| def test_deepcopy_recursion_depth(self): |
| depth = sys.getrecursionlimit() + 20 |
| |
| g = torch.fx.Graph() |
| x = g.placeholder('x') |
| for i in range(depth): |
| x = g.call_function(torch.relu, (x,)) |
| g.output(x) |
| |
| copied_graph = copy.deepcopy(g) |
| |
| val_map = {} |
| for orig_node, new_node in zip(g.nodes, copied_graph.nodes): |
| val_map[orig_node] = new_node |
| |
| for orig_node, new_node in zip(g.nodes, copied_graph.nodes): |
| orig_users = set(orig_node.users.keys()) |
| orig_users_equiv = set(val_map[u] for u in orig_users) |
| new_users = set(new_node.users.keys()) |
| self.assertEqual(orig_users_equiv, new_users) |
| |
| @skipIfNoTorchVision |
| def test_replace_uses(self): |
| rn18 = resnet18() |
| |
| class LowerReluTracer(torch.fx.Tracer): |
| def is_leaf_module(self, m : torch.nn.Module, qualname : str): |
| if isinstance(m, torch.nn.ReLU): |
| return False |
| return super().is_leaf_module(m, qualname) |
| |
| rn18_traced = GraphModule(rn18, LowerReluTracer().trace(rn18)) |
| |
| to_erase = [] |
| for node in rn18_traced.graph.nodes: |
| if node.op == 'call_function' and node.target in [torch.relu, torch.nn.functional.relu]: |
| kwargs = node.kwargs.copy() |
| # Neg doesn't have in-place |
| kwargs.pop('inplace') |
| with rn18_traced.graph.inserting_before(node): |
| new_node = rn18_traced.graph.call_function( |
| the_function=torch.neg, args=node.args, kwargs=node.kwargs) |
| node.replace_all_uses_with(replace_with=new_node) |
| to_erase.append(node) |
| |
| for node in to_erase: |
| rn18_traced.graph.erase_node(node) |
| |
| def test_insertion_point(self): |
| graph : torch.fx.Graph = torch.fx.Graph() |
| x : torch.fx.Node = graph.create_node('placeholder', 'x') |
| b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) |
| output : torch.fx.Node = graph.output(b) |
| |
| with graph.inserting_before(b): |
| neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,)) |
| _, *relu_args = b.args |
| b.args = (neg, *relu_args) |
| |
| gm = torch.fx.GraphModule(torch.nn.Module(), graph) |
| |
| input = torch.randn(33, 44) |
| self.assertEqual(gm(input), torch.relu(torch.neg(input))) |
| |
| |
| def test_move_before(self): |
| graph : torch.fx.Graph = torch.fx.Graph() |
| x : torch.fx.Node = graph.create_node('placeholder', 'x') |
| b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,)) |
| output : torch.fx.Node = graph.output(b) |
| |
| neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,)) |
| _, *relu_args = b.args |
| b.args = (neg, *relu_args) |
| b.prepend(neg) |
| |
| gm = torch.fx.GraphModule(torch.nn.Module(), graph) |
| |
| input = torch.randn(33, 44) |
| self.assertEqual(gm(input), torch.relu(torch.neg(input))) |
| |
| def test_erase_node_error(self): |
| st = SimpleTest() |
| traced = symbolic_trace(st) |
| |
| for node in traced.graph.nodes: |
| # Test deleting with uses both in another Node and at the output |
| if node.target in [operator.add, torch.relu]: |
| with self.assertRaisesRegex(RuntimeError, 'but it still had .* users in the graph'): |
| traced.graph.erase_node(node) |
| |
| def test_copy_it(self): |
| d = immutable_dict([(3, 4), (5, 6)]) |
| l = immutable_list([(3, 4), (5, 6)]) |
| |
| self.assertEqual(d, deepcopy(d)) |
| self.assertEqual(l, deepcopy(l)) |
| |
| def test_find_uses(self): |
| graph = torch.fx.Graph() |
| x = torch.fx.Proxy(graph.placeholder('x')) |
| |
| y = torch.relu(x) |
| z = x + x |
| u = torch.neg(x) |
| graph.output((y + z + u).node) |
| graph.lint() |
| |
| users_of_x = x.node.users |
| self.assertEqual(len(users_of_x), 3) |
| expected_ops = set(['relu', 'add', 'neg']) |
| for use in users_of_x: |
| assert any(use.name.startswith(prefix) for prefix in expected_ops) |
| |
| def test_inline_graph(self): |
| class InlineInto(torch.nn.Module): |
| def forward(self, x): |
| return torch.relu(x) |
| |
| class ToInline(torch.nn.Module): |
| def forward(self, x): |
| return torch.neg(x) |
| |
| inline_into = symbolic_trace(InlineInto()) |
| to_inline = symbolic_trace(ToInline()) |
| |
| combined_graph = torch.fx.Graph() |
| output_node = combined_graph.graph_copy(inline_into.graph, {}) |
| |
| input_node = list(to_inline.graph.nodes)[0] |
| assert input_node and input_node.op == 'placeholder' |
| |
| val_map = {input_node : output_node} |
| output = combined_graph.graph_copy(to_inline.graph, val_map) |
| combined_graph.output(output) |
| |
| combined_module = torch.fx.GraphModule(torch.nn.Module(), combined_graph) |
| |
| input = torch.rand(3, 4) |
| self.assertEqual(combined_module(input), input.relu().neg()) |
| |
| def test_multi_insert_point(self): |
| graph = torch.fx.Graph() |
| x = torch.fx.Proxy(graph.placeholder('x')) |
| relu = torch.relu(x) |
| |
| with graph.inserting_before(relu.node): |
| y = torch.neg(x) |
| z = torch.tanh(y) |
| |
| graph.output((relu.node, z.node)) |
| graph.lint() |
| |
| expected_ops = ['x', 'neg', 'tanh', 'relu'] |
| for node, expected in zip(graph.nodes, expected_ops): |
| assert expected in node.name |
| |
| def test_reassign_args_kwargs_uses(self): |
| graph = torch.fx.Graph() |
| x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y')) |
| z = x + y |
| zed = z + z + z |
| graph.output(zed.node) |
| graph.lint() |
| |
| # zed = z + z + z -> zed = z + z + x |
| zed.node.args = (zed.node.args[0], x.node) |
| self.assertEqual(x.node.users.keys(), [z.node, zed.node]) |
| |
| # z = x + y -> z = y + y |
| z.node.args = (y.node, y.node) |
| self.assertEqual(x.node.users.keys(), [zed.node]) |
| |
| def test_trace_function(self): |
| def foo(x, y): |
| return torch.relu(x) + y |
| |
| x, y = torch.randn(3, 4), torch.randn(3, 4) |
| self.checkGraphModule(foo, (x, y)) |
| |
| def test_trace_dict_int_keys(self): |
| class ModWithDictArg(torch.nn.Module): |
| def forward(self, d : Dict[int, torch.Tensor]): |
| return d[42] |
| |
| class CallsModWithDict(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.m = ModWithDictArg() |
| |
| def forward(self, x): |
| return self.m({42: x}) |
| |
| class MyTracer(torch.fx.Tracer): |
| def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: |
| return isinstance(m, ModWithDictArg) |
| |
| traced_graph = MyTracer().trace(CallsModWithDict()) |
| |
| def test_trace_dict_proxy_keys(self): |
| class ModWithDictArg(torch.nn.Module): |
| def forward(self, d : Dict[torch.Tensor, torch.Tensor]): |
| return d[42] |
| |
| class CallsModWithDict(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.m = ModWithDictArg() |
| |
| def forward(self, x): |
| return self.m({x: x}) |
| |
| class MyTracer(torch.fx.Tracer): |
| def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool: |
| return isinstance(m, ModWithDictArg) |
| |
| with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'): |
| traced_graph = MyTracer().trace(CallsModWithDict()) |
| |
| def test_direct_param_use(self): |
| class TransposeTest(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.b = torch.nn.Parameter(torch.rand(4, 3)) |
| |
| def forward(self, x): |
| return self.b |
| |
| class Foo(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.a = TransposeTest() |
| |
| def forward(self, x): |
| return self.a.b, self.a.b.t(), self.a.b.view(12) |
| |
| traced = torch.fx.symbolic_trace(Foo()) |
| assert(all('constant' not in node.target for node in traced.graph.nodes)) |
| |
| def test_single_default_arg(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, y=1): |
| return y |
| |
| m = M() |
| self.checkGraphModule(m, ()) |
| self.checkGraphModule(m, (3,)) |
| |
| def test_multiple_default_args(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, y=1, z=2): |
| return y + z |
| |
| m = M() |
| self.checkGraphModule(m, ()) |
| self.checkGraphModule(m, (3,)) |
| self.checkGraphModule(m, (3, 4)) |
| |
| def test_regular_and_default_args(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x, y=1): |
| return x + y |
| |
| m = M() |
| self.checkGraphModule(m, (2,)) |
| self.checkGraphModule(m, (2, 3)) |
| |
| def test_string_literal_return(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self): |
| return "foo" |
| |
| m = M() |
| self.checkGraphModule(m, ()) |
| |
| def test_namedtuple_return_qualname(self): |
| class NamedTupReturn(torch.nn.Module): |
| def forward(self, x): |
| return MyNamedTup(x, x) |
| |
| traced = symbolic_trace(NamedTupReturn()) |
| input = torch.rand(3, 4) |
| self.assertEqual(traced(input), MyNamedTup(input, input)) |
| |
| def test_update_args_kwargs_yells_at_you(self): |
| symtraced = symbolic_trace(SimpleTest()) |
| node = next(iter(symtraced.graph.nodes)) |
| with self.assertRaisesRegex(AttributeError, '__update_args_kwargs'): |
| node.__update_args_kwargs((), {}) |
| |
| def test_torchbind_class_attribute_in_fx(self): |
| if TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS: |
| self.skipTest("torch.classes._TorchScriptTesting._StackString is registered, skipping") |
| |
| class FooBar1234(torch.nn.Module): |
| def __init__(self): |
| super(FooBar1234, self).__init__() |
| self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"]) |
| |
| def forward(self): |
| return self.f.top() |
| |
| m = FooBar1234() |
| self.checkGraphModule(m, ()) |
| |
| def test_torchbind_class_attribute_in_fx_tensor_arg(self): |
| if TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS: |
| self.skipTest("torch.classes._TorchScriptTesting._ReLUClass is registered, skipping") |
| |
| class FooBar2341(torch.nn.Module): |
| def __init__(self): |
| super(FooBar2341, self).__init__() |
| self.f = torch.classes._TorchScriptTesting._ReLUClass() |
| |
| def forward(self, x): |
| return self.f.run(x) |
| |
| m = FooBar2341() |
| |
| traced = symbolic_trace(m) |
| input = torch.randn(3, 4) |
| self.assertEqual(traced(input), m(input)) |
| |
| self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes)) |
| |
| def test_script_method_trace(self): |
| class Scripted(torch.nn.Module): |
| def forward(self, x): |
| return torch.relu(x) |
| |
| class Holder(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.s = torch.jit.script(Scripted()) |
| |
| def forward(self, x): |
| return self.s(x) |
| |
| h = Holder() |
| traced = symbolic_trace(h) |
| input = torch.randn(3, 4) |
| self.assertEqual(traced(input), h(input)) |
| |
| self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes)) |
| |
| def test_namedtuple_return_trace(self): |
| class NamedTupReturn(torch.nn.Module): |
| def forward(self, x): |
| return Pair(x, x) |
| |
| traced = symbolic_trace(NamedTupReturn()) |
| input = torch.rand(3, 4) |
| self.assertEqual(traced(input), Pair(input, input)) |
| |
| def test_return_type_exists(self): |
| class ReturnTypeModule(torch.nn.Module): |
| def other(self, x: List[str]) -> List[str]: |
| return x |
| |
| def forward(self, x: List[str]) -> List[str]: |
| return self.other(x) |
| |
| traced = symbolic_trace(ReturnTypeModule()) |
| self.assertIn("-> typing_List[str]", traced._code) |
| scripted = torch.jit.script(traced) |
| self.assertIn("-> List[str]", scripted.code) |
| |
| def getitem_inner(self): |
| class GetItemBase(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.register_buffer('pe', torch.randn(8, 8)) |
| |
| class GetItem1(GetItemBase): |
| def forward(self, x): |
| return self.pe[:, :x.size(0)] |
| |
| class GetItem2(GetItemBase): |
| def forward(self, x): |
| return self.pe[x.size(0)] |
| |
| class GetItem3(GetItemBase): |
| def forward(self, x): |
| return self.pe[4] # fx creates `self._tensor_constant0` here |
| |
| self.checkGraphModule(GetItem1(), [torch.zeros(4)]) |
| self.checkGraphModule(GetItem2(), [torch.zeros(4)]) |
| self.checkGraphModule(GetItem3(), [torch.zeros(4)]) |
| |
| @unittest.skipUnless(os.environ.get("FX_PATCH_GETITEM") == "1", |
| "Will be checked in test_getitem_subproc") |
| def test_getitem(self): |
| self.getitem_inner() |
| |
| def test_getitem_subproc(self): |
| # need to run this test in a subproc to work around: |
| # https://github.com/pytorch/pytorch/issues/50710 |
| proc = Process(target=run_getitem_target) |
| proc.start() |
| proc.join() |
| self.assertEqual(proc.exitcode, 0) |
| |
| |
| def test_user_friendly_call_provenance_with_function(self): |
| def fn(x): |
| return wrapper_fn(x) |
| |
| traced = torch.fx.symbolic_trace(fn) |
| |
| with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is " |
| "being compiled since it was called" |
| " from 'fn.forward'"): |
| scripted = torch.jit.script(traced) |
| |
| def test_user_friendly_call_provenance_with_module(self): |
| class M(torch.nn.Module): |
| def forward(self, x): |
| return wrapper_fn(x) |
| |
| traced = torch.fx.symbolic_trace(M()) |
| |
| with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is " |
| "being compiled since it was called" |
| " from 'M.forward'"): |
| scripted = torch.jit.script(traced) |
| |
| def test_snake_case(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.activations = torch.nn.ModuleDict([ |
| ["snake_case", torch.nn.ReLU()], |
| ["PascalCase", torch.nn.LeakyReLU()], |
| ["ALL_CAPS", torch.nn.PReLU()] |
| ]) |
| |
| def forward(self, x): |
| a = self.activations["snake_case"](x) |
| b = self.activations["PascalCase"](x) |
| c = self.activations["ALL_CAPS"](x) |
| return a, b, c |
| |
| traced = symbolic_trace(M()) |
| |
| check = [ |
| ("activations_snake_case", "activations.snake_case"), |
| ("activations_pascal_case", "activations.PascalCase"), |
| ("activations_all_caps", "activations.ALL_CAPS") |
| ] |
| |
| i = 0 |
| for node in traced.graph.nodes: |
| if node.op == "placeholder" or node.op == "output": |
| continue |
| name = check[i][0] |
| target = check[i][1] |
| self.assertEqual(name, node.name) |
| self.assertEqual(target, node.target) |
| i += 1 |
| self.assertEqual(i, 3) |
| |
| def test_no_mutation(self): |
| from torch.fx.immutable_collections import immutable_list |
| x = immutable_list([3, 4]) |
| with self.assertRaisesRegex(NotImplementedError, "new_args"): |
| x[0] = 4 |
| |
| def test_partial_trace(self): |
| class Foo(torch.nn.Module): |
| def forward(self, x, y): |
| if y: |
| return 2 * x |
| else: |
| return x |
| mod = Foo() |
| mod_true = symbolic_trace(mod, concrete_args={'y': True}) |
| mod_false = symbolic_trace(mod, concrete_args={'y': False}) |
| self.assertEqual(mod_true(3), 6) |
| self.assertEqual(mod_false(3), 3) |
| |
| def test_custom_traceback_raised_when_exception_source_is_graphmodule(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.W = torch.nn.Parameter(torch.randn(5)) |
| |
| def forward(self, x): |
| return torch.dot(self.W, x) |
| |
| traced = torch.fx.symbolic_trace(M()) |
| |
| out = [n for n in traced.graph.nodes if n.op == "output"][-1] |
| with traced.graph.inserting_before(out): |
| relu_out = traced.graph.call_method(method_name='relu', |
| args=(out.args[0],)) |
| out.args = (relu_out,) |
| |
| traced.recompile() |
| |
| with self.capture_stderr() as captured: |
| with self.assertRaises(TypeError): |
| traced(5) |
| |
| self.assertIn("Call using an FX-traced Module, line 4 of the " |
| "traced Module’s generated forward function:", |
| captured[0]) |
| |
| def test_custom_traceback_not_raised_when_exception_source_is_submodule(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(3, 4) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| traced = torch.fx.symbolic_trace(M()) |
| |
| # Do not change this to `capture_stderr` or another context |
| # manager without ensuring that the output is as expected |
| try: |
| traced(torch.rand(5, 5)) |
| except RuntimeError: |
| captured = traceback.format_exc() |
| |
| self.assertNotIn("Call using an FX-traced Module, line 4 of the" |
| " traced Module’s generated forward function:", |
| captured) |
| |
| def test_ast_rewriter_rewrites_assert(self): |
| class M(torch.nn.Module): |
| def forward(self, x: torch.Tensor, y: int, z: int): |
| assert y == z |
| return torch.add(x, x) |
| |
| ast_rewriter = RewritingTracer() |
| graph = ast_rewriter.trace(M()) |
| traced = GraphModule(ast_rewriter.root, graph, "gm") |
| |
| traced.graph.lint() |
| |
| def test_ast_rewriter_rewrites_assert_with_message(self): |
| class M(torch.nn.Module): |
| def forward(self, x: torch.Tensor, y: int, z: int): |
| assert y == z, "msg" |
| return torch.add(x, x) |
| |
| ast_rewriter = RewritingTracer() |
| graph = ast_rewriter.trace(M()) |
| traced = GraphModule(ast_rewriter.root, graph, "gm") |
| |
| traced.graph.lint() |
| |
| def test_ast_rewriter_reassigns_submodules(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.bn = torch.nn.BatchNorm2d(100) |
| |
| def forward(self, x: torch.Tensor): |
| return torch.add(x, x) |
| |
| ast_rewriter = RewritingTracer() |
| graph = ast_rewriter.trace(M()) |
| traced = GraphModule(ast_rewriter.root, graph, "gm") |
| |
| traced.graph.lint() |
| |
| def test_submodule_manipulation_API(self): |
| class C(torch.nn.Module): |
| def __init__(self): |
| super(C, self).__init__() |
| self.conv = torch.nn.Conv2d(16, 33, 3, stride=2) |
| self.param = torch.nn.Parameter(torch.rand(2, 3)) |
| |
| def forward(self, x): |
| return self.conv(torch.cat([self.param, x])) |
| |
| class B(torch.nn.Module): |
| def __init__(self): |
| super(B, self).__init__() |
| self.linear = torch.nn.Linear(100, 200) |
| self.register_buffer("buf", torch.randn(2, 3)) |
| self.net_c = C() |
| |
| def forward(self, x): |
| return self.linear(torch.cat([self.buf, self.net_c(x)])) |
| |
| class A(torch.nn.Module): |
| def __init__(self): |
| super(A, self).__init__() |
| self.net_b = B() |
| self.param = torch.nn.Parameter(torch.rand(2, 3)) |
| |
| def forward(self, x): |
| return self.net_b(x) + self.param |
| |
| a = symbolic_trace(A()) |
| |
| a.add_submodule("net_b.net_c.dropout", torch.nn.Dropout(p=0.2)) |
| |
| conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"][-1] |
| with a.graph.inserting_before(conv): |
| dropout = a.graph.call_module(module_name="net_b.net_c.dropout", |
| args=conv.args) |
| |
| conv.replace_all_uses_with(dropout) |
| a.graph.erase_node(conv) |
| a.recompile() |
| |
| def module_exists(gm: GraphModule, path: str) -> bool: |
| return any(path == name for name, _ in gm.named_modules()) |
| |
| def parameter_exists(gm: GraphModule, path: str) -> bool: |
| return (any(path == name for name, _ in gm.named_parameters()) |
| and any(path == name for name in gm.state_dict().keys())) |
| |
| def buffer_exists(gm: GraphModule, path: str) -> bool: |
| return (any(path == name for name, _ in gm.named_buffers()) |
| and any(path == name for name in gm.state_dict().keys())) |
| |
| # Test that we added the "dropout" submodule |
| self.assertTrue(module_exists(a, "net_b.net_c.dropout")) |
| |
| # Test `get_submodule` with an added submodule |
| self.assertIsNotNone(a.get_submodule("net_b.net_c.dropout")) |
| |
| # Test that the "conv" submodule is still there |
| self.assertTrue(module_exists(a, "net_b.net_c.conv")) |
| |
| # Test `get_submodule` with an original module |
| self.assertIsNotNone(a.get_submodule("net_b.net_c.conv")) |
| |
| # Test that the "conv" node is NOT still there |
| conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"] |
| self.assertEqual(conv, []) |
| |
| a.delete_submodule("net_b.net_c.conv") |
| |
| # Test that the "conv" submodule is now gone |
| self.assertFalse(module_exists(a, "net_b.net_c.conv")) |
| |
| # Test `get_submodule` with a deleted submodule |
| with self.assertRaisesRegex(AttributeError, "has no attribute " |
| "`conv`"): |
| self.assertIsNone(a.get_submodule("net_b.net_c.conv")) |
| |
| # Test `get_attr` warnings |
| cat = [n for n in a.graph.nodes if n.target == torch.cat][-1] |
| |
| with a.graph.inserting_before(cat): |
| |
| with warnings.catch_warnings(record=True) as w: |
| param = a.graph.get_attr(qualified_name="net_b.net_c.param") |
| self.assertEqual(len(w), 0) |
| |
| with self.assertWarnsRegex(UserWarning, "Attempted to " |
| "insert a get_attr Node with no " |
| "underlying reference in the " |
| "owning GraphModule"): |
| bad_param = a.graph.get_attr(qualified_name="net_b.param") |
| a.graph.erase_node(bad_param) |
| |
| cat.args = (*cat.args, param) |
| |
| a.recompile() |
| |
| a.graph.lint() |
| |
| # Test `get_parameter` |
| a.get_parameter("net_b.net_c.param") |
| with self.assertRaisesRegex(AttributeError, "is not an " |
| "nn.Parameter"): |
| a.get_parameter("net_b.buf") |
| with self.assertRaisesRegex(AttributeError, "has no attribute " |
| "`param`"): |
| a.get_parameter("net_b.param") |
| |
| # Test `get_buffer` |
| a.get_buffer("net_b.buf") |
| with self.assertRaisesRegex(AttributeError, "is not a " |
| "buffer"): |
| a.get_buffer("net_b.net_c.param") |
| with self.assertRaisesRegex(AttributeError, "has no attribute " |
| "`buf`"): |
| a.get_buffer("net_b.net_c.buf") |
| |
| # Test non-nested attributes |
| a.get_submodule("") |
| a.get_parameter("param") |
| |
| # Insert some unused submodules |
| a.add_submodule("net_b.embedding", torch.nn.Embedding(10, 3)) |
| a.add_submodule("net_b.net_c.embedding", torch.nn.Embedding(10, 3)) |
| a.add_submodule("net_b.net_c.rnn", torch.nn.RNN(10, 20, 2)) |
| a.add_submodule("batch_norm_2d", torch.nn.BatchNorm2d(100)) |
| |
| # Garbage collection |
| a.delete_all_unused_submodules() |
| |
| # Test that all the unused submodules are gone |
| self.assertFalse(module_exists(a, "net_b.embedding")) |
| self.assertFalse(module_exists(a, "net_b.net_c.embedding")) |
| self.assertFalse(module_exists(a, "net_b.net_c.rnn")) |
| self.assertFalse(module_exists(a, "batch_norm_2d")) |
| |
| # Test that we didn't delete any unused Parameters or buffers |
| self.assertTrue(parameter_exists(a, "net_b.net_c.param")) |
| self.assertTrue(buffer_exists(a, "net_b.buf")) |
| |
| a.graph.lint() |
| |
| def run_getitem_target(): |
| from torch.fx.symbolic_trace import _wrapped_methods_to_patch |
| _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) |
| try: |
| TestFX().getitem_inner() |
| finally: |
| _wrapped_methods_to_patch.pop() |
| |
| |
| if __name__ == '__main__': |
| run_tests() |