blob: 2a8818038a81bbae4fc6edb7061dd06a8c9e9d30 [file] [log] [blame]
Edward Z. Yang5b88a202022-07-20 18:12:25 -04001# Owner(s): ["module: fx"]
Jane Xu9ea34242021-10-18 12:23:22 -07002
Jason Ansela66851a2021-01-22 15:03:09 -08003import builtins
4import contextlib
5import copy
6import functools
Erjia Guanb96cc9a2021-04-16 06:46:46 -07007import inspect
Jason Ansela66851a2021-01-22 15:03:09 -08008import math
9import numbers
James Reed6a44efa2022-03-14 16:05:50 -070010import io
Jason Ansela66851a2021-01-22 15:03:09 -080011import operator
12import os
13import pickle
14import sys
Zsolt Dollensteinb0043072021-08-12 10:56:55 -070015import torch
Shen Li10224432021-08-12 11:39:31 -070016import traceback
James Reed538647f2021-08-30 19:54:50 -070017import typing
18import types
Shen Li10224432021-08-12 11:39:31 -070019import warnings
20import unittest
Animesh Jain7ebab922022-03-08 22:04:38 -080021import torch.nn.utils._stateless as _stateless
Shen Li10224432021-08-12 11:39:31 -070022from math import sqrt
Zsolt Dollensteinb0043072021-08-12 10:56:55 -070023from torch.multiprocessing import Process
24from torch.testing import FileCheck
Zsolt Dollensteinb0043072021-08-12 10:56:55 -070025from torch.testing._internal.common_methods_invocations import op_db
Shen Li10224432021-08-12 11:39:31 -070026from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests
27import torch.utils._pytree as pytree
28import torch.fx._pytree as fx_pytree
Horace Hed635d0f2022-02-11 10:07:21 -080029from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap, PH, CodeGen
Jordan Fix987f1462022-02-23 02:38:29 -080030from torch.fx.node import Target, Argument, _format_arg
Shen Li10224432021-08-12 11:39:31 -070031from torch.fx.passes import shape_prop
32from torch.fx.immutable_collections import immutable_dict, immutable_list
33from torch.fx.experimental.rewriter import RewritingTracer
34from torch.fx.operator_schemas import get_signature_for_torch_op
35from copy import deepcopy
36from collections import namedtuple
37
38from torch.fx.proxy import TraceError
James Reed538647f2021-08-30 19:54:50 -070039from torch.fx._compatibility import _BACK_COMPAT_OBJECTS, _MARKED_WITH_COMATIBLITY
Shen Li10224432021-08-12 11:39:31 -070040
41from fx.test_subgraph_rewriter import TestSubgraphRewriter # noqa: F401
42from fx.test_dce_pass import TestDCE # noqa: F401
43from fx.test_fx_const_fold import TestConstFold # noqa: F401
44from fx.test_fx_param_shape_control_flow import TestConstParamShapeInControlFlow # noqa: F401
Angela Yi3d0b0b22022-07-14 14:25:21 -070045from fx.test_pass_infra import TestPassManager # noqa: F401
Shangdi Yuc52ee6d2022-07-22 03:45:06 +000046from fx.test_common_passes import TestCommonPass # noqa: F401
47from fx.test_cse_pass import TestCSEPass # noqa: F401
Jerry Cai1b147a52021-07-28 23:45:19 -070048
Zeina Migeed9f3167e2021-07-06 16:46:53 -070049if sys.version_info >= (3, 7):
50 from fx.test_gradual_type import AnnotationsTest # noqa: F401
Zeina Migeed6f145542021-07-06 23:50:38 -070051if sys.version_info >= (3, 7):
52 from fx.test_gradual_type import TypeCheckerTest # noqa: F401
James Reed00b8ebe2020-10-07 21:32:51 -070053from typing import Any, Callable, Dict, NamedTuple, List, Optional, Tuple, Union
tktrungna81524332021-07-24 05:15:04 -070054from torch.testing._internal.common_utils import (
55 IS_FBCODE,
56 IS_MACOS,
57 IS_WINDOWS,
tktrungna81524332021-07-24 05:15:04 -070058 find_library_location,
59 run_tests,
soulitzer0fcdf932022-07-25 11:47:44 -040060 skipIfSlowGradcheckEnv,
tktrungna81524332021-07-24 05:15:04 -070061)
James Reeda070c612020-08-26 23:57:54 -070062from torch.testing._internal.jit_utils import JitTestCase
James Reed575e7492020-08-11 09:57:01 -070063
Shen Li10224432021-08-12 11:39:31 -070064from fx.named_tup import MyNamedTup
65
James Reed39514572020-08-18 14:59:20 -070066try:
Suraj Subramanian78022aa2021-04-22 08:52:45 -070067 from torchvision import models as torchvision_models
James Reed39514572020-08-18 14:59:20 -070068 HAS_TORCHVISION = True
69except ImportError:
70 HAS_TORCHVISION = False
71skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
72
James Reeda1a23662020-09-01 13:27:05 -070073class SimpleTest(torch.nn.Module):
74 def forward(self, x):
75 return torch.relu(x + 3.0)
76
Zachary DeVito26a90122020-10-05 15:15:43 -070077def a_non_torch_leaf(a, b):
78 return a + b
79
Alexander Soare219ba652021-08-12 17:35:02 -070080# Used for test_autowrap_function. Autowrapped functions need to be global
81def fx_int(x: float) -> int:
82 return int(x)
83
84def fx_int_x2(x: float) -> int:
85 return int(x) * 2
86
Richard Zou52d1ffb2021-07-28 06:26:08 -070087# used in test_pytree. It's all the way out here because pickling a GraphModule
88# that uses Point errors out if Point is local to the function
Shen Li10224432021-08-12 11:39:31 -070089Point = namedtuple('Point', ['x', 'y'])
Richard Zou52d1ffb2021-07-28 06:26:08 -070090
James Reeda7e92f12021-01-11 10:59:13 -080091# Test wrap() passing both a function name as well as a function
92# directly
93def a_lifted_leaf(a, b):
94 return a[0] + a[1] + b
95
Shen Li10224432021-08-12 11:39:31 -070096wrap('a_lifted_leaf')
James Reeda7e92f12021-01-11 10:59:13 -080097# Test wrapping twice doesn't break anything
Shen Li10224432021-08-12 11:39:31 -070098wrap('a_lifted_leaf')
James Reeda7e92f12021-01-11 10:59:13 -080099
100def a_lifted_leaf2(a, b):
101 return a[0] + a[1] + b
102
103wrap(a_lifted_leaf2)
104
Shen Li10224432021-08-12 11:39:31 -0700105wrap('len')
James Reed0291f352021-01-15 17:42:30 -0800106
Patrick Huc6505cc2021-09-01 10:49:39 -0700107wrap('getattr')
108
Jordan Fix987f1462022-02-23 02:38:29 -0800109def wrapped_named_tup(p1, *, p2):
110 return p1.x + p2.y
111
112wrap(wrapped_named_tup)
113
Jason Ansel3344f062021-01-19 13:39:16 -0800114@wrap
115def wrapped_via_decorator(a):
116 return a + 1
117
Shen Li10224432021-08-12 11:39:31 -0700118wrap('wrapped_with_submodule')
Ansley Ussery0d4dc6c2021-05-11 18:15:58 -0700119
120def wrapped_with_submodule(x: torch.Tensor, batchnorm1d: torch.nn.BatchNorm1d):
121 return batchnorm1d(x)
122
Kunal Bhallab00a4b72022-09-09 05:44:29 +0000123def my_decorator(f):
124 @functools.wraps(f)
125 def wrapper_inside_decorator(*args, **kwargs):
126 return f(*args, **kwargs)
127 return wrapper_inside_decorator
128
129@wrap
130@my_decorator
131def wrapped_decorated_fn(x):
132 return x
Jason Ansela66851a2021-01-22 15:03:09 -0800133
134real_wrapped_via_decorator = wrapped_via_decorator
135real_a_lifed_leaf = a_lifted_leaf
136real_a_lifed_leaf2 = a_lifted_leaf2
137_sqrt = sqrt
138
Shen Li10224432021-08-12 11:39:31 -0700139wrap('wrapper_fn')
Ansley Ussery4ac48902021-01-21 12:00:43 -0800140
141def wrapper_fn(x):
142 return torch.foo(x)
143
James Reed00b8ebe2020-10-07 21:32:51 -0700144class Pair(NamedTuple):
Shen Li10224432021-08-12 11:39:31 -0700145 x : torch.Tensor
146 y : torch.Tensor
James Reed00b8ebe2020-10-07 21:32:51 -0700147
Jordan Fix987f1462022-02-23 02:38:29 -0800148 def _custom_fx_repr_fn(self) -> str:
149 return f"Pair(x={_format_arg(self.x)}, y={_format_arg(self.y)})"
150
Horace He8d363d32021-05-07 04:46:50 -0700151# for testing pytrees
152class Foo(object): # noqa: B209
153 def __init__(self, a, b):
154 self.a = a
155 self.b = b
156
James Reeda070c612020-08-26 23:57:54 -0700157class TestFX(JitTestCase):
James Reedf40c9db2021-02-25 18:43:52 -0800158 def setUp(self):
Jane Xu6ecd13d2022-03-16 15:04:32 -0700159 super().setUp()
James Reede1c3e5f2021-09-02 21:11:57 -0700160 # Checking for mutable operations whil tracing is feature flagged
161 # Enable it in testing but not by default
162 self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
163 torch.fx.proxy.TracerBase.check_mutable_operations = True
164
Jeff Daily340ae3c2022-07-14 00:42:16 +0000165 if not (IS_FBCODE or IS_WINDOWS or IS_MACOS):
James Reede1c3e5f2021-09-02 21:11:57 -0700166 lib_file_path = find_library_location('libtorchbind_test.so')
167 torch.ops.load_library(str(lib_file_path))
168
169 def tearDown(self):
Jane Xu6ecd13d2022-03-16 15:04:32 -0700170 super().tearDown()
James Reede1c3e5f2021-09-02 21:11:57 -0700171 torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
James Reedf40c9db2021-02-25 18:43:52 -0800172
Michael Suo6c28df72020-08-26 14:33:37 -0700173 def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None):
174 """Check that an nn.Module's results match the GraphModule version
175 for a given set of args/kwargs.
176 """
177 kwargs = kwargs if kwargs else {}
178 ref_outs = m(*args, **kwargs)
179 gm = symbolic_trace(m)
Ansley Ussery85109ce2021-03-04 14:50:34 -0800180 gm.graph.lint()
Michael Suo6c28df72020-08-26 14:33:37 -0700181 test_outs = gm(*args, **kwargs)
182 self.assertEqual(ref_outs, test_outs)
183
James Reed575e7492020-08-11 09:57:01 -0700184 def test_graph_module(self):
185 class MySub(torch.nn.Module):
186 def __init__(self):
187 super().__init__()
188 self.w = torch.nn.Parameter(torch.rand(4, 3))
189
190 def forward(self, x):
191 return self.w + x
192
193 class MyModule(torch.nn.Module):
194 def __init__(self):
195 super().__init__()
196 self.lin = torch.nn.Linear(4, 3)
197 self.sub_mod = MySub()
198 self.w = torch.nn.Parameter(torch.rand(3))
199
200 def forward(self, A, B, c):
201 t = torch.sigmoid(A) + self.lin(c)
Shen Li10224432021-08-12 11:39:31 -0700202 return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3))
James Reed575e7492020-08-11 09:57:01 -0700203
204 m = MyModule()
205 gm = symbolic_trace(m)
206
207 ms = torch.jit.script(gm)
208
209 class M2(torch.nn.Module):
210 def forward(self, A):
211 m, idx = torch.max(A, 0)
212 return m + 1, idx + 1
213
214 m2 = M2()
215 gm2 = symbolic_trace(m2)
216
217 class T(torch.nn.Module):
Shen Li10224432021-08-12 11:39:31 -0700218
James Reed575e7492020-08-11 09:57:01 -0700219 def forward(self, A, b=4, *args, c=5, **kwargs):
Shen Li10224432021-08-12 11:39:31 -0700220 x = A + 1 + args[0] + kwargs['3']
James Reed575e7492020-08-11 09:57:01 -0700221 return x
222
223 t = T()
224 symbolic_trace(t)
225
Jay Leverett44fcb002021-08-31 17:28:42 -0700226 # test for issue described at https://github.com/pytorch/pytorch/issues/63883
227 class M3(torch.nn.Module):
228 def forward(self, x):
229 return torch.relu(x)
230
231 m3 = M3()
232 gm3 = symbolic_trace(m3)
233 new_instance = gm3.__new__(type(gm3))
234 new_instance.__init__(gm3, gm3.graph)
235
236 x = torch.randn(5, 3)
237 torch.testing.assert_allclose(new_instance(x), torch.relu(x))
238
Zachary DeVito26a90122020-10-05 15:15:43 -0700239 def test_custom_import(self):
240 graph = torch.fx.Graph()
Shen Li10224432021-08-12 11:39:31 -0700241 a = graph.placeholder('x')
242 b = graph.placeholder('y')
Zachary DeVito26a90122020-10-05 15:15:43 -0700243 c = graph.call_function(a_non_torch_leaf, (a, b))
244 d = graph.call_function(torch.sin, (c,))
245 graph.output(d)
246 gm = GraphModule(torch.nn.Module(), graph)
247 x, y = torch.rand(1), torch.rand(1)
248 self.assertEqual(torch.sin(x + y), gm(x, y))
249
Michael Suo6c28df72020-08-26 14:33:37 -0700250 def test_args_kwargs(self):
251 class T(torch.nn.Module):
252 def forward(self, *args, **kwargs):
Shen Li10224432021-08-12 11:39:31 -0700253 x = args[0] + kwargs['foo']
Michael Suo6c28df72020-08-26 14:33:37 -0700254 return x
255
256 t = T()
Shen Li10224432021-08-12 11:39:31 -0700257 self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
Michael Suo6c28df72020-08-26 14:33:37 -0700258
James Reed9ccf85b2020-10-22 11:52:31 -0700259 def test_args_kwargs_no_self(self):
260 class T(torch.nn.Module):
261 def forward(*args, **kwargs): # noqa: B902
262 self = args[0]
263 return torch.relu(args[1])
264
265 t = T()
Shen Li10224432021-08-12 11:39:31 -0700266 with self.assertRaisesRegex(RuntimeError, r'cannot be part of \*args expansion'):
267 self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
James Reed9ccf85b2020-10-22 11:52:31 -0700268
James Reed575e7492020-08-11 09:57:01 -0700269 def test_fx_shifts(self):
270 class MyModule(torch.nn.Module):
271 def forward(self, x):
272 return x << 3, x >> 3
273
274 input = torch.LongTensor(10).random_(0, 1024)
275
276 m = MyModule()
Michael Suo6c28df72020-08-26 14:33:37 -0700277 self.checkGraphModule(m, (input,))
James Reed575e7492020-08-11 09:57:01 -0700278
Jason Ansel487c7712021-09-17 14:28:38 -0700279 def test_fx_and_or(self):
280 class MyModule(torch.nn.Module):
281 def forward(self, x):
282 return x & x, x | x
283
284 input = torch.LongTensor(10).random_(0, 1024)
285
286 m = MyModule()
287 self.checkGraphModule(m, (input,))
288
James Reed575e7492020-08-11 09:57:01 -0700289 def test_dict(self):
290 class MyDictMod(torch.nn.Module):
291 def forward(self, d):
Shen Li10224432021-08-12 11:39:31 -0700292 return d['3'].relu(), {'4' : d['3'].neg()}
James Reed575e7492020-08-11 09:57:01 -0700293
Shen Li10224432021-08-12 11:39:31 -0700294 input_dict = {'3': torch.rand(3, 4)}
James Reed575e7492020-08-11 09:57:01 -0700295 m = MyDictMod()
James Reed575e7492020-08-11 09:57:01 -0700296
Michael Suo6c28df72020-08-26 14:33:37 -0700297 self.checkGraphModule(m, (input_dict,))
James Reed575e7492020-08-11 09:57:01 -0700298
Horace He35413a12021-09-08 09:59:04 -0700299 def test_matmul_tracing(self):
300 const = torch.randn(3)
301
302 def matmul_f(x):
303 return x @ const
304
305 mod = symbolic_trace(matmul_f)
306 inp = torch.randn(3)
307 self.assertEqual(mod(inp), matmul_f(inp))
308
309 def rmatmul_f(x):
310 return const @ x
311
312 mod = symbolic_trace(rmatmul_f)
313 inp = torch.randn(3)
314 self.assertEqual(mod(inp), rmatmul_f(inp))
315
316
James Reed0134ded2020-08-12 14:25:53 -0700317 def test_disallow_override(self):
318 # Custom delegate to disallow in-place tensor operations
Zachary DeVito2c1b2152020-09-15 15:49:55 -0700319 class NoMutableCallTracer(Tracer):
Shen Li10224432021-08-12 11:39:31 -0700320 def create_node(self, kind : str, target : Union[str, Callable],
321 args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
322 type_expr : Optional[Any] = None) -> Node:
James Reed0134ded2020-08-12 14:25:53 -0700323 name = target if isinstance(target, str) else torch.typename(target)
Shen Li10224432021-08-12 11:39:31 -0700324 if name[-1] == '_':
325 raise RuntimeError('In-place operations are not supported')
James Reed0134ded2020-08-12 14:25:53 -0700326 return super().create_node(kind, target, args, kwargs, name)
327
328 # Test method
329 class MyInplaceMod(torch.nn.Module):
330 def forward(self, x):
331 x.add_(3.0)
332 return x
333
334 m = MyInplaceMod()
335
Shen Li10224432021-08-12 11:39:31 -0700336 with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
Zachary DeVito2c1b2152020-09-15 15:49:55 -0700337 NoMutableCallTracer().trace(m)
James Reed0134ded2020-08-12 14:25:53 -0700338
339 # Test free function
340 class MyInplaceMod2(torch.nn.Module):
341 def forward(self, x):
342 torch.log_(x)
343 return x
344 m2 = MyInplaceMod2()
Shen Li10224432021-08-12 11:39:31 -0700345 with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
Zachary DeVito2c1b2152020-09-15 15:49:55 -0700346 NoMutableCallTracer().trace(m2)
James Reed0134ded2020-08-12 14:25:53 -0700347
348 # Test symbolic node as an arg
349 class MyInplaceMod3(torch.nn.Module):
350 def forward(self, x):
351 y = torch.ones(3, 4)
352 y.add_(x)
353 return x
354 m3 = MyInplaceMod3()
Shen Li10224432021-08-12 11:39:31 -0700355 with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
Zachary DeVito2c1b2152020-09-15 15:49:55 -0700356 NoMutableCallTracer().trace(m3)
James Reed0134ded2020-08-12 14:25:53 -0700357
358 def test_leaf_module(self):
359 # Custom delegate to make it so that there are no leaf modules, everything
360 # should get traced through
Zachary DeVito2c1b2152020-09-15 15:49:55 -0700361 class NoLeafModulesTracer(Tracer):
James Reed043466f2020-09-18 17:00:32 -0700362 def is_leaf_module(self, m, qualname):
James Reed0134ded2020-08-12 14:25:53 -0700363 return False
364
365 class MyReluMod(torch.nn.Module):
366 def __init__(self):
367 super().__init__()
368 self.relu = torch.nn.ReLU()
369
370 def forward(self, x):
371 return self.relu(x)
372
373 mrm = MyReluMod()
Zachary DeVito2c1b2152020-09-15 15:49:55 -0700374 sym = NoLeafModulesTracer().trace(mrm)
James Reed2ab74a42020-10-03 21:11:52 -0700375 for node in sym.nodes:
Shen Li10224432021-08-12 11:39:31 -0700376 self.assertNotEqual(node.op, 'call_module')
Ansley Ussery85109ce2021-03-04 14:50:34 -0800377 sym.lint()
James Reed0134ded2020-08-12 14:25:53 -0700378
James Reeda7e92f12021-01-11 10:59:13 -0800379 def test_wrap(self):
380 self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5))
381
382 def to_trace(y):
Shen Li10224432021-08-12 11:39:31 -0700383 return a_lifted_leaf((4, y), 3) + a_lifted_leaf((3, 4), 5) + a_lifted_leaf((y, y), y)
James Reeda7e92f12021-01-11 10:59:13 -0800384
385 m = symbolic_trace(to_trace)
Shen Li10224432021-08-12 11:39:31 -0700386 self.assertIn('a_lifted_leaf', m.code)
James Reeda7e92f12021-01-11 10:59:13 -0800387 self.assertEqual(27, m(2))
Jason Ansela66851a2021-01-22 15:03:09 -0800388 self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
James Reeda7e92f12021-01-11 10:59:13 -0800389
390 def test_wrap_fn_directly(self):
391 self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
392
393 def to_trace(y):
Shen Li10224432021-08-12 11:39:31 -0700394 return a_lifted_leaf2((4, y), 3) + a_lifted_leaf2((3, 4), 5) + a_lifted_leaf2((y, y), y)
James Reeda7e92f12021-01-11 10:59:13 -0800395
396 m = symbolic_trace(to_trace)
Shen Li10224432021-08-12 11:39:31 -0700397 self.assertIn('a_lifted_leaf2', m.code)
James Reeda7e92f12021-01-11 10:59:13 -0800398 self.assertEqual(27, m(2))
Jason Ansela66851a2021-01-22 15:03:09 -0800399 self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
James Reeda7e92f12021-01-11 10:59:13 -0800400
Jason Ansel3344f062021-01-19 13:39:16 -0800401 def test_wrapped_via_decorator(self):
402 self.assertEqual(wrapped_via_decorator(0), 1)
403
404 def to_trace(y):
405 return wrapped_via_decorator(y)
406
407 m = symbolic_trace(to_trace)
Shen Li10224432021-08-12 11:39:31 -0700408 self.assertIn('wrapped_via_decorator', m.code)
Jason Ansel3344f062021-01-19 13:39:16 -0800409 self.assertEqual(m(0), 1)
Jason Ansela66851a2021-01-22 15:03:09 -0800410 self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
411 self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
Jason Ansel3344f062021-01-19 13:39:16 -0800412
Jordan Fixf6579352021-06-16 17:23:23 -0700413 def test_wrapped_via_decorator_and_transformed(self):
414 self.assertEqual(wrapped_via_decorator(0), 1)
415
416 def to_trace(y):
417 return wrapped_via_decorator(y)
418
419 m = symbolic_trace(to_trace)
Shen Li10224432021-08-12 11:39:31 -0700420 self.assertIn('wrapped_via_decorator', m.code)
Jordan Fixf6579352021-06-16 17:23:23 -0700421 self.assertEqual(m(0), 1)
422 self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
423 self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
424
425 transformed = torch.fx.Transformer(m).transform()
Shen Li10224432021-08-12 11:39:31 -0700426 self.assertIn('wrapped_via_decorator', transformed.code)
Jordan Fixf6579352021-06-16 17:23:23 -0700427 self.assertEqual(transformed(0), 1)
428 self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
429 self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
430
Ansley Ussery0d4dc6c2021-05-11 18:15:58 -0700431 def test_wrap_with_submodule(self):
Shen Li10224432021-08-12 11:39:31 -0700432
Ansley Ussery0d4dc6c2021-05-11 18:15:58 -0700433 class M(torch.nn.Module):
434 def __init__(self):
435 super(M, self).__init__()
436 self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
437
438 def forward(self, x: torch.Tensor):
439 return wrapped_with_submodule(x, self.batchnorm1d)
440
441 m = symbolic_trace(M())
442
443 self.assertIn("wrapped_with_submodule", m.code)
444
445 input = torch.rand(3, 2)
446 ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
447 self.assertEqual(ref_batchnorm1d(input), m(input))
448
James Reed7b73fdf2021-05-17 19:48:47 -0700449 def test_wrapped_retrace(self):
450 def to_trace(y):
451 return wrapped_via_decorator(y)
452
453 m = symbolic_trace(to_trace)
Shen Li10224432021-08-12 11:39:31 -0700454 self.assertIn('wrapped_via_decorator', m.code)
James Reed7b73fdf2021-05-17 19:48:47 -0700455 self.assertEqual(m(0), 1)
456
457 retraced = symbolic_trace(m)
Shen Li10224432021-08-12 11:39:31 -0700458 self.assertIn('wrapped_via_decorator', retraced.code)
James Reed7b73fdf2021-05-17 19:48:47 -0700459 self.assertEqual(retraced(0), 1)
460
Kunal Bhallab00a4b72022-09-09 05:44:29 +0000461 def test_wrap_decorated_function(self):
462 def to_trace(y):
463 return wrapped_decorated_fn(y)
464
465 m = symbolic_trace(to_trace)
466 self.assertIn('wrapped_decorated_fn', m.code)
467 self.assertEqual(m(1), 1)
468
Zachary DeVito40116852020-08-14 16:43:55 -0700469 def test_graph_edit_with_proxy(self):
470 class M(torch.nn.Module):
471 def forward(self, a, b):
472 return a + b
473 m = M()
474 g = symbolic_trace(m).graph
James Reed79fe7942020-09-22 14:56:15 -0700475 new_g = torch.fx.Graph()
Shen Li10224432021-08-12 11:39:31 -0700476 val_map : Dict[Node, Node] = {}
James Reed53aea602020-10-02 17:05:42 -0700477 output_val = new_g.graph_copy(g, val_map)
478 t = Proxy(output_val)
Zachary DeVito40116852020-08-14 16:43:55 -0700479 # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
James Reed79fe7942020-09-22 14:56:15 -0700480 new_g.output((t + t).node)
481 gm = GraphModule(m, new_g)
Ansley Ussery85109ce2021-03-04 14:50:34 -0800482 gm.graph.lint()
Zachary DeVito40116852020-08-14 16:43:55 -0700483 self.assertEqual(gm(3, 4), 14)
484
James Reeda2d26102022-03-25 16:30:20 -0700485 def test_concrete_arg_none_assert(self):
486 class Foo(torch.nn.Module):
487 def forward(self, x, val=None):
488 return x if val is None else x + val
489
490 f = Foo()
491 traced = torch.fx.symbolic_trace(f, concrete_args={'val' : None})
492 with self.assertRaisesRegex(AssertionError, 'val has been specialized to have value None'):
493 traced(torch.randn(5), torch.randn(5))
494
495 x = torch.randn(5)
496 torch.testing.assert_close(traced(x), f(x))
497
sijiacefcbbb12022-05-15 00:11:52 -0700498 def test_trace_multiple_funcs(self):
499 class Foo(torch.nn.Module):
500 def forward(self, x, y):
501 return x + y
502
503 def minus_forward(self, x, y):
504 return x - y
505
506 def multiply_forward(self, x, y):
507 return x * y
508
509 f = Foo()
510 x, y = torch.randn(5), torch.randn(5)
511
512 print(torch.__version__)
513
514 tracer = Tracer()
515 torch.testing.assert_close(GraphModule(f, tracer.trace(f))(x, y), f(x, y))
516
517 tracer.traced_func_name = "minus_forward"
518 torch.testing.assert_close(
519 GraphModule(f, tracer.trace(f))(x, y),
520 f.minus_forward(x, y),
521 )
522
523 tracer.traced_func_name = "multiply_forward"
524 torch.testing.assert_close(
525 GraphModule(f, tracer.trace(f))(x, y),
526 f.multiply_forward(x, y),
527 )
528
529 tracer.traced_func_name = "add_forward"
530 with self.assertRaisesRegex(AssertionError, "doesn't exist in"):
531 tracer.trace(f)
532
533
James Reedb0bdc822020-09-28 22:50:49 -0700534 def test_graph_unique_names(self):
535 class M(torch.nn.Module):
536 def forward(self, a, b):
537 return a + b
538 m = M()
539 g = symbolic_trace(m).graph
540 new_g = torch.fx.Graph()
Shen Li10224432021-08-12 11:39:31 -0700541 val_map : Dict[Node, Node] = {}
James Reed53aea602020-10-02 17:05:42 -0700542 output_val = new_g.graph_copy(g, val_map)
543 t = Proxy(output_val)
James Reedb0bdc822020-09-28 22:50:49 -0700544 # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
545 new_g.output((t + t).node)
546 gm = GraphModule(m, new_g)
Shen Li10224432021-08-12 11:39:31 -0700547 seen_names : Set[str] = set()
James Reedb0bdc822020-09-28 22:50:49 -0700548 for node in gm.graph.nodes:
549 assert node.name not in seen_names
550 seen_names.add(node.name)
551
James Reed8b5b7fa2021-03-03 12:25:33 -0800552 def test_stack_traces(self):
553 class M(torch.nn.Module):
554 def forward(self, a, b):
555 return a + b
556
557 tracer = torch.fx.Tracer()
558 tracer.record_stack_traces = True
559
560 graph = tracer.trace(M())
Vasiliy Kuznetsov2dd46d32021-12-07 06:17:09 -0800561 # saving the original list because we will insert new nodes as a part of a test
562 orig_graph_nodes = list(graph.nodes)
563 for node in orig_graph_nodes:
Shen Li10224432021-08-12 11:39:31 -0700564 if node.op == 'output':
James Reed8b5b7fa2021-03-03 12:25:33 -0800565 continue
566 self.assertTrue(node.stack_trace is not None)
Shen Li10224432021-08-12 11:39:31 -0700567 assert 'test_fx.py' in node.stack_trace
James Reed8b5b7fa2021-03-03 12:25:33 -0800568
Vasiliy Kuznetsov2dd46d32021-12-07 06:17:09 -0800569 # verify that copying the node does not lose the stack trace
570 new_node = graph.node_copy(node)
571 self.assertTrue(new_node.stack_trace is not None)
572 assert 'test_fx.py' in new_node.stack_trace
573
Sherlock Huang752579a2022-08-03 02:35:15 +0000574 def test_stack_traces_with_transformer(self):
575 class M(torch.nn.Module):
576 def forward(self, a, b):
577 return a + b
578
579 tracer = torch.fx.Tracer()
580 tracer.record_stack_traces = True
581
582 graph = tracer.trace(M())
583 gm = GraphModule(tracer.root, graph)
584 new_gm = Transformer(gm).transform()
585
586 # nodes after Transformer should still preserve the original node's stack trace
587 for node in new_gm.graph.nodes:
Sherlock Huang69156762022-08-10 21:04:00 +0000588 if node.op in {'placeholder', 'output'}:
Sherlock Huang752579a2022-08-03 02:35:15 +0000589 continue
590 self.assertTrue(node.stack_trace is not None)
591 assert 'test_fx.py' in node.stack_trace
592
James Reedb0bdc822020-09-28 22:50:49 -0700593 def test_graph_unique_names_manual(self):
Shen Li10224432021-08-12 11:39:31 -0700594 graph : torch.fx.Graph = torch.fx.Graph()
595 a : torch.fx.Node = graph.create_node('placeholder', 'x')
596 b : torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,), name='foo_1_1')
597 c : torch.fx.Node = graph.create_node('get_attr', 'y_attr', name='foo_1')
598 d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
James Reedb0bdc822020-09-28 22:50:49 -0700599 graph.output(d)
600 graph2 = torch.fx.Graph()
Shen Li10224432021-08-12 11:39:31 -0700601 val_map : Dict[Node, Node] = {}
James Reed53aea602020-10-02 17:05:42 -0700602 graph2.graph_copy(graph, val_map)
Shen Li10224432021-08-12 11:39:31 -0700603 seen_names : Set[str] = set()
James Reedb0bdc822020-09-28 22:50:49 -0700604 for node in graph2.nodes:
605 assert node.name not in seen_names
606 seen_names.add(node.name)
607
Zachary DeVitob349f582020-08-23 15:35:08 -0700608 def test_unpack(self):
609 class M(torch.nn.Module):
610 def forward(self, a, b):
611 c, d = a
612 return c + d + b
Michael Suo6c28df72020-08-26 14:33:37 -0700613
Zachary DeVitob349f582020-08-23 15:35:08 -0700614 a = (torch.rand(1), torch.rand(1))
615 b = torch.rand(1)
Michael Suo6c28df72020-08-26 14:33:37 -0700616 m = M()
617 self.checkGraphModule(m, (a, b))
Zachary DeVito40116852020-08-14 16:43:55 -0700618
James Reeda070c612020-08-26 23:57:54 -0700619 def test_native_callable(self):
Jeff Daily340ae3c2022-07-14 00:42:16 +0000620 if IS_FBCODE or IS_WINDOWS or IS_MACOS:
Michael Suo374e9372020-09-18 13:54:03 -0700621 raise unittest.SkipTest("non-portable load_library call used in test")
James Reeda070c612020-08-26 23:57:54 -0700622 # This test exercises the case where we use FX to translate from Python
623 # code to some native callable object
624 #
625 # For the purposes of testing, we use ElementwiseInterpreter defined
626 # in test_custom_class.cpp.
627 #
628 # We test that we can
629 # 1) Construct a native callable from FX IR
630 # 2) Construct a drop-in replacement module that delegates to the
631 # native callable rather than the original code
632 # 3) Run both the original code and native callable wrapper with
633 # equivalent results
634 # 4) TorchScript compile the native callable wrapper and confirm
635 # equivalent results with the reference
636 # 5) TorchScript serialize and deserialize the native callable
637 # and confirm equivalent results with the reference
638
639 # We use this simple Module as a reference computation
640 class MySimpleMod(torch.nn.Module):
641 def forward(self, x):
642 return 3.0 * x + x
643
644 msm = MySimpleMod()
645
646 # This is what a lowering pass might look like: a function that takes
647 # a valid nn.Module, symbolically traces it, lowers the Module to some
648 # representation, and wraps that representation up into another
649 # nn.Module instance that handles dispatch to the compiled/lowered code.
Shen Li10224432021-08-12 11:39:31 -0700650 def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Module:
James Reeda070c612020-08-26 23:57:54 -0700651 # ===== Stage 1: Symbolic trace the module =====
652 mod = symbolic_trace(orig_mod)
653
654 # ===== Stage 2: Lower GraphModule representation to the C++
655 # interpreter's instruction format ======
656 instructions = []
657 constant_idx = 0
658 constants = {}
659 fn_input_names = []
660
Shen Li10224432021-08-12 11:39:31 -0700661 target_to_name = {
662 operator.add : "add",
663 operator.mul : "mul"
664 }
James Reeda070c612020-08-26 23:57:54 -0700665
Shen Li10224432021-08-12 11:39:31 -0700666 output_node : Optional[Node] = None
James Reeda070c612020-08-26 23:57:54 -0700667 # For each instruction, create a triple
668 # (instruction_name : str, inputs : List[str], output : str)
669 # to feed into the C++ interpreter
670 for n in mod.graph.nodes:
671 target, args, out_name = n.target, n.args, n.name
672 assert len(n.kwargs) == 0, "kwargs currently not supported"
673
Shen Li10224432021-08-12 11:39:31 -0700674 if n.op == 'placeholder':
James Reeda070c612020-08-26 23:57:54 -0700675 # Placeholders specify function argument names. Save these
676 # for later when we generate the wrapper GraphModule
677 fn_input_names.append(target)
Shen Li10224432021-08-12 11:39:31 -0700678 elif n.op == 'call_function':
James Reeda070c612020-08-26 23:57:54 -0700679 assert target in target_to_name, "Unsupported call target " + target
680 arg_names = []
681 for arg in args:
682 if not isinstance(arg, Node):
683 # Pull out constants. These constants will later be
684 # fed to the interpreter C++ object via add_constant()
Shen Li10224432021-08-12 11:39:31 -0700685 arg_name = f'constant_{constant_idx}'
Yukio Siraichi93bf0ae2021-04-11 15:43:54 -0700686 constants[arg_name] = torch.tensor(
Shen Li10224432021-08-12 11:39:31 -0700687 [arg] if isinstance(arg, numbers.Number) else arg)
James Reeda070c612020-08-26 23:57:54 -0700688 arg_names.append(arg_name)
689 constant_idx += 1
690 else:
691 arg_names.append(arg.name)
692 instructions.append((target_to_name[target], arg_names, out_name))
Shen Li10224432021-08-12 11:39:31 -0700693 elif n.op == 'output':
James Reed53aea602020-10-02 17:05:42 -0700694 if output_node is not None:
Shen Li10224432021-08-12 11:39:31 -0700695 raise RuntimeError('Multiple output nodes!')
James Reed53aea602020-10-02 17:05:42 -0700696 output_node = n
James Reeda070c612020-08-26 23:57:54 -0700697 else:
Shen Li10224432021-08-12 11:39:31 -0700698 raise RuntimeError('Unsupported opcode ' + n.op)
James Reeda070c612020-08-26 23:57:54 -0700699
700 interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter()
701 # Load constants
702 for k, v in constants.items():
703 interpreter.add_constant(k, v)
704 # Specify names for positional input arguments
705 interpreter.set_input_names(fn_input_names)
706 # Load instructions
707 interpreter.set_instructions(instructions)
708 # Specify name for single output
James Reed53aea602020-10-02 17:05:42 -0700709 assert isinstance(output_node.args[0], torch.fx.Node)
710 interpreter.set_output_name(output_node.args[0].name)
James Reeda070c612020-08-26 23:57:54 -0700711
712 # ===== Stage 3: Create a wrapper GraphModule around the interpreter =====
713 class WrapperModule(torch.nn.Module):
714 def __init__(self, interpreter):
715 super().__init__()
716 self.interpreter = interpreter
717
718 wrapper = WrapperModule(interpreter)
719
720 # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter
721 # 3) Returns the speficied return value
722
723 # FIXME: The following code could be greatly simplified by symbolic_trace'ing
Zachary DeVito2c1b2152020-09-15 15:49:55 -0700724 # the wrapper with a Tracer that considers the Wrapper instance a root
James Reeda070c612020-08-26 23:57:54 -0700725 # module, however, I can't get `__call__` exposed on TorchBind classes
726 # without it messing up Python `hasattr` for some reason. More digging
727 # into CPython's implementation of hasattr is probably in order...
728
729 graph = torch.fx.Graph()
730 # Add placeholders for fn inputs
731 placeholder_nodes = []
732 for name in fn_input_names:
Shen Li10224432021-08-12 11:39:31 -0700733 placeholder_nodes.append(graph.create_node('placeholder', name))
James Reeda070c612020-08-26 23:57:54 -0700734
735 # Get the interpreter object
Shen Li10224432021-08-12 11:39:31 -0700736 interpreter_node = graph.create_node('get_attr', 'interpreter')
James Reeda070c612020-08-26 23:57:54 -0700737
738 # Add a node to call the interpreter instance
739 output_node = graph.create_node(
Shen Li10224432021-08-12 11:39:31 -0700740 op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes))
James Reeda070c612020-08-26 23:57:54 -0700741
742 # Register output
743 graph.output(output_node)
744
Ansley Ussery85109ce2021-03-04 14:50:34 -0800745 graph.lint()
James Reed6bdb8712020-09-28 22:50:49 -0700746
James Reeda070c612020-08-26 23:57:54 -0700747 # Return final GraphModule!!!
748 return GraphModule(wrapper, graph)
749
Shen Li10224432021-08-12 11:39:31 -0700750
James Reeda070c612020-08-26 23:57:54 -0700751 # Lower GraphModule to C++ interpreter
752 lowered = lower_to_elementwise_interpreter(msm)
753
754 # Compare correctness with original module
755 x = torch.rand(3, 4)
756 ref_out = msm(x)
757 test_out = lowered(x)
Philip Meier99203582021-08-19 12:45:32 -0700758 torch.testing.assert_close(test_out, ref_out)
James Reeda070c612020-08-26 23:57:54 -0700759
760 # Test TorchScript compilation
761 scripted_lowered = torch.jit.script(lowered)
762 script_out = scripted_lowered(x)
Philip Meier99203582021-08-19 12:45:32 -0700763 torch.testing.assert_close(script_out, ref_out)
James Reeda070c612020-08-26 23:57:54 -0700764
765 # Test TorchScript ser/de
766 import_copy = self.getExportImportCopy(scripted_lowered)
767 imported_out = import_copy(x)
Philip Meier99203582021-08-19 12:45:32 -0700768 torch.testing.assert_close(imported_out, ref_out)
James Reeda070c612020-08-26 23:57:54 -0700769
Michael Suo38309982020-08-27 10:41:26 -0700770 def test_reserved_getattr(self):
771 """Ensure that we do not name any nodes with a reserved builtin like `getattr`"""
772 class M(torch.nn.Module):
773 def forward(self, a):
774 return a.foo.bar.baz
775
776 m = M()
777 m_g = symbolic_trace(m)
Ansley Ussery85109ce2021-03-04 14:50:34 -0800778 m_g.graph.lint()
Michael Suo38309982020-08-27 10:41:26 -0700779 for node in m_g.graph.nodes:
780 self.assertTrue(node.name != "getattr")
781
James Reedb68f2272022-03-14 23:11:00 +0000782 @unittest.skip("Hotfix for SEV remediation")
James Reeddae7ed12022-03-02 19:57:36 -0800783 def test_trace_buffer_slice(self):
784 bs, d_hid = 10, 23
785
786 class ExampleCode(torch.nn.Module):
787 def __init__(self):
788 super().__init__()
789 self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid))
790 self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
791 self.lin = torch.nn.Linear(d_hid, d_hid)
792 self.register_buffer('buffer', torch.randn(bs + 100, d_hid))
793
794 def forward(self, x):
795 x = torch.mm(x, self.mm_param)
796 skip_connection = x
797 x = torch.relu(x)
798 x = torch.mm(x, self.mm_param) + self.buffer[:x.shape[0]]
799 x = self.lin(x)
800 x = torch.relu(x)
801 x = x + skip_connection
802 x = torch.mm(x, self.mm_param2)
803 x = self.lin(x)
804 return x
805
806
807 ec = ExampleCode()
808
809 traced = torch.fx.symbolic_trace(ec)
810
811 x = torch.randn(bs, d_hid)
812 torch.testing.assert_allclose(ec(x), traced(x))
813
814
Dmytro Dzhulgakov633d2392020-08-28 18:06:25 -0700815 def test_node_tagging(self):
Zachary DeVito2c1b2152020-09-15 15:49:55 -0700816 class TaggingTracer(Tracer):
Shen Li10224432021-08-12 11:39:31 -0700817 def create_node(self, kind : str, target : Union[str, Callable],
818 args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
819 type_expr : Optional[Any] = None) -> Node:
Dmytro Dzhulgakov633d2392020-08-28 18:06:25 -0700820 n = super().create_node(kind, target, args, kwargs, name)
Shen Li10224432021-08-12 11:39:31 -0700821 n.tag = 'foo'
Dmytro Dzhulgakov633d2392020-08-28 18:06:25 -0700822 return n
823
824 class M(torch.nn.Module):
825 def forward(self, a, b):
826 return a + b
827
828 m = M()
James Reed2ab74a42020-10-03 21:11:52 -0700829 g = TaggingTracer().trace(m)
Ansley Ussery85109ce2021-03-04 14:50:34 -0800830 g.lint()
Dmytro Dzhulgakov633d2392020-08-28 18:06:25 -0700831 for n in g.nodes:
Shen Li10224432021-08-12 11:39:31 -0700832 self.assertTrue(hasattr(n, 'tag'))
833 self.assertEqual(n.tag, 'foo')
Dmytro Dzhulgakov633d2392020-08-28 18:06:25 -0700834
James Reed73f7d63b2020-09-01 13:27:05 -0700835 def test_tensor_attribute(self):
836 class TensorAttribute(torch.nn.Module):
837 def __init__(self):
838 super().__init__()
839 self.tensor = torch.rand(3, 4)
840
841 def forward(self, x):
842 return torch.nn.functional.linear(x, self.tensor)
843
844 ta = TensorAttribute()
845 traced = symbolic_trace(ta)
846 traced(torch.rand(4, 4))
847
848 class WrapperForQualname(torch.nn.Module):
849 def __init__(self):
850 super().__init__()
851 self.ta = TensorAttribute()
852
853 def forward(self, x):
854 return torch.nn.functional.linear(x, self.ta.tensor)
855
856 wfq = WrapperForQualname()
857 traced2 = symbolic_trace(wfq)
Ansley Ussery85109ce2021-03-04 14:50:34 -0800858 traced2.graph.lint()
James Reed73f7d63b2020-09-01 13:27:05 -0700859 traced2(torch.rand(4, 4))
860
Horace He300613d2021-10-06 18:34:24 -0700861 def test_tensor_attribute_coalseced(self):
862
863 def count_attrs(fx_module):
864 targets = set()
865 for node in traced.graph.nodes:
866 if node.op == 'get_attr':
867 targets.add(node.target)
868 return len(targets)
869
870 val = torch.tensor(5)
871
872 def f(x):
873 return x + val + val
874 traced = symbolic_trace(f)
875 traced.graph.lint()
876 self.assertEqual(count_attrs(traced), 1)
877
878 val2 = torch.tensor(5)
879
880 def f(x):
881 val = torch.tensor(5)
882 return x + val + val2
883
884 traced = symbolic_trace(f)
885 traced.graph.lint()
886 self.assertEqual(count_attrs(traced), 2)
887
888
James Reed29664e62020-09-16 18:41:35 -0700889 def test_symbolic_trace_sequential(self):
890 class Simple(torch.nn.Module):
891 def forward(self, x):
892 return torch.neg(x)
893
Shen Li10224432021-08-12 11:39:31 -0700894 seq = torch.nn.Sequential(
895 Simple(),
896 Simple(),
897 Simple()
898 )
James Reed29664e62020-09-16 18:41:35 -0700899 traced = symbolic_trace(seq)
Ansley Ussery85109ce2021-03-04 14:50:34 -0800900 traced.graph.lint()
James Reed29664e62020-09-16 18:41:35 -0700901 x = torch.rand(3, 4)
902 self.assertEqual(traced(x), seq(x))
903
James Reed73f7d63b2020-09-01 13:27:05 -0700904 def test_tensor_constant(self):
905 class ConstTensor(torch.nn.Module):
906 def forward(self, x):
907 return torch.nn.functional.linear(x, torch.zeros(3, 4))
908
909 ct = ConstTensor()
910 traced = symbolic_trace(ct)
Ansley Ussery85109ce2021-03-04 14:50:34 -0800911 traced.graph.lint()
James Reed73f7d63b2020-09-01 13:27:05 -0700912 traced(torch.rand(4, 4))
Michael Suo38309982020-08-27 10:41:26 -0700913
James Reeda1a23662020-09-01 13:27:05 -0700914 def test_pickle_graphmodule(self):
James Reed043466f2020-09-18 17:00:32 -0700915 class Nested(torch.nn.Module):
916 def __init__(self):
917 super().__init__()
918 self.st = torch.nn.Linear(4, 4)
919
920 def forward(self, x):
921 return self.st(x)
922
923 n = Nested()
924 traced = symbolic_trace(n)
Ansley Ussery85109ce2021-03-04 14:50:34 -0800925 traced.graph.lint()
James Reeda1a23662020-09-01 13:27:05 -0700926 pickled = pickle.dumps(traced)
927 loaded = pickle.loads(pickled)
Ansley Ussery85109ce2021-03-04 14:50:34 -0800928 loaded.graph.lint()
James Reeda1a23662020-09-01 13:27:05 -0700929 x = torch.rand(3, 4)
930 self.assertEqual(loaded(x), traced(x))
931
Michael Suoecf3ca02021-02-23 13:33:22 -0800932 def test_pickle_custom_import(self):
933 graph = torch.fx.Graph()
Shen Li10224432021-08-12 11:39:31 -0700934 a = graph.placeholder('x')
935 b = graph.placeholder('y')
Michael Suoecf3ca02021-02-23 13:33:22 -0800936 c = graph.call_function(a_non_torch_leaf, (a, b))
937 d = graph.call_function(torch.sin, (c,))
938 graph.output(d)
939 gm = GraphModule(torch.nn.Module(), graph)
940 pickled = pickle.dumps(gm)
941 loaded = pickle.loads(pickled)
Ansley Ussery85109ce2021-03-04 14:50:34 -0800942 loaded.graph.lint()
Michael Suoecf3ca02021-02-23 13:33:22 -0800943 x, y = torch.rand(1), torch.rand(1)
944 self.assertEqual(loaded(x, y), gm(x, y))
945
James Reed998c4ca2020-11-19 19:51:14 -0800946 def test_all_input_nodes(self):
Shen Li10224432021-08-12 11:39:31 -0700947 graph : torch.fx.Graph = torch.fx.Graph()
948 a : torch.fx.Node = graph.placeholder('x')
949 b : torch.fx.Node = graph.call_module('linear_mod', args=(a,))
950 c : torch.fx.Node = graph.get_attr('y_attr')
951 d : torch.fx.Node = graph.call_function(operator.add, args=(b, c))
952 e : torch.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0))
James Reed998c4ca2020-11-19 19:51:14 -0800953 graph.output(e)
954 graph.lint()
955
956 self.assertEqual(b.all_input_nodes, [a])
957 self.assertEqual(c.all_input_nodes, [])
958 self.assertEqual(d.all_input_nodes, [b, c])
959 self.assertEqual(e.all_input_nodes, [d])
960
James Reed7a77d1c2020-09-02 15:33:38 -0700961 def test_deepcopy_graphmodule_with_transform(self):
962 st = SimpleTest()
963 traced = symbolic_trace(st)
Ansley Ussery85109ce2021-03-04 14:50:34 -0800964 traced.graph.lint()
James Reed7a77d1c2020-09-02 15:33:38 -0700965
966 def transform(traced):
James Reed79fe7942020-09-22 14:56:15 -0700967 new_graph = torch.fx.Graph()
Shen Li10224432021-08-12 11:39:31 -0700968 val_map : Dict[Node, Node] = {}
James Reed53aea602020-10-02 17:05:42 -0700969 output_value = new_graph.graph_copy(traced.graph, val_map)
Zachary DeVito2ad5a822020-09-04 11:33:05 -0700970 relu_out = new_graph.create_node(
Shen Li10224432021-08-12 11:39:31 -0700971 op='call_method', target='neg', args=(output_value,), kwargs={})
James Reed7a77d1c2020-09-02 15:33:38 -0700972 new_graph.output(relu_out)
Zachary DeVito2ad5a822020-09-04 11:33:05 -0700973 return GraphModule(traced, new_graph)
James Reed7a77d1c2020-09-02 15:33:38 -0700974 transformed = transform(traced)
Ansley Ussery85109ce2021-03-04 14:50:34 -0800975 transformed.graph.lint()
James Reed7a77d1c2020-09-02 15:33:38 -0700976 copied = copy.deepcopy(transformed)
James Reed60ae6c92020-09-17 17:10:46 -0700977 self.assertNotEqual(id(type(transformed)), id(type(copied)))
James Reed7a77d1c2020-09-02 15:33:38 -0700978 x = torch.randn(3, 4)
979 self.assertEqual(copied(x), transformed(x))
980
James Reed60ae6c92020-09-17 17:10:46 -0700981 def test_deepcopy_with_submods_params(self):
982 class Bar(torch.nn.Module):
983 def __init__(self):
984 super().__init__()
985 self.param = torch.nn.Parameter(torch.rand(3, 4))
986
987 def forward(self, x):
988 return torch.relu(x) + self.param
989
990 class Baz(torch.nn.Module):
991 def __init__(self):
992 super().__init__()
993 self.param = torch.nn.Parameter(torch.rand(3, 4))
994 self.bar = Bar()
995
996 def forward(self, x):
997 return self.bar(x) - self.param
998
999 baz = Baz()
1000 traced = symbolic_trace(baz)
Ansley Ussery85109ce2021-03-04 14:50:34 -08001001 traced.graph.lint()
James Reed60ae6c92020-09-17 17:10:46 -07001002 copied = copy.deepcopy(traced)
Ansley Ussery85109ce2021-03-04 14:50:34 -08001003 copied.graph.lint()
James Reed60ae6c92020-09-17 17:10:46 -07001004
Bradley Davis011fdc32021-08-17 09:55:25 -07001005 def test_deepcopy_graph_with_tracer_cls(self):
1006 class TestTracer(Tracer):
1007 def is_leaf_module(self, module, name):
1008 return True
1009
1010 g = Graph(tracer_cls=TestTracer)
1011 x = g.placeholder("x")
1012 g.output(x)
1013
1014 h = copy.deepcopy(g)
1015 self.assertIsNotNone(h._tracer_cls)
1016 self.assertTrue(g._tracer_cls == h._tracer_cls)
1017
James Reed8d53df32020-09-01 16:26:55 -07001018 def test_unpack_list_better_error(self):
1019 class SomeArgs(torch.nn.Module):
1020 def forward(self, a, b):
1021 return torch.rand(3, 4)
1022
1023 class UnpacksList(torch.nn.Module):
1024 def __init__(self):
1025 super().__init__()
1026 self.sa = SomeArgs()
1027
Shen Li10224432021-08-12 11:39:31 -07001028 def forward(self, x : list):
James Reed8d53df32020-09-01 16:26:55 -07001029 return self.sa(*x)
1030
1031 ul = UnpacksList()
Shen Li10224432021-08-12 11:39:31 -07001032 with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
James Reed8d53df32020-09-01 16:26:55 -07001033 symbolic_trace(ul)
1034
1035 def test_unpack_dict_better_error(self):
1036 class SomeKwargs(torch.nn.Module):
1037 def forward(self, x=3, y=4):
1038 return torch.rand(3, 4)
1039
1040 class UnpacksDict(torch.nn.Module):
1041 def __init__(self):
1042 super().__init__()
1043 self.sk = SomeKwargs()
1044
Shen Li10224432021-08-12 11:39:31 -07001045 def forward(self, x : dict):
James Reed8d53df32020-09-01 16:26:55 -07001046 return self.sk(**x)
1047
1048 ud = UnpacksDict()
Shen Li10224432021-08-12 11:39:31 -07001049 with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
James Reed8d53df32020-09-01 16:26:55 -07001050 symbolic_trace(ud)
1051
James Reedd390e3d2021-01-11 11:43:27 -08001052 def test_pretty_print_targets(self):
1053 # Test that Graph pretty-print prints friendly name for targets
1054 # in `operator` and `builtins`
1055
1056 class SomeMod(torch.nn.Module):
1057 def forward(self, x):
1058 return torch.add(x.foo + x.bar, 3.0)
1059
1060 traced = symbolic_trace(SomeMod())
1061 graph_str = str(traced.graph)
Shen Li10224432021-08-12 11:39:31 -07001062 self.assertIn('builtins.getattr', graph_str)
1063 self.assertIn('operator.add', graph_str)
1064 self.assertIn('torch.add', graph_str)
James Reedd390e3d2021-01-11 11:43:27 -08001065
Ansley Ussery215d9da2021-02-05 21:39:11 -08001066 def test_pretty_print_node(self):
1067 class M(torch.nn.Module):
1068 def __init__(self):
1069 super().__init__()
Shen Li10224432021-08-12 11:39:31 -07001070 self.param: torch.nn.Parameter = torch.nn.Parameter(
1071 torch.rand(3, 4))
Ansley Ussery215d9da2021-02-05 21:39:11 -08001072 self.linear = torch.nn.Linear(4, 5)
1073
1074 def forward(self, x: torch.Tensor, y: int = 2):
1075 return self.linear(x[y] + self.param).clamp(min=0.0, max=1.0)
1076
1077 traced = symbolic_trace(M())
1078
1079 all_formatted = "\n".join([n.format_node() for n in traced.graph.nodes])
1080
Shen Li10224432021-08-12 11:39:31 -07001081 FileCheck().check("x").check("placeholder") \
1082 .check("y").check("placeholder") \
1083 .check("getitem").check("call_function") \
1084 .check("param").check("get_attr") \
1085 .check("add").check("call_function") \
1086 .check("linear").check("call_module") \
1087 .check("clamp").check("call_method") \
1088 .run(all_formatted)
Ansley Ussery215d9da2021-02-05 21:39:11 -08001089
James Reedf51be322020-11-12 11:32:33 -08001090 def test_script_tensor_constant(self):
1091 # TorchScript seems to ignore attributes that start with `__`.
1092 # We used to call anonymous Tensor values `__tensor_constant*`, but
1093 # they were getting ignored by script. Now they're called
1094 # `_tensor_constant*`
1095 class IHaveATensorConstant(torch.nn.Module):
1096 def forward(self, x):
1097 return x + torch.rand(3, 4)
1098
1099 traced = torch.fx.symbolic_trace(IHaveATensorConstant())
1100 torch.jit.script(traced)
1101
Alexander Soare219ba652021-08-12 17:35:02 -07001102 def test_autowrap_functions(self):
1103 class AutowrapFnTest(torch.nn.Module):
1104 def forward(self, x):
1105 return fx_int(x.shape[0] / 2)
1106
1107 class AutowrapFnTest2(torch.nn.Module):
1108 def forward(self, x):
1109 return fx_int(x.shape[0] / 2) + fx_int_x2(x.shape[0] / 2)
1110
1111 # Check function(s) are wrapped
1112 # `int` would normally throw a TypeError as argument can't be `Proxy`
1113 tracer = Tracer(autowrap_functions=(fx_int,))
1114 graph = tracer.trace(AutowrapFnTest())
1115 traced = GraphModule(tracer.root, graph, 'test')
1116 tracer_2 = Tracer(autowrap_functions=(fx_int, fx_int_x2))
1117 tracer_2.trace(AutowrapFnTest2())
1118
1119 # Test scriptability
1120 traced_scripted = torch.jit.script(traced)
1121 self.assertEqual(traced_scripted(torch.rand(4)), 2)
1122
James Reed6a44efa2022-03-14 16:05:50 -07001123 def test_tuple_no_subscript(self):
1124 def foo(x : Tuple):
1125 return x[0]
1126
1127 traced = torch.fx.symbolic_trace(foo)
1128 x = (torch.randn(5, 3),)
1129 torch.testing.assert_allclose(traced(x), x[0])
1130
1131 bio = io.BytesIO()
1132
1133 torch.save(traced, bio)
1134
1135 bio.seek(0)
1136
1137 loaded = torch.load(bio)
1138
1139 torch.testing.assert_allclose(loaded(x), x[0])
1140
Hui Guoe2e44bb2020-12-18 16:42:04 -08001141 def test_torch_fx_len(self):
1142 class FXLenTest(torch.nn.Module):
1143 def forward(self, x):
James Reed0291f352021-01-15 17:42:30 -08001144 return len(x)
Hui Guoe2e44bb2020-12-18 16:42:04 -08001145
1146 traced = symbolic_trace(FXLenTest())
James Reed0291f352021-01-15 17:42:30 -08001147 self.assertEqual(traced(torch.rand(3, 4)), 3)
1148
1149 # Test scriptability
1150 scripted = torch.jit.script(FXLenTest())
1151 self.assertEqual(scripted(torch.rand(3)), 3)
1152
1153 traced_scripted = torch.jit.script(traced)
1154 self.assertEqual(traced_scripted(torch.rand(3)), 3)
1155
1156 # Test non-proxy len
1157 class FXLenTest2(torch.nn.Module):
1158 def __init__(self):
1159 super().__init__()
1160 self.l = [3, 4, 5]
1161
1162 def forward(self, x):
1163 return x + len(self.l)
1164
1165 traced2 = symbolic_trace(FXLenTest2())
1166 inp = torch.rand(3, 4)
1167 self.assertEqual(traced2(inp), inp + 3.0)
Jason Ansela66851a2021-01-22 15:03:09 -08001168 self.assertIs(len, builtins.len)
1169
Patrick Huc6505cc2021-09-01 10:49:39 -07001170 def test_torch_fx_getattr(self):
1171 class FXGetattrTest(torch.nn.Module):
1172 def forward(self, x):
1173 return getattr(x, 'nonexistent_attr', torch.Tensor([2, 3]))
1174
1175 traced = symbolic_trace(FXGetattrTest())
1176 self.assertEqual(traced(torch.rand(3, 4)), torch.Tensor([2, 3]))
1177
Jason Ansela66851a2021-01-22 15:03:09 -08001178 def test_sqrt(self):
1179 class Sqrt1(torch.nn.Module):
1180 def forward(self, x):
1181 return sqrt(x.size(0))
1182
1183 class Sqrt2(torch.nn.Module):
1184 def forward(self, x):
1185 return math.sqrt(x.size(0))
1186
1187 class Sqrt3(torch.nn.Module):
1188 def forward(self, x):
1189 return x + math.sqrt(2) + sqrt(2)
1190
1191 self.checkGraphModule(Sqrt1(), [torch.zeros(8)])
1192 self.checkGraphModule(Sqrt2(), [torch.zeros(8)])
1193 self.checkGraphModule(Sqrt3(), [torch.zeros(8)])
1194 self.assertIs(sqrt, _sqrt)
1195 self.assertIs(math.sqrt, _sqrt)
Hui Guoe2e44bb2020-12-18 16:42:04 -08001196
Lu Fangf15e2722020-09-02 16:06:42 -07001197 def test_torch_custom_ops(self):
1198 class M(torch.nn.Module):
1199 def forward(self, a):
1200 b = torch.ops.aten.sigmoid(a)
1201 c = torch.ops.aten.cat([a, b])
1202 return torch.ops.aten.cat((c, c))
1203 m = M()
1204 input = torch.randn(3)
1205 ref_out = m(input)
1206 gm = symbolic_trace(m)
Ansley Ussery85109ce2021-03-04 14:50:34 -08001207 gm.graph.lint()
Lu Fangf15e2722020-09-02 16:06:42 -07001208 out = gm(input)
1209 self.assertEqual(out, ref_out)
1210
anjali411beda4e82022-03-08 13:41:49 -08001211 def test_torch_op_overloads(self):
1212 class M(torch.nn.Module):
1213 def forward(self, a):
1214 b = torch.ops.aten.add.Tensor(a, a)
1215 return b
1216 m = M()
1217 input = torch.randn(3)
1218 ref_out = m(input)
1219 gm = symbolic_trace(m)
1220 gm.graph.lint()
1221 out = gm(input)
1222 self.assertEqual(out, ref_out)
1223
1224 for node in gm.graph.nodes:
1225 if node.op == 'call_function':
1226 assert isinstance(node.target, torch._ops.OpOverload)
1227 assert node.target.__name__ == 'add.Tensor'
1228
Michael Suoecf3ca02021-02-23 13:33:22 -08001229 def test_pickle_torch_custom_ops(self):
1230 class M(torch.nn.Module):
1231 def forward(self, a):
1232 b = torch.ops.aten.sigmoid(a)
1233 c = torch.ops.aten.cat([a, b])
1234 return torch.ops.aten.cat((c, c))
1235 m = M()
1236 input = torch.randn(3)
1237 ref_out = m(input)
1238 gm = symbolic_trace(m)
Ansley Ussery85109ce2021-03-04 14:50:34 -08001239 gm.graph.lint()
Michael Suoecf3ca02021-02-23 13:33:22 -08001240 pickled = pickle.dumps(gm)
1241 loaded = pickle.loads(pickled)
1242 self.assertEqual(loaded(input), gm(input))
1243
James Reedaf13faf2020-09-04 10:44:20 -07001244 def test_pretty_print(self):
1245 st = SimpleTest()
1246 traced = symbolic_trace(st)
Ansley Ussery85109ce2021-03-04 14:50:34 -08001247 traced.graph.lint()
James Reedaf13faf2020-09-04 10:44:20 -07001248 printed = str(traced)
Shen Li10224432021-08-12 11:39:31 -07001249 assert 'SimpleTest()' in printed
1250 assert 'torch.relu' in printed
James Reedaf13faf2020-09-04 10:44:20 -07001251
1252 def test_pretty_print_graph(self):
1253 class KwargPrintTest(torch.nn.Module):
1254 def forward(self, x):
1255 return torch.squeeze(x + 3.0, dim=2)
1256 st = KwargPrintTest()
1257 traced = symbolic_trace(st)
Ansley Ussery85109ce2021-03-04 14:50:34 -08001258 traced.graph.lint()
James Reedaf13faf2020-09-04 10:44:20 -07001259 stringed = str(traced.graph)
Shen Li10224432021-08-12 11:39:31 -07001260 for s in ['args', 'kwargs', '#users']:
James Reedaf13faf2020-09-04 10:44:20 -07001261 assert s in stringed
1262
James Reed00156d42021-05-14 14:05:44 -07001263 def test_custom_proxy_type(self):
1264 class TensorPair:
1265 def __init__(self, left, right):
1266 self.left, self.right = left, right
1267
1268 def add(self, other):
1269 l = self.left + other.left
1270 r = self.right + other.right
1271 return TensorPair(l, r)
1272
1273 def mul(self, other):
1274 l = self.left * other.left
1275 r = self.right * other.right
1276 return TensorPair(l, r)
1277
Shen Li10224432021-08-12 11:39:31 -07001278 def use_tensor_pair(x : TensorPair, y : TensorPair):
James Reed00156d42021-05-14 14:05:44 -07001279 s = x.add(y)
1280 return s.mul(x)
1281
1282 x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1283 y = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1284
1285 ref_out = use_tensor_pair(x, y)
1286
1287 traced = symbolic_trace(use_tensor_pair)
1288
1289 traced_out = traced(x, y)
1290 self.assertEqual(traced_out.left, ref_out.left)
1291 self.assertEqual(traced_out.right, ref_out.right)
1292
1293 def test_custom_proxy_type_literal(self):
1294 class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
1295 def __init__(self, left, right):
1296 self.left, self.right = left, right
1297
1298 def add(self, other):
1299 l = self.left + other.left
1300 r = self.right + other.right
1301 return TensorPair(l, r)
1302
1303 def mul(self, other):
1304 l = self.left * other.left
1305 r = self.right * other.right
1306 return TensorPair(l, r)
1307
Shen Li10224432021-08-12 11:39:31 -07001308 def use_tensor_pair_literal(x : TensorPair):
James Reed00156d42021-05-14 14:05:44 -07001309 s = x.add(TensorPair(torch.zeros(5, 3), torch.zeros(5, 3)))
1310 return s.mul(x)
1311
1312 x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1313
1314 ref_out = use_tensor_pair_literal(x)
1315
1316 traced = symbolic_trace(use_tensor_pair_literal)
1317
1318 traced_out = traced(x)
1319 self.assertEqual(traced_out.left, ref_out.left)
1320 self.assertEqual(traced_out.right, ref_out.right)
1321
1322 def test_custom_proxy_dynamic_value(self):
1323 class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
1324 def __init__(self, left, right):
1325 self.left, self.right = left, right
1326
1327 def add(self, other):
1328 l = self.left + other.left
1329 r = self.right + other.right
1330 return TensorPair(l, r)
1331
1332 def mul(self, other):
1333 l = self.left * other.left
1334 r = self.right * other.right
1335 return TensorPair(l, r)
1336
Shen Li10224432021-08-12 11:39:31 -07001337 def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor):
James Reed00156d42021-05-14 14:05:44 -07001338 s = x.add(TensorPair(y, y))
1339 return s.mul(x)
1340
1341 x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
1342 y = torch.randn(5, 3)
1343 ref_out = use_tensor_pair_ctor(x, y)
1344
1345 traced = symbolic_trace(use_tensor_pair_ctor)
1346
1347 traced_out = traced(x, y)
1348 self.assertEqual(traced_out.left, ref_out.left)
1349 self.assertEqual(traced_out.right, ref_out.right)
1350
1351 def test_custom_proxy_input_dependent_control_flow(self):
1352 class ZeroTensor(metaclass=torch.fx.ProxyableClassMeta):
1353 def __init__(self, inp):
1354 if inp.sum() == 0:
1355 self.is_zero = True
1356 self.tensor = torch.tensor([])
1357 else:
1358 self.is_zero = False
1359 self.tensor = inp
1360
1361 def add(self, other):
1362 if self.is_zero:
1363 return ZeroTensor(other.tensor)
1364 elif other.is_zero:
1365 return self
1366
Shen Li10224432021-08-12 11:39:31 -07001367 def use_zero_tensor(x : torch.Tensor, y : torch.Tensor):
James Reed00156d42021-05-14 14:05:44 -07001368 return ZeroTensor(x + y)
1369
1370 x, y = torch.randn(5, 3), torch.randn(5, 3)
1371
1372 ref_out = use_zero_tensor(x, y)
1373
1374 traced = symbolic_trace(use_zero_tensor)
1375
1376 traced_out = traced(x, y)
1377
1378 self.assertEqual(traced_out.is_zero, ref_out.is_zero)
1379 self.assertEqual(traced_out.tensor, ref_out.tensor)
1380
Zachary DeVito2c1b2152020-09-15 15:49:55 -07001381 def test_graph_fns(self):
1382 g = Graph()
Shen Li10224432021-08-12 11:39:31 -07001383 a = g.placeholder('a')
1384 b = g.call_module('linear', (a,))
1385 c = g.get_attr('bias')
1386 d = g.call_method('add', (b, c))
Zachary DeVito2c1b2152020-09-15 15:49:55 -07001387 e = g.call_function(torch.sin, (d,))
1388 g.output(e)
1389 mod = torch.nn.Module()
1390 mod.linear = torch.nn.Linear(3, 4)
1391 mod.bias = torch.rand(4)
1392 gm = GraphModule(mod, g)
Ansley Ussery85109ce2021-03-04 14:50:34 -08001393 gm.graph.lint()
James Reed29664e62020-09-16 18:41:35 -07001394 input = torch.rand(3)
Zachary DeVito2c1b2152020-09-15 15:49:55 -07001395 r = gm(input)
1396 ref = torch.sin(mod.linear(input) + mod.bias)
James Reed29664e62020-09-16 18:41:35 -07001397 self.assertEqual(r, ref)
Lu Fangf15e2722020-09-02 16:06:42 -07001398
James Reeddbfee422020-11-11 10:54:01 -08001399 def test_remove_uses(self):
Shen Li10224432021-08-12 11:39:31 -07001400 g : torch.fx.Graph = Graph()
1401 x : torch.fx.Node = g.placeholder('x')
1402 relu : torch.fx.Node = g.call_function(torch.relu, (x,))
1403 neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
James Reeddbfee422020-11-11 10:54:01 -08001404 g.output(neg)
1405
1406 neg.replace_all_uses_with(relu)
1407 g.erase_node(neg)
1408
1409 self.assertTrue(neg not in relu.users)
1410
Jay Banerjee5332d872022-03-04 10:32:06 -08001411 def test_remove_uses_with_custom_filter(self):
1412 g : torch.fx.Graph = Graph()
1413 x : torch.fx.Node = g.placeholder('x')
1414 relu : torch.fx.Node = g.call_function(torch.relu, (x,))
1415 neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
1416 g.output(neg)
1417
1418 neg.replace_all_uses_with(relu, lambda x: x != neg)
1419
1420 self.assertTrue(neg in relu.users)
1421
1422
James Reed5205cc12021-01-19 23:10:59 -08001423 def test_nonetype_annotation(self):
1424 eb = torch.nn.EmbeddingBag(3, 4)
1425 symbolic_trace(eb)
James Reeddbfee422020-11-11 10:54:01 -08001426
Michael Suoecf3ca02021-02-23 13:33:22 -08001427 def test_pickle_nonetype_annotation(self):
Shen Li10224432021-08-12 11:39:31 -07001428 eb = torch.nn.EmbeddingBag(10, 3, mode='sum')
Michael Suoecf3ca02021-02-23 13:33:22 -08001429 traced = symbolic_trace(eb)
1430 pickled = pickle.dumps(traced)
1431 loaded = pickle.loads(pickled)
Ansley Ussery85109ce2021-03-04 14:50:34 -08001432 loaded.graph.lint()
Michael Suoecf3ca02021-02-23 13:33:22 -08001433 input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
1434 offsets = torch.LongTensor([0, 4])
1435 self.assertEqual(loaded(input, offsets), traced(input, offsets))
1436
Michael Suo958d9a82021-02-23 22:37:14 -08001437 def test_return_tuple(self):
1438 class M(torch.nn.Module):
1439 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
1440 return (x, x + x)
1441
Shen Li10224432021-08-12 11:39:31 -07001442
Michael Suo958d9a82021-02-23 22:37:14 -08001443 original = M()
1444 traced = symbolic_trace(original)
1445 self.assertEqual(traced(torch.ones(1)), original.forward(torch.ones(1)))
1446
James Reede9c64492020-09-16 12:21:01 -07001447 def test_construct_root_dict(self):
Shen Li10224432021-08-12 11:39:31 -07001448 graph : torch.fx.Graph = torch.fx.Graph()
1449 a : torch.fx.Node = graph.create_node('placeholder', 'x')
1450 b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
1451 c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
1452 d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
James Reede9c64492020-09-16 12:21:01 -07001453 graph.output(d)
1454
Shen Li10224432021-08-12 11:39:31 -07001455 linear_mod : torch.nn.Module = torch.nn.Linear(3, 4)
1456 add_param : torch.Tensor = torch.rand(3, 4)
1457 gm : torch.fx.GraphModule = torch.fx.GraphModule(
1458 {'foo.bar.baz': linear_mod, 'zip.zap.zam' : add_param}, graph)
Ansley Ussery85109ce2021-03-04 14:50:34 -08001459 gm.graph.lint()
James Reede9c64492020-09-16 12:21:01 -07001460
Shen Li10224432021-08-12 11:39:31 -07001461 assert 'self.foo.bar.baz' in gm.code
James Reede9c64492020-09-16 12:21:01 -07001462
Shen Li10224432021-08-12 11:39:31 -07001463 x : torch.Tensor = torch.rand(3, 3)
1464 out : torch.Tensor = gm(x)
1465 ref_out : torch.Tensor = linear_mod(x) + add_param
James Reede9c64492020-09-16 12:21:01 -07001466 self.assertEqual(out, ref_out)
1467
Vasiliy Kuznetsoveee7dad2020-09-25 13:44:53 -07001468 def test_symbolic_trace_assert(self):
Shen Li10224432021-08-12 11:39:31 -07001469
Vasiliy Kuznetsoveee7dad2020-09-25 13:44:53 -07001470 class AssertsTensorShape(torch.nn.Module):
1471 def forward(self, x):
Vasiliy Kuznetsovdea23372020-11-16 11:39:33 -08001472 torch._assert(x.shape[1] > 4, "assert_foobar")
Vasiliy Kuznetsoveee7dad2020-09-25 13:44:53 -07001473 return x
1474
1475 m = AssertsTensorShape()
1476 # verify traceability
1477 traced = symbolic_trace(m)
1478 # verify assertion on traced model works correctly at runtime
1479 traced(torch.rand(4, 5))
Vasiliy Kuznetsovdea23372020-11-16 11:39:33 -08001480 with self.assertRaisesRegex(AssertionError, "assert_foobar"):
Vasiliy Kuznetsoveee7dad2020-09-25 13:44:53 -07001481 traced(torch.rand(4, 3))
Vasiliy Kuznetsovdea23372020-11-16 11:39:33 -08001482 # verify the symbolically traced module is scriptable
1483 ms = torch.jit.script(m)
1484 with self.assertRaisesRegex(torch.jit.Error, "assert_foobar"):
1485 ms(torch.rand(4, 3))
1486
Bradley Davis8880f3d2021-07-21 11:16:14 -07001487 def test_fx_create_arg(self):
1488 class CustomArgObject:
1489 def __init__(self, x, y):
1490 self.x = x
1491 self.y = y
1492
1493 def __fx_create_arg__(self, tracer: torch.fx.Tracer):
1494 return tracer.create_node(
1495 "call_function",
1496 CustomArgObject,
1497 args=(
1498 tracer.create_arg(self.x),
1499 tracer.create_arg(self.y),
1500 ),
1501 kwargs={},
1502 )
1503
1504 class HasCustomArgObjectWhenLeaf(torch.nn.Module):
1505 def forward(self, o: CustomArgObject):
1506 # Not normally traceable; good reason to make
1507 # this module a leaf.
1508 for x in o.x:
1509 o.y += x
1510 return o.y
1511
1512 class Root(torch.nn.Module):
1513 def __init__(self):
1514 super().__init__()
1515 self.inner = HasCustomArgObjectWhenLeaf()
1516
1517 def forward(self, x, y):
1518 o = CustomArgObject(x, y)
1519 return self.inner(o)
1520
1521 class CreateArgTracer(torch.fx.Tracer):
1522 def is_leaf_module(self, m, module_qualified_name):
1523 return type(m) is HasCustomArgObjectWhenLeaf
1524
1525 m = Root()
1526 graph = CreateArgTracer().trace(m)
1527 gm = torch.fx.GraphModule(m, graph)
1528 assert "CustomArgObject(" in gm.code
1529
James Reed38c45bd2021-01-19 19:17:58 -08001530 def test_trace_fn_constant(self):
1531 some_constant = torch.rand(3, 4)
1532
1533 def add_const(x):
1534 return some_constant + x
1535
1536 traced = symbolic_trace(add_const)
1537
1538 input = torch.rand(3, 4)
1539 self.assertEqual(traced(input), add_const(input))
Vasiliy Kuznetsoveee7dad2020-09-25 13:44:53 -07001540
James Reed6bdb8712020-09-28 22:50:49 -07001541 def test_copy_no_remap(self):
1542 traced = symbolic_trace(SimpleTest())
1543 g = traced.graph
1544 copied = torch.fx.Graph()
1545 for node in g.nodes:
1546 copied.node_copy(node)
Shen Li10224432021-08-12 11:39:31 -07001547 with self.assertRaisesRegex(RuntimeError, 'does not belong to this Graph'):
James Reed6bdb8712020-09-28 22:50:49 -07001548 copied.lint()
1549
1550 def test_wrong_topo(self):
Shen Li10224432021-08-12 11:39:31 -07001551 graph : torch.fx.Graph = torch.fx.Graph()
1552 a : torch.fx.Node = graph.create_node('placeholder', 'x')
1553 b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
1554 c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
1555 d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
James Reed6bdb8712020-09-28 22:50:49 -07001556 graph.output(d)
Zachary DeVito88dcb952020-10-12 18:18:06 -07001557 nodes = list(graph.nodes)
1558 nodes[3].append(nodes[2])
Shen Li10224432021-08-12 11:39:31 -07001559 with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'):
James Reed6bdb8712020-09-28 22:50:49 -07001560 graph.lint()
1561
Patrick Hu18cb3fc2021-08-27 13:37:38 -07001562 def test_wrong_target_type(self):
1563 graph : torch.fx.Graph = torch.fx.Graph()
1564 with self.assertRaises(ValueError):
1565 n = torch.fx.Node(graph=graph, name='foo', op='call_function', target='foo',
1566 args=(), kwargs={})
1567
James Reed78b95b62020-10-01 01:05:40 -07001568 def test_example_shape_prop(self):
1569 class TestCase(torch.nn.Module):
1570 def __init__(self):
1571 super().__init__()
1572 self.attr = torch.randn(3, 4)
1573 self.submod = torch.nn.Linear(4, 4)
1574
1575 def forward(self, x):
1576 return torch.neg(self.submod(x.relu() + self.attr))
1577 tc = TestCase()
1578 tc_traced = symbolic_trace(tc)
1579 ref_out = tc_traced(torch.rand(3, 4))
James Reed53aea602020-10-02 17:05:42 -07001580 shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4))
James Reed78b95b62020-10-01 01:05:40 -07001581
1582 # Make sure we're testing all opcodes
1583 opcodes = set()
Shen Li10224432021-08-12 11:39:31 -07001584 output_shape : Optional[torch.Shape] = None
1585 output_stride : Optional[Tuple[int]] = None
James Reed78b95b62020-10-01 01:05:40 -07001586 for node in tc_traced.graph.nodes:
1587 opcodes.add(node.op)
Shen Li10224432021-08-12 11:39:31 -07001588 if node.op == 'output':
1589 output_shape = node.args[0].meta['tensor_meta'].shape
1590 output_stride = node.args[0].meta['tensor_meta'].stride
1591 self.assertEqual(opcodes, set(['placeholder', 'get_attr', 'call_function', 'call_method',
1592 'call_module', 'output']))
James Reed78b95b62020-10-01 01:05:40 -07001593
Tim Gates3a87b472022-07-14 04:20:26 +00001594 # Test shape propagation and make sure results match actual
James Reed53aea602020-10-02 17:05:42 -07001595 self.assertEqual(output_shape, ref_out.shape)
James Reed641d4ff2021-04-02 19:55:49 -07001596 self.assertEqual(output_stride, ref_out.stride())
James Reed78b95b62020-10-01 01:05:40 -07001597
James Reed8bdea142021-04-13 16:36:42 -07001598 def test_shape_prop_layout(self):
1599 class ConvTest(torch.nn.Module):
1600 def __init__(self):
1601 super().__init__()
1602 self.conv_mod = torch.nn.Conv2d(5, 5, 3)
1603
1604 def forward(self, x):
1605 return self.conv_mod(x)
1606
1607 # contiguous layout
1608 test_mod = ConvTest()
1609 traced = symbolic_trace(test_mod)
1610 x = torch.randn(5, 5, 224, 224)
1611 shape_prop.ShapeProp(traced).propagate(x)
1612
Shen Li10224432021-08-12 11:39:31 -07001613 assert(all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
1614 for node in traced.graph.nodes))
James Reed8bdea142021-04-13 16:36:42 -07001615
1616 x_channels_last = x.contiguous(memory_format=torch.channels_last)
1617 traced.to(memory_format=torch.channels_last)
1618 shape_prop.ShapeProp(traced).propagate(x_channels_last)
1619 for node in traced.graph.nodes:
1620 # NB: the implementation of conv may not preserve the memory format,
1621 # unfortunately. The best we can do is just check that the placeholder
1622 # node is channels-last
Shen Li10224432021-08-12 11:39:31 -07001623 if node.op in {'placeholder'}:
1624 self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last)
James Reed8bdea142021-04-13 16:36:42 -07001625
James Reedd02919d2021-04-16 18:57:00 -07001626 def test_shape_prop_aggregate(self):
1627 class ReturnTwo(torch.nn.Module):
1628 def forward(self, x):
1629 return (3, torch.sum(x))
1630
1631 class UnderTest(torch.nn.Module):
1632 def __init__(self):
1633 super().__init__()
1634 self.rt = ReturnTwo()
1635
1636 def forward(self, x):
1637 return self.rt(x)
1638
1639 ut = UnderTest()
1640
1641 class RTTracer(torch.fx.Tracer):
1642 def is_leaf_module(self, m, module_qualified_name):
1643 return type(m) is ReturnTwo
1644
1645 graph = RTTracer().trace(ut)
1646 mod = torch.fx.GraphModule(ut, graph)
1647
1648 shape_prop.ShapeProp(mod).propagate(torch.rand(3, 4))
1649
1650 for node in mod.graph.nodes:
Shen Li10224432021-08-12 11:39:31 -07001651 if node.op == 'call_module':
1652 assert 'tensor_meta' in node.meta
1653 tensor_meta = node.meta['tensor_meta']
James Reedd02919d2021-04-16 18:57:00 -07001654 assert tensor_meta[0] == 3
1655 assert tensor_meta[1].shape == torch.Size([])
James Reed8bdea142021-04-13 16:36:42 -07001656
1657 def test_shape_prop_layout_3d(self):
1658 class ConvTest3d(torch.nn.Module):
1659 def __init__(self):
1660 super().__init__()
1661 self.conv_mod = torch.nn.Conv3d(5, 5, 3)
1662
1663 def forward(self, x):
1664 return self.conv_mod(x)
1665
1666 test_mod_3d = ConvTest3d()
1667 traced_3d = symbolic_trace(test_mod_3d)
1668 x_3d = torch.randn(5, 5, 224, 224, 15)
1669 shape_prop.ShapeProp(traced_3d).propagate(x_3d)
Shen Li10224432021-08-12 11:39:31 -07001670 assert(all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
1671 for node in traced_3d.graph.nodes))
James Reed8bdea142021-04-13 16:36:42 -07001672
1673 x_channels_last_3d = x_3d.contiguous(memory_format=torch.channels_last_3d)
1674 traced_3d.to(memory_format=torch.channels_last_3d)
1675 shape_prop.ShapeProp(traced_3d).propagate(x_channels_last_3d)
1676 for node in traced_3d.graph.nodes:
1677 # NB: the implementation of conv may not preserve the memory format,
1678 # unfortunately. The best we can do is just check that the placeholder
1679 # node is channels-last
Shen Li10224432021-08-12 11:39:31 -07001680 if node.op in {'placeholder'}:
1681 self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last_3d)
James Reed8bdea142021-04-13 16:36:42 -07001682
James Reed609f76f2021-02-01 11:34:54 -08001683 def test_interpreter(self):
1684 class MyModule(torch.nn.Module):
1685 def __init__(self):
1686 super().__init__()
1687 self.param = torch.nn.Parameter(torch.rand(3, 4))
1688 self.linear = torch.nn.Linear(4, 5)
1689
1690 def forward(self, x):
1691 return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1692
1693 m = MyModule()
1694 gm = torch.fx.symbolic_trace(m)
1695
1696 interpreter = Interpreter(gm)
1697 input = torch.randn(3, 4)
1698 self.assertEqual(interpreter.run(input), gm(input))
1699 self.assertEqual(interpreter.run(input), m(input))
1700
1701 def test_interpreter_run_node_override(self):
1702 class MyModule(torch.nn.Module):
1703 def __init__(self):
1704 super().__init__()
1705 self.param = torch.nn.Parameter(torch.rand(3, 4))
1706 self.linear = torch.nn.Linear(4, 5)
1707
1708 def forward(self, x):
1709 return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1710
1711 m = MyModule()
1712 gm = torch.fx.symbolic_trace(m)
1713
1714 class RunNodeInterpreter(Interpreter):
1715 def __init__(self, module):
1716 super().__init__(module)
1717
Shen Li10224432021-08-12 11:39:31 -07001718 def run_node(self, n : Node) -> Any:
James Reed609f76f2021-02-01 11:34:54 -08001719 result = super().run_node(n)
1720 n.cached_value = result
1721 return result
1722
1723 input = torch.randn(3, 4)
1724 RunNodeInterpreter(gm).run(input)
1725 for node in gm.graph.nodes:
Shen Li10224432021-08-12 11:39:31 -07001726 assert hasattr(node, 'cached_value')
James Reed609f76f2021-02-01 11:34:54 -08001727
1728 def test_interpreter_onthefly_swap(self):
Shen Li10224432021-08-12 11:39:31 -07001729
James Reed609f76f2021-02-01 11:34:54 -08001730 def fn(x):
1731 return torch.sigmoid(x).neg()
1732
1733 gm = torch.fx.symbolic_trace(fn)
1734
1735 class NegSigmSwapInterpreter(Interpreter):
Shen Li10224432021-08-12 11:39:31 -07001736 def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
James Reed609f76f2021-02-01 11:34:54 -08001737 if target == torch.sigmoid:
1738 return torch.neg(*args, **kwargs)
1739 return super().call_function(n)
1740
Shen Li10224432021-08-12 11:39:31 -07001741 def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1742 if target == 'neg':
James Reed609f76f2021-02-01 11:34:54 -08001743 call_self, *args_tail = args
1744 return call_self.sigmoid(*args_tail, **kwargs)
1745 return super().call_method(n)
1746
1747 input = torch.randn(3, 4)
1748 result = NegSigmSwapInterpreter(gm).run(input)
1749 self.assertEqual(result, torch.neg(input).sigmoid())
1750
1751 def test_interpreter_partial_eval(self):
1752 class MyModule(torch.nn.Module):
1753 def __init__(self):
1754 super().__init__()
1755 self.param = torch.nn.Parameter(torch.rand(3, 4))
1756 self.linear = torch.nn.Linear(4, 5)
1757
1758 def forward(self, x):
1759 return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1760
1761 gm = torch.fx.symbolic_trace(MyModule())
1762 interp = Interpreter(gm)
1763 env = {}
1764 for node in gm.graph.nodes:
Shen Li10224432021-08-12 11:39:31 -07001765 if node.op == 'call_module' and node.target == 'linear':
James Reed609f76f2021-02-01 11:34:54 -08001766 env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0
1767 break
1768 assert len(env) == 1
1769 x = torch.randn(3, 4)
1770 result = interp.run(x, initial_env=env)
Shen Li10224432021-08-12 11:39:31 -07001771 self.assertEqual(result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0))
James Reed609f76f2021-02-01 11:34:54 -08001772
1773 def test_interpreter_star_args(self):
1774 def with_star_args(x, *args):
1775 return x + args[0]
1776
1777 gm = torch.fx.symbolic_trace(with_star_args)
1778 interp = Interpreter(gm)
1779 result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4))
1780 self.assertEqual(result, torch.ones(3, 4) * 2.0)
1781
James Reedd4e84b02021-02-09 20:18:47 -08001782 @skipIfNoTorchVision
1783 def test_interpreter_noop_resnet18(self):
Suraj Subramanian78022aa2021-04-22 08:52:45 -07001784 rn18 = torchvision_models.resnet18()
James Reedd4e84b02021-02-09 20:18:47 -08001785 transformed = torch.fx.Transformer(symbolic_trace(rn18)).transform()
1786 inp = torch.randn(5, 3, 224, 224)
1787 self.assertEqual(transformed(inp), rn18(inp))
1788
James Reeda28c7db2021-03-25 20:33:23 -07001789 @skipIfNoTorchVision
1790 def test_interpreter_gc_values(self):
Suraj Subramanian78022aa2021-04-22 08:52:45 -07001791 rn18 = torchvision_models.resnet18()
James Reeda28c7db2021-03-25 20:33:23 -07001792 interp = Interpreter(symbolic_trace(rn18))
1793 inp = torch.rand(5, 3, 224, 224)
1794 out = interp.run(inp)
1795 env_key_names = set(n.name for n in interp.env.keys())
Shen Li10224432021-08-12 11:39:31 -07001796 self.assertEqual(env_key_names, set(['output']))
James Reeda28c7db2021-03-25 20:33:23 -07001797
James Reed3f6643e2022-02-03 17:38:53 -08001798 def test_interpreter_default_args(self):
1799 class Model(torch.nn.Module):
1800 def forward(self, x, y=3.14159):
1801 return x + y
1802
1803 model = Model()
1804 gm = torch.fx.symbolic_trace(model)
1805
1806 interp = Interpreter(gm)
1807 x = torch.randn(5, 3)
1808 out = interp.run(x)
1809 torch.testing.assert_allclose(out, x + 3.14159)
1810
1811 def test_interpreter_not_enough_args(self):
1812 class Model(torch.nn.Module):
1813 def forward(self, x, y):
1814 return x + y
1815
1816 model = Model()
1817 gm = torch.fx.symbolic_trace(model)
1818
1819 interp = Interpreter(gm)
1820 x = torch.randn(5, 3)
1821 with self.assertRaisesRegex(RuntimeError,
1822 'Expected positional argument for parameter y, but one was not passed in'):
1823 out = interp.run(x)
1824
James Reed609f76f2021-02-01 11:34:54 -08001825 def test_transformer_noop(self):
1826 class MyModule(torch.nn.Module):
1827 def __init__(self):
1828 super().__init__()
1829 self.param = torch.nn.Parameter(torch.rand(3, 4))
1830 self.linear = torch.nn.Linear(4, 5)
1831
1832 def forward(self, x):
1833 return self.linear(x + self.param).clamp(min=0.0, max=1.0)
1834
1835 m = MyModule()
1836 gm = torch.fx.symbolic_trace(m)
1837
1838 new_gm = Transformer(gm).transform()
1839
1840 input = torch.randn(3, 4)
1841 self.assertEqual(new_gm(input), gm(input))
1842
1843 def test_transformer_op_swap(self):
Shen Li10224432021-08-12 11:39:31 -07001844
James Reed609f76f2021-02-01 11:34:54 -08001845 def fn(x):
1846 return torch.sigmoid(x).neg()
1847
1848 gm = torch.fx.symbolic_trace(fn)
1849
1850 class NegSigmSwapXformer(Transformer):
Shen Li10224432021-08-12 11:39:31 -07001851 def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
James Reed609f76f2021-02-01 11:34:54 -08001852 if target == torch.sigmoid:
1853 return torch.neg(*args, **kwargs)
1854 return super().call_function(n)
1855
Shen Li10224432021-08-12 11:39:31 -07001856 def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
1857 if target == 'neg':
James Reed609f76f2021-02-01 11:34:54 -08001858 call_self, *args_tail = args
1859 return call_self.sigmoid(*args_tail, **kwargs)
1860 return super().call_method(n)
1861
1862 transformed = NegSigmSwapXformer(gm).transform()
1863 input = torch.randn(3, 4)
1864 self.assertEqual(transformed(input), torch.neg(input).sigmoid())
1865
Shiyan Deng238b0bb2021-02-23 19:22:30 -08001866 def test_transformer_multi_outputs(self):
1867 class MyModule(torch.nn.Module):
1868 def __init__(self):
1869 super().__init__()
1870 self.param = torch.nn.Parameter(torch.rand(3, 4))
1871 self.linear = torch.nn.Linear(4, 5)
1872
1873 def forward(self, x):
1874 x = x + self.param
1875 out = self.linear(x)
1876 return x, out
1877
1878 m = MyModule()
1879 gm = torch.fx.symbolic_trace(m)
1880
1881 new_gm = Transformer(gm).transform()
1882
1883 input = torch.randn(3, 4)
1884 self.assertEqual(new_gm(input), gm(input))
1885
James Reed00b8ebe2020-10-07 21:32:51 -07001886 def test_fn_type_annotations(self):
1887 class Foo(torch.nn.Module):
Shen Li10224432021-08-12 11:39:31 -07001888 def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]:
1889 return {'a': p.x + p.y + z + i}
James Reed00b8ebe2020-10-07 21:32:51 -07001890
1891 foo_scripted = torch.jit.script(Foo())
1892 foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
1893
1894 fxed = symbolic_trace(Foo())
1895 fxed_scripted = torch.jit.script(fxed)
1896 fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
1897
Zachary DeVito70d34712020-11-09 10:33:08 -08001898 def test_fn_type_annotation_empty(self):
Shen Li10224432021-08-12 11:39:31 -07001899 def forward(a : List[torch.Tensor]):
Zachary DeVito70d34712020-11-09 10:33:08 -08001900 return a[0]
1901 torch.jit.script(symbolic_trace(forward))
1902
James Reed9ccf85b2020-10-22 11:52:31 -07001903 def test_wrapped_method(self):
1904 def wrap_with_relu(fn):
1905 @functools.wraps(fn)
1906 def wrapper(*args, **kwargs):
1907 return torch.relu(fn(*args, **kwargs))
1908 return wrapper
1909
1910 class Foo(torch.nn.Module):
1911 @wrap_with_relu
1912 def forward(self, x, w):
1913 return torch.matmul(x, w)
1914
1915 f = Foo()
1916 traced = symbolic_trace(f)
1917 x, w = torch.rand(3, 4), torch.rand(4, 4)
1918 self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes))
1919
James Reedeb8003d2021-01-06 15:43:37 -08001920 def test_empty_graph_codegen(self):
1921 graph = torch.fx.Graph()
1922 gm = torch.fx.GraphModule(torch.nn.Module(), graph)
1923 self.assertEqual(gm(), None)
1924
James Reed069232a2020-10-28 10:20:04 -07001925 def test_sequential(self):
1926 m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1))
1927 gm = torch.fx.symbolic_trace(m)
1928 gm_copy = copy.deepcopy(gm)
1929
James Reed9ccf85b2020-10-22 11:52:31 -07001930 def test_ctx_mgr(self):
1931 @contextlib.contextmanager
1932 def do_nothing():
1933 yield
1934
1935 class M(torch.nn.Module):
1936 def __init__(self):
1937 super().__init__()
1938
1939 @do_nothing()
1940 def forward(self, x):
1941 return torch.relu(x)
1942
1943 m = M()
1944 self.checkGraphModule(m, (torch.rand(3, 4),))
1945
James Reed00b8ebe2020-10-07 21:32:51 -07001946 def test_typename_print(self):
Shen Li10224432021-08-12 11:39:31 -07001947 graph : torch.fx.Graph = torch.fx.Graph()
1948 x : torch.fx.Node = graph.create_node('placeholder', 'x')
1949 b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,),
1950 type_expr=List[float])
1951 output : torch.fx.Node = graph.output(b)
Ansley Ussery5268b5a2021-05-25 12:11:15 -07001952
Shen Li10224432021-08-12 11:39:31 -07001953 self.assertTrue('typing.List[float]' in str(graph))
James Reed00b8ebe2020-10-07 21:32:51 -07001954
Yinghai Lu6b0aa292021-10-04 19:55:42 -07001955 def test_layout(self):
1956 class M(torch.nn.Module):
1957 def __init__(self):
1958 super().__init__()
1959
1960 def forward(self, x):
1961 return torch.empty_like(x, layout=torch.strided, pin_memory=False).fill_(0)
1962
1963 traced = symbolic_trace(M())
1964 x = torch.rand(5, 9, 3, 4)
1965 self.assertEqual(traced(x), torch.zeros_like(x))
1966
James Reeda3353d12021-02-01 18:49:56 -08001967 def test_ellipsis(self):
1968 class M(torch.nn.Module):
1969 def __init__(self):
1970 super().__init__()
1971
1972 def forward(self, x, y):
1973 return x + y[:, 1:10, ...]
1974
1975 traced = symbolic_trace(M())
1976 x, y = torch.rand(5, 9, 3, 4), torch.rand(5, 15, 3, 4)
1977 self.assertEqual(traced(x, y), x + y[:, 1:10, ...])
1978
James Reed67c1dc62020-10-27 17:49:29 -07001979 def test_inf_nan(self):
1980 class FooMod(torch.nn.Module):
1981 def forward(self, x):
Shen Li10224432021-08-12 11:39:31 -07001982 return x + float('inf'), x + float('-inf'), x + float('nan')
James Reed67c1dc62020-10-27 17:49:29 -07001983
1984 fm = FooMod()
1985 self.checkGraphModule(fm, (torch.rand(3, 4),))
1986
1987 def test_inf_nan_kwds(self):
Shen Li10224432021-08-12 11:39:31 -07001988 graph : torch.fx.Graph = torch.fx.Graph()
1989 x : torch.fx.Node = graph.create_node('placeholder', 'x')
1990 b : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('inf')), {}, name='inf')
1991 c : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('nan')), {}, name='nan')
James Reed67c1dc62020-10-27 17:49:29 -07001992 graph.output((b, c))
1993
1994 gm = torch.fx.GraphModule(torch.nn.Module(), graph)
1995 x = torch.rand(3, 4)
Shen Li10224432021-08-12 11:39:31 -07001996 self.assertEqual(gm(x), (x + float('inf'), x + float('nan')))
James Reed67c1dc62020-10-27 17:49:29 -07001997
James Reed27009322020-10-22 11:52:31 -07001998 def test_deepcopy_recursion_depth(self):
1999 depth = sys.getrecursionlimit() + 20
2000
2001 g = torch.fx.Graph()
Shen Li10224432021-08-12 11:39:31 -07002002 x = g.placeholder('x')
James Reed27009322020-10-22 11:52:31 -07002003 for i in range(depth):
2004 x = g.call_function(torch.relu, (x,))
2005 g.output(x)
2006
2007 copied_graph = copy.deepcopy(g)
2008
2009 val_map = {}
2010 for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
2011 val_map[orig_node] = new_node
2012
2013 for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
2014 orig_users = set(orig_node.users.keys())
2015 orig_users_equiv = set(val_map[u] for u in orig_users)
2016 new_users = set(new_node.users.keys())
2017 self.assertEqual(orig_users_equiv, new_users)
2018
James Reedb04ae952020-10-05 17:05:07 -07002019 @skipIfNoTorchVision
2020 def test_replace_uses(self):
Suraj Subramanian78022aa2021-04-22 08:52:45 -07002021 rn18 = torchvision_models.resnet18()
James Reedb04ae952020-10-05 17:05:07 -07002022
2023 class LowerReluTracer(torch.fx.Tracer):
Shen Li10224432021-08-12 11:39:31 -07002024 def is_leaf_module(self, m : torch.nn.Module, qualname : str):
James Reedb04ae952020-10-05 17:05:07 -07002025 if isinstance(m, torch.nn.ReLU):
2026 return False
2027 return super().is_leaf_module(m, qualname)
2028
2029 rn18_traced = GraphModule(rn18, LowerReluTracer().trace(rn18))
2030
2031 to_erase = []
2032 for node in rn18_traced.graph.nodes:
Shen Li10224432021-08-12 11:39:31 -07002033 if node.op == 'call_function' and node.target in [torch.relu, torch.nn.functional.relu]:
Zachary DeVitofc1d6bf2020-10-14 15:49:30 -07002034 kwargs = node.kwargs.copy()
James Reedb04ae952020-10-05 17:05:07 -07002035 # Neg doesn't have in-place
Shen Li10224432021-08-12 11:39:31 -07002036 kwargs.pop('inplace')
Zachary DeVito88dcb952020-10-12 18:18:06 -07002037 with rn18_traced.graph.inserting_before(node):
James Reedb04ae952020-10-05 17:05:07 -07002038 new_node = rn18_traced.graph.call_function(
Shen Li10224432021-08-12 11:39:31 -07002039 the_function=torch.neg, args=node.args, kwargs=node.kwargs)
James Reedb04ae952020-10-05 17:05:07 -07002040 node.replace_all_uses_with(replace_with=new_node)
2041 to_erase.append(node)
2042
2043 for node in to_erase:
2044 rn18_traced.graph.erase_node(node)
2045
Shen Li10224432021-08-12 11:39:31 -07002046
Allen (Congcong) Chen798dd462021-04-23 11:35:55 -07002047 def test_replace_input(self):
Shen Li10224432021-08-12 11:39:31 -07002048 graph : torch.fx.Graph = torch.fx.Graph()
2049 x : torch.fx.Node = graph.create_node('placeholder', 'x')
2050 y : torch.fx.Node = graph.create_node('placeholder', 'y')
2051 b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2052 output : torch.fx.Node = graph.output(b)
Allen (Congcong) Chen798dd462021-04-23 11:35:55 -07002053
2054 b.replace_input_with(x, y)
2055
2056 gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2057
2058 input_x = torch.randn(33, 44)
2059 input_y = torch.randn(11, 22)
2060 self.assertEqual(gm(input_x, input_y), torch.relu(input_y))
2061
James Reedb04ae952020-10-05 17:05:07 -07002062 def test_insertion_point(self):
Shen Li10224432021-08-12 11:39:31 -07002063 graph : torch.fx.Graph = torch.fx.Graph()
2064 x : torch.fx.Node = graph.create_node('placeholder', 'x')
2065 b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2066 output : torch.fx.Node = graph.output(b)
James Reedb04ae952020-10-05 17:05:07 -07002067
Zachary DeVito88dcb952020-10-12 18:18:06 -07002068 with graph.inserting_before(b):
Shen Li10224432021-08-12 11:39:31 -07002069 neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
James Reedb04ae952020-10-05 17:05:07 -07002070 _, *relu_args = b.args
2071 b.args = (neg, *relu_args)
2072
2073 gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2074
2075 input = torch.randn(33, 44)
2076 self.assertEqual(gm(input), torch.relu(torch.neg(input)))
2077
James Reed36adc3f2021-05-19 14:52:41 -07002078 def test_update_args_api(self):
Shen Li10224432021-08-12 11:39:31 -07002079 graph : torch.fx.Graph = torch.fx.Graph()
2080 x : torch.fx.Node = graph.create_node('placeholder', 'x')
2081 y : torch.fx.Node = graph.create_node('placeholder', 'y')
2082 b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2083 output : torch.fx.Node = graph.output(b)
James Reed36adc3f2021-05-19 14:52:41 -07002084
2085 orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2086 inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
2087 self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
2088
Shen Li10224432021-08-12 11:39:31 -07002089
James Reed36adc3f2021-05-19 14:52:41 -07002090 b.update_arg(0, y)
2091 new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2092 self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
2093
2094 def test_update_kwargs_api(self):
Shen Li10224432021-08-12 11:39:31 -07002095 graph : torch.fx.Graph = torch.fx.Graph()
2096 x : torch.fx.Node = graph.create_node('placeholder', 'x')
2097 y : torch.fx.Node = graph.create_node('placeholder', 'y')
2098 b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, kwargs={'input': x})
2099 output : torch.fx.Node = graph.output(b)
James Reed36adc3f2021-05-19 14:52:41 -07002100
2101 orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2102 inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
2103 self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
2104
Shen Li10224432021-08-12 11:39:31 -07002105
2106 b.update_kwarg('input', y)
James Reed36adc3f2021-05-19 14:52:41 -07002107 new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2108 self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
James Reedb04ae952020-10-05 17:05:07 -07002109
James Reeda8d9fbb2022-03-04 10:34:25 -08002110 def test_immutable_list_pytree_ops(self):
2111 rand_tensor = torch.randn(5, 3)
2112 l = immutable_list([3, [rand_tensor, 42]])
2113
2114 flattened, spec = pytree.tree_flatten(l)
2115 assert flattened == [3, rand_tensor, 42]
2116
2117 unflattened = pytree.tree_unflatten(flattened, spec)
2118 assert unflattened == l
2119 assert isinstance(unflattened, immutable_list)
2120
2121 def test_immutable_dict_pytree_ops(self):
2122 rand_tensor = torch.randn(5, 3)
2123 d = immutable_dict({'a': 3, 'b': [rand_tensor, 42]})
2124
2125 flattened, spec = pytree.tree_flatten(d)
2126 assert flattened == [3, rand_tensor, 42]
2127
2128 unflattened = pytree.tree_unflatten(flattened, spec)
2129 assert unflattened == d
2130 assert isinstance(unflattened, immutable_dict)
2131
James Reedb04ae952020-10-05 17:05:07 -07002132 def test_move_before(self):
Shen Li10224432021-08-12 11:39:31 -07002133 graph : torch.fx.Graph = torch.fx.Graph()
2134 x : torch.fx.Node = graph.create_node('placeholder', 'x')
2135 b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2136 output : torch.fx.Node = graph.output(b)
James Reedb04ae952020-10-05 17:05:07 -07002137
Shen Li10224432021-08-12 11:39:31 -07002138 neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
James Reedb04ae952020-10-05 17:05:07 -07002139 _, *relu_args = b.args
2140 b.args = (neg, *relu_args)
Zachary DeVito88dcb952020-10-12 18:18:06 -07002141 b.prepend(neg)
James Reedb04ae952020-10-05 17:05:07 -07002142
2143 gm = torch.fx.GraphModule(torch.nn.Module(), graph)
2144
2145 input = torch.randn(33, 44)
2146 self.assertEqual(gm(input), torch.relu(torch.neg(input)))
2147
Shiyan Deng4b9464f2021-10-27 10:48:30 -07002148 def test_prepend_self(self):
2149 graph : torch.fx.Graph = torch.fx.Graph()
2150 x : torch.fx.Node = graph.create_node('placeholder', 'x')
2151 b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
2152 output : torch.fx.Node = graph.output(b)
2153
2154 b.prepend(b)
2155 x.append(b)
2156 self.assertEqual(len(graph.nodes), 3)
2157
James Reedb04ae952020-10-05 17:05:07 -07002158 def test_erase_node_error(self):
2159 st = SimpleTest()
2160 traced = symbolic_trace(st)
2161
2162 for node in traced.graph.nodes:
2163 # Test deleting with uses both in another Node and at the output
2164 if node.target in [operator.add, torch.relu]:
Shen Li10224432021-08-12 11:39:31 -07002165 with self.assertRaisesRegex(RuntimeError, 'but it still had .* users in the graph'):
James Reedb04ae952020-10-05 17:05:07 -07002166 traced.graph.erase_node(node)
2167
Zachary DeVitofc1d6bf2020-10-14 15:49:30 -07002168 def test_copy_it(self):
2169 d = immutable_dict([(3, 4), (5, 6)])
2170 l = immutable_list([(3, 4), (5, 6)])
2171
2172 self.assertEqual(d, deepcopy(d))
2173 self.assertEqual(l, deepcopy(l))
2174
James Reed255b1032021-03-17 20:39:16 -07002175 def test_get_torch_func_signature(self):
2176 for key in dir(torch):
2177 obj = getattr(torch, key)
2178 if callable(obj):
2179 schemas = get_signature_for_torch_op(obj)
2180
James Reedb04ae952020-10-05 17:05:07 -07002181 def test_find_uses(self):
2182 graph = torch.fx.Graph()
Shen Li10224432021-08-12 11:39:31 -07002183 x = torch.fx.Proxy(graph.placeholder('x'))
James Reedb04ae952020-10-05 17:05:07 -07002184
2185 y = torch.relu(x)
2186 z = x + x
2187 u = torch.neg(x)
2188 graph.output((y + z + u).node)
2189 graph.lint()
2190
James Reed8cdb6382020-10-07 00:13:34 -07002191 users_of_x = x.node.users
2192 self.assertEqual(len(users_of_x), 3)
Shen Li10224432021-08-12 11:39:31 -07002193 expected_ops = set(['relu', 'add', 'neg'])
James Reed8cdb6382020-10-07 00:13:34 -07002194 for use in users_of_x:
2195 assert any(use.name.startswith(prefix) for prefix in expected_ops)
2196
James Reedc73af602020-10-09 16:33:21 -07002197 def test_inline_graph(self):
2198 class InlineInto(torch.nn.Module):
2199 def forward(self, x):
2200 return torch.relu(x)
2201
2202 class ToInline(torch.nn.Module):
2203 def forward(self, x):
2204 return torch.neg(x)
2205
2206 inline_into = symbolic_trace(InlineInto())
2207 to_inline = symbolic_trace(ToInline())
2208
2209 combined_graph = torch.fx.Graph()
2210 output_node = combined_graph.graph_copy(inline_into.graph, {})
2211
Zachary DeVito88dcb952020-10-12 18:18:06 -07002212 input_node = list(to_inline.graph.nodes)[0]
Shen Li10224432021-08-12 11:39:31 -07002213 assert input_node and input_node.op == 'placeholder'
James Reedc73af602020-10-09 16:33:21 -07002214
Shen Li10224432021-08-12 11:39:31 -07002215 val_map = {input_node : output_node}
James Reedc73af602020-10-09 16:33:21 -07002216 output = combined_graph.graph_copy(to_inline.graph, val_map)
2217 combined_graph.output(output)
2218
2219 combined_module = torch.fx.GraphModule(torch.nn.Module(), combined_graph)
2220
2221 input = torch.rand(3, 4)
2222 self.assertEqual(combined_module(input), input.relu().neg())
James Reedb04ae952020-10-05 17:05:07 -07002223
2224 def test_multi_insert_point(self):
2225 graph = torch.fx.Graph()
Shen Li10224432021-08-12 11:39:31 -07002226 x = torch.fx.Proxy(graph.placeholder('x'))
James Reedb04ae952020-10-05 17:05:07 -07002227 relu = torch.relu(x)
2228
Zachary DeVito88dcb952020-10-12 18:18:06 -07002229 with graph.inserting_before(relu.node):
James Reedb04ae952020-10-05 17:05:07 -07002230 y = torch.neg(x)
2231 z = torch.tanh(y)
2232
2233 graph.output((relu.node, z.node))
2234 graph.lint()
2235
Shen Li10224432021-08-12 11:39:31 -07002236 expected_ops = ['x', 'neg', 'tanh', 'relu']
James Reedb04ae952020-10-05 17:05:07 -07002237 for node, expected in zip(graph.nodes, expected_ops):
2238 assert expected in node.name
2239
James Reed8cdb6382020-10-07 00:13:34 -07002240 def test_reassign_args_kwargs_uses(self):
2241 graph = torch.fx.Graph()
Shen Li10224432021-08-12 11:39:31 -07002242 x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y'))
James Reed8cdb6382020-10-07 00:13:34 -07002243 z = x + y
2244 zed = z + z + z
2245 graph.output(zed.node)
2246 graph.lint()
2247
2248 # zed = z + z + z -> zed = z + z + x
2249 zed.node.args = (zed.node.args[0], x.node)
Philip Meierd4d0ab72022-01-26 23:35:24 -08002250 self.assertEqual(list(x.node.users.keys()), [z.node, zed.node])
James Reed8cdb6382020-10-07 00:13:34 -07002251
2252 # z = x + y -> z = y + y
2253 z.node.args = (y.node, y.node)
Philip Meierd4d0ab72022-01-26 23:35:24 -08002254 self.assertEqual(list(x.node.users.keys()), [zed.node])
James Reed8cdb6382020-10-07 00:13:34 -07002255
James Reed09842a42020-10-13 19:09:57 -07002256 def test_trace_function(self):
2257 def foo(x, y):
2258 return torch.relu(x) + y
2259
2260 x, y = torch.randn(3, 4), torch.randn(3, 4)
2261 self.checkGraphModule(foo, (x, y))
2262
James Reedd23cb942021-02-09 21:49:30 -08002263 def test_trace_dict_int_keys(self):
2264 class ModWithDictArg(torch.nn.Module):
Shen Li10224432021-08-12 11:39:31 -07002265 def forward(self, d : Dict[int, torch.Tensor]):
James Reedd23cb942021-02-09 21:49:30 -08002266 return d[42]
2267
2268 class CallsModWithDict(torch.nn.Module):
2269 def __init__(self):
2270 super().__init__()
2271 self.m = ModWithDictArg()
2272
2273 def forward(self, x):
2274 return self.m({42: x})
2275
2276 class MyTracer(torch.fx.Tracer):
Shen Li10224432021-08-12 11:39:31 -07002277 def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
James Reedd23cb942021-02-09 21:49:30 -08002278 return isinstance(m, ModWithDictArg)
2279
2280 traced_graph = MyTracer().trace(CallsModWithDict())
2281
2282 def test_trace_dict_proxy_keys(self):
2283 class ModWithDictArg(torch.nn.Module):
Shen Li10224432021-08-12 11:39:31 -07002284 def forward(self, d : Dict[torch.Tensor, torch.Tensor]):
James Reedd23cb942021-02-09 21:49:30 -08002285 return d[42]
2286
2287 class CallsModWithDict(torch.nn.Module):
2288 def __init__(self):
2289 super().__init__()
2290 self.m = ModWithDictArg()
2291
2292 def forward(self, x):
2293 return self.m({x: x})
2294
2295 class MyTracer(torch.fx.Tracer):
Shen Li10224432021-08-12 11:39:31 -07002296 def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
James Reedd23cb942021-02-09 21:49:30 -08002297 return isinstance(m, ModWithDictArg)
2298
Shen Li10224432021-08-12 11:39:31 -07002299 with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'):
James Reedd23cb942021-02-09 21:49:30 -08002300 traced_graph = MyTracer().trace(CallsModWithDict())
2301
James Reedd661e642021-08-18 13:16:01 -07002302 def test_module_deepcopy_edit_nodes(self):
2303 class Foo(torch.nn.Module):
2304 def forward(self, x):
2305 return torch.relu(x)
2306
2307 traced1 = symbolic_trace(Foo())
2308 copied = copy.deepcopy(traced1)
2309
2310 for node in copied.graph.nodes:
2311 if node.target == torch.relu:
2312 node.target = torch.neg
2313
2314 copied.recompile()
2315 traced1.recompile()
2316
2317 x = torch.randn(15, 15)
2318 torch.testing.assert_allclose(traced1(x), torch.relu(x))
2319 torch.testing.assert_allclose(copied(x), torch.neg(x))
2320
Horace Hecb4b6332020-10-30 17:06:55 -07002321 def test_direct_param_use(self):
2322 class TransposeTest(torch.nn.Module):
2323 def __init__(self):
2324 super().__init__()
2325 self.b = torch.nn.Parameter(torch.rand(4, 3))
2326
2327 def forward(self, x):
2328 return self.b
2329
2330 class Foo(torch.nn.Module):
2331 def __init__(self):
2332 super().__init__()
2333 self.a = TransposeTest()
2334
2335 def forward(self, x):
2336 return self.a.b, self.a.b.t(), self.a.b.view(12)
2337
2338 traced = torch.fx.symbolic_trace(Foo())
Shen Li10224432021-08-12 11:39:31 -07002339 assert(all('constant' not in node.target for node in traced.graph.nodes))
Horace Hecb4b6332020-10-30 17:06:55 -07002340
Ansley Usserye914a1b2020-11-10 18:55:22 -08002341 def test_single_default_arg(self):
2342 class M(torch.nn.Module):
2343 def __init__(self):
2344 super().__init__()
2345
2346 def forward(self, y=1):
2347 return y
2348
2349 m = M()
2350 self.checkGraphModule(m, ())
2351 self.checkGraphModule(m, (3,))
2352
2353 def test_multiple_default_args(self):
2354 class M(torch.nn.Module):
2355 def __init__(self):
2356 super().__init__()
2357
2358 def forward(self, y=1, z=2):
2359 return y + z
2360
2361 m = M()
2362 self.checkGraphModule(m, ())
2363 self.checkGraphModule(m, (3,))
2364 self.checkGraphModule(m, (3, 4))
2365
2366 def test_regular_and_default_args(self):
2367 class M(torch.nn.Module):
2368 def __init__(self):
2369 super().__init__()
2370
2371 def forward(self, x, y=1):
2372 return x + y
2373
2374 m = M()
2375 self.checkGraphModule(m, (2,))
2376 self.checkGraphModule(m, (2, 3))
2377
Ansley Ussery4cb73f52020-11-11 08:52:06 -08002378 def test_string_literal_return(self):
2379 class M(torch.nn.Module):
2380 def __init__(self):
2381 super().__init__()
2382
2383 def forward(self):
2384 return "foo"
2385
2386 m = M()
2387 self.checkGraphModule(m, ())
2388
James Reedfb755ad2020-12-18 14:08:28 -08002389 def test_namedtuple_return_qualname(self):
2390 class NamedTupReturn(torch.nn.Module):
2391 def forward(self, x):
2392 return MyNamedTup(x, x)
2393
2394 traced = symbolic_trace(NamedTupReturn())
2395 input = torch.rand(3, 4)
2396 self.assertEqual(traced(input), MyNamedTup(input, input))
2397
James Reed67d0c182020-12-22 15:18:16 -08002398 def test_update_args_kwargs_yells_at_you(self):
2399 symtraced = symbolic_trace(SimpleTest())
2400 node = next(iter(symtraced.graph.nodes))
Shen Li10224432021-08-12 11:39:31 -07002401 with self.assertRaisesRegex(AttributeError, '__update_args_kwargs'):
James Reed67d0c182020-12-22 15:18:16 -08002402 node.__update_args_kwargs((), {})
2403
Lu Fang212ec072020-12-04 16:19:24 -08002404 def test_torchbind_class_attribute_in_fx(self):
Jeff Daily340ae3c2022-07-14 00:42:16 +00002405 if IS_FBCODE or IS_WINDOWS or IS_MACOS:
Shen Li10224432021-08-12 11:39:31 -07002406 self.skipTest("torch.classes._TorchScriptTesting._StackString is registered, skipping")
Lu Fang212ec072020-12-04 16:19:24 -08002407
2408 class FooBar1234(torch.nn.Module):
2409 def __init__(self):
2410 super(FooBar1234, self).__init__()
2411 self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
2412
2413 def forward(self):
2414 return self.f.top()
2415
2416 m = FooBar1234()
2417 self.checkGraphModule(m, ())
2418
James Reed1fe6a652021-03-05 23:37:54 -08002419 def test_torchbind_class_attribute_in_fx_tensor_arg(self):
Jeff Daily340ae3c2022-07-14 00:42:16 +00002420 if IS_FBCODE or IS_WINDOWS or IS_MACOS:
Shen Li10224432021-08-12 11:39:31 -07002421 self.skipTest("torch.classes._TorchScriptTesting._ReLUClass is registered, skipping")
James Reed1fe6a652021-03-05 23:37:54 -08002422
2423 class FooBar2341(torch.nn.Module):
2424 def __init__(self):
2425 super(FooBar2341, self).__init__()
2426 self.f = torch.classes._TorchScriptTesting._ReLUClass()
2427
2428 def forward(self, x):
2429 return self.f.run(x)
2430
2431 m = FooBar2341()
2432
2433 traced = symbolic_trace(m)
2434 input = torch.randn(3, 4)
2435 self.assertEqual(traced(input), m(input))
2436
Shen Li10224432021-08-12 11:39:31 -07002437 self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes))
James Reed1fe6a652021-03-05 23:37:54 -08002438
2439 def test_script_method_trace(self):
2440 class Scripted(torch.nn.Module):
2441 def forward(self, x):
2442 return torch.relu(x)
2443
2444 class Holder(torch.nn.Module):
2445 def __init__(self):
2446 super().__init__()
2447 self.s = torch.jit.script(Scripted())
2448
2449 def forward(self, x):
2450 return self.s(x)
2451
2452 h = Holder()
2453 traced = symbolic_trace(h)
2454 input = torch.randn(3, 4)
2455 self.assertEqual(traced(input), h(input))
2456
Shen Li10224432021-08-12 11:39:31 -07002457 self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes))
James Reed1fe6a652021-03-05 23:37:54 -08002458
James Reed80f75102020-12-10 15:28:15 -08002459 def test_namedtuple_return_trace(self):
2460 class NamedTupReturn(torch.nn.Module):
2461 def forward(self, x):
2462 return Pair(x, x)
2463
2464 traced = symbolic_trace(NamedTupReturn())
James Reedfb755ad2020-12-18 14:08:28 -08002465 input = torch.rand(3, 4)
2466 self.assertEqual(traced(input), Pair(input, input))
James Reed09842a42020-10-13 19:09:57 -07002467
Jordan Fix987f1462022-02-23 02:38:29 -08002468 def test_named_tuple_inlined(self):
2469 class NamedTupMod(torch.nn.Module):
2470 def forward(self, inp):
2471 return wrapped_named_tup(Pair(inp, 1.2), p2=Pair(3.4, inp))
2472
2473 m = NamedTupMod()
2474 input = torch.rand(3, 4)
2475 ref = m(input)
2476 traced = symbolic_trace(m)
2477
2478 res = traced(input)
2479 self.assertEqual(ref, res)
2480
2481 # Check Pair NamedTuple works when inlined into the function call.
2482 ph = call_func = None
2483 for node in traced.graph.nodes:
2484 if node.op == "placeholder":
2485 ph = node
2486 elif node.op == "call_function" and node.target == wrapped_named_tup:
2487 node.update_arg(0, Pair(ph, 1.2))
2488 node.update_kwarg("p2", Pair(3.4, ph))
2489 call_func = node
2490 break
2491 self.assertTrue(call_func is not None)
2492 self.assertTrue(isinstance(call_func.args[0], Pair))
2493 self.assertTrue(isinstance(call_func.kwargs["p2"], Pair))
2494 self.assertEqual(_format_arg(call_func.args[0]), "Pair(x=%inp, y=1.2)")
2495 self.assertEqual(_format_arg(call_func.kwargs["p2"]), "Pair(x=3.4, y=%inp)")
2496
2497 traced.graph.eliminate_dead_code()
2498 traced.recompile()
2499 res = traced(input)
2500 self.assertEqual(ref, res)
2501
Brandon Linc51455a2021-01-04 19:54:01 -08002502 def test_return_type_exists(self):
2503 class ReturnTypeModule(torch.nn.Module):
2504 def other(self, x: List[str]) -> List[str]:
2505 return x
2506
2507 def forward(self, x: List[str]) -> List[str]:
2508 return self.other(x)
2509
2510 traced = symbolic_trace(ReturnTypeModule())
Michael Suoecf3ca02021-02-23 13:33:22 -08002511 self.assertIn("-> typing_List[str]", traced._code)
Brandon Linc51455a2021-01-04 19:54:01 -08002512 scripted = torch.jit.script(traced)
2513 self.assertIn("-> List[str]", scripted.code)
2514
Jason Ansela66851a2021-01-22 15:03:09 -08002515 def getitem_inner(self):
2516 class GetItemBase(torch.nn.Module):
2517 def __init__(self):
2518 super().__init__()
Shen Li10224432021-08-12 11:39:31 -07002519 self.register_buffer('pe', torch.randn(8, 8))
Jason Ansela66851a2021-01-22 15:03:09 -08002520
2521 class GetItem1(GetItemBase):
2522 def forward(self, x):
Shen Li10224432021-08-12 11:39:31 -07002523 return self.pe[:, :x.size(0)]
Jason Ansela66851a2021-01-22 15:03:09 -08002524
2525 class GetItem2(GetItemBase):
2526 def forward(self, x):
2527 return self.pe[x.size(0)]
2528
2529 class GetItem3(GetItemBase):
2530 def forward(self, x):
2531 return self.pe[4] # fx creates `self._tensor_constant0` here
2532
2533 self.checkGraphModule(GetItem1(), [torch.zeros(4)])
2534 self.checkGraphModule(GetItem2(), [torch.zeros(4)])
2535 self.checkGraphModule(GetItem3(), [torch.zeros(4)])
2536
Shen Li10224432021-08-12 11:39:31 -07002537 @unittest.skipUnless(os.environ.get("FX_PATCH_GETITEM") == "1",
2538 "Will be checked in test_getitem_subproc")
Jason Ansela66851a2021-01-22 15:03:09 -08002539 def test_getitem(self):
2540 self.getitem_inner()
2541
2542 def test_getitem_subproc(self):
2543 # need to run this test in a subproc to work around:
2544 # https://github.com/pytorch/pytorch/issues/50710
2545 proc = Process(target=run_getitem_target)
2546 proc.start()
2547 proc.join()
2548 self.assertEqual(proc.exitcode, 0)
2549
Shen Li10224432021-08-12 11:39:31 -07002550
Ansley Ussery4ac48902021-01-21 12:00:43 -08002551 def test_user_friendly_call_provenance_with_function(self):
2552 def fn(x):
2553 return wrapper_fn(x)
2554
2555 traced = torch.fx.symbolic_trace(fn)
2556
Shen Li10224432021-08-12 11:39:31 -07002557 with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is "
2558 "being compiled since it was called"
2559 " from 'fn.forward'"):
Ansley Ussery4ac48902021-01-21 12:00:43 -08002560 scripted = torch.jit.script(traced)
2561
2562 def test_user_friendly_call_provenance_with_module(self):
2563 class M(torch.nn.Module):
2564 def forward(self, x):
2565 return wrapper_fn(x)
2566
2567 traced = torch.fx.symbolic_trace(M())
2568
Shen Li10224432021-08-12 11:39:31 -07002569 with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is "
2570 "being compiled since it was called"
2571 " from 'M.forward'"):
Ansley Ussery4ac48902021-01-21 12:00:43 -08002572 scripted = torch.jit.script(traced)
2573
Ansley Ussery7494f022021-01-21 22:16:04 -08002574 def test_snake_case(self):
2575 class M(torch.nn.Module):
2576 def __init__(self):
2577 super(M, self).__init__()
Shen Li10224432021-08-12 11:39:31 -07002578 self.activations = torch.nn.ModuleDict([
2579 ["snake_case", torch.nn.ReLU()],
2580 ["PascalCase", torch.nn.LeakyReLU()],
2581 ["ALL_CAPS", torch.nn.PReLU()]
2582 ])
Ansley Ussery7494f022021-01-21 22:16:04 -08002583
2584 def forward(self, x):
2585 a = self.activations["snake_case"](x)
2586 b = self.activations["PascalCase"](x)
2587 c = self.activations["ALL_CAPS"](x)
2588 return a, b, c
2589
2590 traced = symbolic_trace(M())
2591
2592 check = [
2593 ("activations_snake_case", "activations.snake_case"),
2594 ("activations_pascal_case", "activations.PascalCase"),
Shen Li10224432021-08-12 11:39:31 -07002595 ("activations_all_caps", "activations.ALL_CAPS")
Ansley Ussery7494f022021-01-21 22:16:04 -08002596 ]
2597
2598 i = 0
2599 for node in traced.graph.nodes:
2600 if node.op == "placeholder" or node.op == "output":
2601 continue
2602 name = check[i][0]
2603 target = check[i][1]
2604 self.assertEqual(name, node.name)
2605 self.assertEqual(target, node.target)
2606 i += 1
2607 self.assertEqual(i, 3)
2608
Zachary DeVito33d51802021-01-28 10:17:19 -08002609 def test_no_mutation(self):
2610 from torch.fx.immutable_collections import immutable_list
2611 x = immutable_list([3, 4])
2612 with self.assertRaisesRegex(NotImplementedError, "new_args"):
2613 x[0] = 4
2614
Horace He2d305b92021-02-04 11:49:45 -08002615 def test_partial_trace(self):
2616 class Foo(torch.nn.Module):
2617 def forward(self, x, y):
2618 if y:
2619 return 2 * x
2620 else:
2621 return x
2622 mod = Foo()
Shen Li10224432021-08-12 11:39:31 -07002623 mod_true = symbolic_trace(mod, concrete_args={'y': True})
2624 mod_false = symbolic_trace(mod, concrete_args={'y': False})
Horace He8d363d32021-05-07 04:46:50 -07002625 self.assertEqual(mod_true(3, True), 6)
2626 print(mod_true.code)
Shen Li10224432021-08-12 11:39:31 -07002627 assert(any([i.target == torch._assert for i in mod_true.graph.nodes]))
Horace He8d363d32021-05-07 04:46:50 -07002628 with self.assertRaises(AssertionError):
2629 mod_true(3, False)
2630 self.assertEqual(mod_false(3, False), 3)
2631 with self.assertRaises(AssertionError):
2632 mod_false(3, True)
2633
2634 def f_higher(a, f):
2635 return f(a)
2636
Shen Li10224432021-08-12 11:39:31 -07002637 nf = symbolic_trace(f_higher, concrete_args={'f': lambda x: x * 2})
Horace He8d363d32021-05-07 04:46:50 -07002638 self.assertEqual(nf(3, lambda x: x * 2), 6)
Jason Ansela66851a2021-01-22 15:03:09 -08002639
Ansley Ussery4cc10562021-02-12 18:28:13 -08002640 def test_custom_traceback_raised_when_exception_source_is_graphmodule(self):
2641 class M(torch.nn.Module):
2642 def __init__(self):
2643 super(M, self).__init__()
2644 self.W = torch.nn.Parameter(torch.randn(5))
2645
2646 def forward(self, x):
2647 return torch.dot(self.W, x)
2648
2649 traced = torch.fx.symbolic_trace(M())
2650
2651 out = [n for n in traced.graph.nodes if n.op == "output"][-1]
2652 with traced.graph.inserting_before(out):
Shen Li10224432021-08-12 11:39:31 -07002653 relu_out = traced.graph.call_method(method_name='relu',
2654 args=(out.args[0],))
Ansley Ussery4cc10562021-02-12 18:28:13 -08002655 out.args = (relu_out,)
2656
2657 traced.recompile()
2658
2659 with self.capture_stderr() as captured:
2660 with self.assertRaises(TypeError):
2661 traced(5)
2662
Shen Li10224432021-08-12 11:39:31 -07002663 self.assertRegex(captured[0],
2664 r"Call using an FX-traced Module, line .* of the "
2665 r"traced Module's generated forward function:")
Ansley Ussery4cc10562021-02-12 18:28:13 -08002666
2667 def test_custom_traceback_not_raised_when_exception_source_is_submodule(self):
2668 class M(torch.nn.Module):
2669 def __init__(self):
2670 super().__init__()
2671 self.linear = torch.nn.Linear(3, 4)
2672
2673 def forward(self, x):
2674 return self.linear(x)
2675
2676 traced = torch.fx.symbolic_trace(M())
2677
2678 # Do not change this to `capture_stderr` or another context
2679 # manager without ensuring that the output is as expected
2680 try:
2681 traced(torch.rand(5, 5))
2682 except RuntimeError:
2683 captured = traceback.format_exc()
2684
Shen Li10224432021-08-12 11:39:31 -07002685 self.assertNotRegex(captured,
2686 r"Call using an FX-traced Module, line .* of the "
2687 r"traced Module's generated forward function:")
Ansley Ussery4cc10562021-02-12 18:28:13 -08002688
James Reed4e37a012021-08-24 13:44:52 -07002689 def test_graph_module_replicate_for_dp(self):
2690 class Foo(torch.nn.Module):
2691 def forward(self, x):
2692 return torch.relu(x)
2693
2694 gm = torch.fx.symbolic_trace(Foo())
2695
2696 x = torch.randn(5, 3)
2697 out = gm(x)
2698
2699 replica = gm._replicate_for_data_parallel()
2700 out_replica = replica(x)
2701
2702 torch.testing.assert_allclose(out_replica, out)
2703
Ansley Usseryd8bb9322021-02-17 09:05:40 -08002704 def test_ast_rewriter_rewrites_assert(self):
2705 class M(torch.nn.Module):
2706 def forward(self, x: torch.Tensor, y: int, z: int):
2707 assert y == z
2708 return torch.add(x, x)
2709
2710 ast_rewriter = RewritingTracer()
2711 graph = ast_rewriter.trace(M())
2712 traced = GraphModule(ast_rewriter.root, graph, "gm")
2713
Ansley Ussery85109ce2021-03-04 14:50:34 -08002714 traced.graph.lint()
Ansley Usseryd8bb9322021-02-17 09:05:40 -08002715
2716 def test_ast_rewriter_rewrites_assert_with_message(self):
2717 class M(torch.nn.Module):
2718 def forward(self, x: torch.Tensor, y: int, z: int):
2719 assert y == z, "msg"
2720 return torch.add(x, x)
2721
2722 ast_rewriter = RewritingTracer()
2723 graph = ast_rewriter.trace(M())
2724 traced = GraphModule(ast_rewriter.root, graph, "gm")
2725
Ansley Ussery85109ce2021-03-04 14:50:34 -08002726 traced.graph.lint()
Ansley Usseryd8bb9322021-02-17 09:05:40 -08002727
James Reede1c3e5f2021-09-02 21:11:57 -07002728 def test_throw_out_variant(self):
2729 def foo(x):
2730 y = torch.rand_like(x)
2731 torch.sigmoid(x, out=y)
2732 return y
2733
2734 class MyTracer(torch.fx.Tracer):
2735 check_mutable_operations = True
2736
2737 tracer = MyTracer()
2738 with self.assertRaisesRegex(RuntimeError, 'mutable operation aten::sigmoid.out'):
2739 traced_graph = tracer.trace(foo)
2740
Ansley Usseryd8bb9322021-02-17 09:05:40 -08002741 def test_ast_rewriter_reassigns_submodules(self):
2742 class M(torch.nn.Module):
2743 def __init__(self):
2744 super().__init__()
2745 self.bn = torch.nn.BatchNorm2d(100)
2746
2747 def forward(self, x: torch.Tensor):
2748 return torch.add(x, x)
2749
2750 ast_rewriter = RewritingTracer()
2751 graph = ast_rewriter.trace(M())
2752 traced = GraphModule(ast_rewriter.root, graph, "gm")
2753
Ansley Ussery85109ce2021-03-04 14:50:34 -08002754 traced.graph.lint()
2755
Mostafa Elhoushi13941302021-08-18 14:47:40 -07002756 def test_ast_rewriter_wrap(self):
2757 self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5))
2758
2759 def to_trace(y):
2760 return (
2761 a_lifted_leaf((4, y), 3)
2762 + a_lifted_leaf((3, 4), 5)
2763 + a_lifted_leaf((y, y), y)
2764 )
2765
2766 ast_rewriter = RewritingTracer()
2767 graph = ast_rewriter.trace(to_trace)
2768 traced = GraphModule(ast_rewriter.root, graph, "gm")
2769
2770 self.assertIn("a_lifted_leaf", traced.code)
2771 self.assertEqual(27, traced(2))
2772 self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
2773
2774 def test_ast_rewriter_wrap_fn_directly(self):
2775 self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
2776
2777 def to_trace(y):
2778 return (
2779 a_lifted_leaf2((4, y), 3)
2780 + a_lifted_leaf2((3, 4), 5)
2781 + a_lifted_leaf2((y, y), y)
2782 )
2783
2784 ast_rewriter = RewritingTracer()
2785 graph = ast_rewriter.trace(to_trace)
2786 traced = GraphModule(ast_rewriter.root, graph, "gm")
2787
2788 self.assertIn("a_lifted_leaf2", traced.code)
2789 self.assertEqual(27, traced(2))
2790 self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
2791
James Reed9117eed2021-09-16 20:31:03 -07002792 def test_profiler_ranges_side_effect(self):
2793 g = torch.fx.Graph()
Xiaodong Wang22919602022-04-25 00:16:02 -07002794 handle = g.call_function(torch.ops.profiler._record_function_enter, ('test_range',))
James Reed9117eed2021-09-16 20:31:03 -07002795 g.call_function(torch.ops.profiler._record_function_exit, (handle,))
2796 g.output(None)
2797
2798 found_targets = {}
2799 for node in g.nodes:
2800 if node.op == 'call_function':
2801 found_targets.setdefault(node.target)
2802 self.assertEqual(
Philip Meierd4d0ab72022-01-26 23:35:24 -08002803 list(found_targets.keys()),
Xiaodong Wang22919602022-04-25 00:16:02 -07002804 [torch.ops.profiler._record_function_enter, torch.ops.profiler._record_function_exit]
Philip Meierd4d0ab72022-01-26 23:35:24 -08002805 )
James Reed9117eed2021-09-16 20:31:03 -07002806
2807 g.eliminate_dead_code()
2808 found_targets = {}
2809 for node in g.nodes:
2810 if node.op == 'call_function':
2811 found_targets.setdefault(node.target)
2812 self.assertEqual(
Philip Meierd4d0ab72022-01-26 23:35:24 -08002813 list(found_targets.keys()),
Xiaodong Wang22919602022-04-25 00:16:02 -07002814 [torch.ops.profiler._record_function_enter, torch.ops.profiler._record_function_exit]
Philip Meierd4d0ab72022-01-26 23:35:24 -08002815 )
James Reed9117eed2021-09-16 20:31:03 -07002816
Mostafa Elhoushi13941302021-08-18 14:47:40 -07002817 def test_ast_rewriter_wrapped_via_decorator(self):
2818 class F(torch.nn.Module):
2819 def forward(self, x):
2820 return wrapped_via_decorator(x)
2821
2822 ast_rewriter = RewritingTracer()
2823 graph = ast_rewriter.trace(F())
2824 traced = GraphModule(ast_rewriter.root, graph, "gm")
2825
2826 self.assertIn("wrapped_via_decorator", traced.code)
2827 self.assertEqual(traced(0), 1)
2828 self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
2829 self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
2830
2831 def test_ast_rewriter_wrapped_via_decorator_and_transformed(self):
2832 self.assertEqual(wrapped_via_decorator(0), 1)
2833
2834 def to_trace(y):
2835 return wrapped_via_decorator(y)
2836
2837 ast_rewriter = RewritingTracer()
2838 graph = ast_rewriter.trace(to_trace)
2839 traced = GraphModule(ast_rewriter.root, graph, "gm")
2840
2841 self.assertIn("wrapped_via_decorator", traced.code)
2842 self.assertEqual(traced(0), 1)
2843 self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
2844 self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
2845
2846 transformed = torch.fx.Transformer(traced).transform()
2847 self.assertIn("wrapped_via_decorator", transformed.code)
2848 self.assertEqual(transformed(0), 1)
2849 self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
2850 self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
2851
2852 def test_ast_rewriter_wrap_with_submodule(self):
2853 class M(torch.nn.Module):
2854 def __init__(self):
2855 super(M, self).__init__()
2856 self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
2857
2858 def forward(self, x: torch.Tensor):
2859 return wrapped_with_submodule(x, self.batchnorm1d)
2860
2861 ast_rewriter = RewritingTracer()
2862 graph = ast_rewriter.trace(M())
2863 traced = GraphModule(ast_rewriter.root, graph, "gm")
2864
2865 self.assertIn("wrapped_with_submodule", traced.code)
2866
2867 input = torch.rand(3, 2)
2868 ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
2869 self.assertEqual(ref_batchnorm1d(input), traced(input))
2870
Ansley Ussery85109ce2021-03-04 14:50:34 -08002871 def test_submodule_manipulation_API(self):
2872 class C(torch.nn.Module):
2873 def __init__(self):
2874 super(C, self).__init__()
2875 self.conv = torch.nn.Conv2d(16, 33, 3, stride=2)
2876 self.param = torch.nn.Parameter(torch.rand(2, 3))
2877
2878 def forward(self, x):
2879 return self.conv(torch.cat([self.param, x]))
2880
2881 class B(torch.nn.Module):
2882 def __init__(self):
2883 super(B, self).__init__()
2884 self.linear = torch.nn.Linear(100, 200)
2885 self.register_buffer("buf", torch.randn(2, 3))
2886 self.net_c = C()
2887
2888 def forward(self, x):
2889 return self.linear(torch.cat([self.buf, self.net_c(x)]))
2890
2891 class A(torch.nn.Module):
2892 def __init__(self):
2893 super(A, self).__init__()
2894 self.net_b = B()
2895 self.param = torch.nn.Parameter(torch.rand(2, 3))
2896
2897 def forward(self, x):
2898 return self.net_b(x) + self.param
2899
2900 a = symbolic_trace(A())
2901
2902 a.add_submodule("net_b.net_c.dropout", torch.nn.Dropout(p=0.2))
2903
2904 conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"][-1]
2905 with a.graph.inserting_before(conv):
Bradley Davis1f4bba72021-07-09 14:50:38 -07002906 with warnings.catch_warnings(record=True) as w:
Shen Li10224432021-08-12 11:39:31 -07002907 dropout = a.graph.call_module(module_name="net_b.net_c.dropout",
2908 args=conv.args)
Bradley Davis1f4bba72021-07-09 14:50:38 -07002909 self.assertEqual(len(w), 0)
Ansley Ussery85109ce2021-03-04 14:50:34 -08002910
2911 conv.replace_all_uses_with(dropout)
2912 a.graph.erase_node(conv)
2913 a.recompile()
2914
2915 def module_exists(gm: GraphModule, path: str) -> bool:
2916 return any(path == name for name, _ in gm.named_modules())
2917
2918 def parameter_exists(gm: GraphModule, path: str) -> bool:
Shen Li10224432021-08-12 11:39:31 -07002919 return (any(path == name for name, _ in gm.named_parameters())
2920 and any(path == name for name in gm.state_dict().keys()))
Ansley Ussery85109ce2021-03-04 14:50:34 -08002921
2922 def buffer_exists(gm: GraphModule, path: str) -> bool:
Shen Li10224432021-08-12 11:39:31 -07002923 return (any(path == name for name, _ in gm.named_buffers())
2924 and any(path == name for name in gm.state_dict().keys()))
Ansley Ussery85109ce2021-03-04 14:50:34 -08002925
2926 # Test that we added the "dropout" submodule
2927 self.assertTrue(module_exists(a, "net_b.net_c.dropout"))
2928
2929 # Test `get_submodule` with an added submodule
2930 self.assertIsNotNone(a.get_submodule("net_b.net_c.dropout"))
2931
2932 # Test that the "conv" submodule is still there
2933 self.assertTrue(module_exists(a, "net_b.net_c.conv"))
2934
2935 # Test `get_submodule` with an original module
2936 self.assertIsNotNone(a.get_submodule("net_b.net_c.conv"))
2937
2938 # Test that the "conv" node is NOT still there
2939 conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"]
2940 self.assertEqual(conv, [])
2941
2942 a.delete_submodule("net_b.net_c.conv")
2943
2944 # Test that the "conv" submodule is now gone
2945 self.assertFalse(module_exists(a, "net_b.net_c.conv"))
2946
2947 # Test `get_submodule` with a deleted submodule
Shen Li10224432021-08-12 11:39:31 -07002948 with self.assertRaisesRegex(AttributeError, "has no attribute "
2949 "`conv`"):
Ansley Ussery85109ce2021-03-04 14:50:34 -08002950 self.assertIsNone(a.get_submodule("net_b.net_c.conv"))
2951
2952 # Test `get_attr` warnings
2953 cat = [n for n in a.graph.nodes if n.target == torch.cat][-1]
2954
2955 with a.graph.inserting_before(cat):
2956
2957 with warnings.catch_warnings(record=True) as w:
2958 param = a.graph.get_attr(qualified_name="net_b.net_c.param")
2959 self.assertEqual(len(w), 0)
2960
Shen Li10224432021-08-12 11:39:31 -07002961 with self.assertWarnsRegex(UserWarning, "Attempted to "
2962 "insert a get_attr Node with no "
2963 "underlying reference in the "
2964 "owning GraphModule"):
Ansley Ussery85109ce2021-03-04 14:50:34 -08002965 bad_param = a.graph.get_attr(qualified_name="net_b.param")
2966 a.graph.erase_node(bad_param)
2967
2968 cat.args = (*cat.args, param)
2969
2970 a.recompile()
2971
2972 a.graph.lint()
2973
2974 # Test `get_parameter`
2975 a.get_parameter("net_b.net_c.param")
Shen Li10224432021-08-12 11:39:31 -07002976 with self.assertRaisesRegex(AttributeError, "is not an "
2977 "nn.Parameter"):
Ansley Ussery85109ce2021-03-04 14:50:34 -08002978 a.get_parameter("net_b.buf")
Shen Li10224432021-08-12 11:39:31 -07002979 with self.assertRaisesRegex(AttributeError, "has no attribute "
2980 "`param`"):
Ansley Ussery85109ce2021-03-04 14:50:34 -08002981 a.get_parameter("net_b.param")
2982
2983 # Test `get_buffer`
2984 a.get_buffer("net_b.buf")
Shen Li10224432021-08-12 11:39:31 -07002985 with self.assertRaisesRegex(AttributeError, "is not a "
2986 "buffer"):
Ansley Ussery85109ce2021-03-04 14:50:34 -08002987 a.get_buffer("net_b.net_c.param")
Shen Li10224432021-08-12 11:39:31 -07002988 with self.assertRaisesRegex(AttributeError, "has no attribute "
2989 "`buf`"):
Ansley Ussery85109ce2021-03-04 14:50:34 -08002990 a.get_buffer("net_b.net_c.buf")
2991
2992 # Test non-nested attributes
2993 a.get_submodule("")
2994 a.get_parameter("param")
2995
2996 # Insert some unused submodules
2997 a.add_submodule("net_b.embedding", torch.nn.Embedding(10, 3))
2998 a.add_submodule("net_b.net_c.embedding", torch.nn.Embedding(10, 3))
2999 a.add_submodule("net_b.net_c.rnn", torch.nn.RNN(10, 20, 2))
3000 a.add_submodule("batch_norm_2d", torch.nn.BatchNorm2d(100))
3001
3002 # Garbage collection
3003 a.delete_all_unused_submodules()
3004
3005 # Test that all the unused submodules are gone
3006 self.assertFalse(module_exists(a, "net_b.embedding"))
3007 self.assertFalse(module_exists(a, "net_b.net_c.embedding"))
3008 self.assertFalse(module_exists(a, "net_b.net_c.rnn"))
3009 self.assertFalse(module_exists(a, "batch_norm_2d"))
3010
3011 # Test that we didn't delete any unused Parameters or buffers
3012 self.assertTrue(parameter_exists(a, "net_b.net_c.param"))
3013 self.assertTrue(buffer_exists(a, "net_b.buf"))
3014
3015 a.graph.lint()
Ansley Usseryd8bb9322021-02-17 09:05:40 -08003016
James Reed3eb94432021-10-11 19:32:04 -07003017 def test_delete_unused_submodules_leaf(self):
3018 class SubModule(torch.nn.Module):
3019 def __init__(self):
3020 super().__init__()
3021 self.linear = torch.nn.Linear(10, 10)
3022 self.relu = torch.nn.ReLU()
3023
3024 def forward(self, x):
3025 x = self.linear(x)
3026 x = self.relu(x)
3027 return x
3028
3029 class Model(torch.nn.Module):
3030 def __init__(self):
3031 super().__init__()
3032 self.submod = SubModule()
3033
3034 def forward(self, x):
3035 x = self.submod(x)
3036 return x
3037
3038 model = Model()
3039
3040 class MyCustomTracer(torch.fx.Tracer):
3041 def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
3042 return module_qualified_name == "submod"
3043
3044 inputs = torch.randn(1, 10)
3045 traced_graph = MyCustomTracer().trace(model)
3046 gm2 = torch.fx.GraphModule(model, traced_graph)
3047 gm2.delete_all_unused_submodules()
3048 torch.testing.assert_allclose(gm2(inputs), model(inputs))
3049
Animesh Jain7ebab922022-03-08 22:04:38 -08003050 def test_fx_stateless(self):
3051 class MockModule(torch.nn.Module):
3052 def __init__(self):
3053 super().__init__()
3054 self.l1 = torch.nn.Linear(1, 1)
3055 self.register_buffer('buffer', torch.ones(1))
3056
3057 def forward(self, x):
3058 return self.l1(x) + self.buffer
3059
3060 module = MockModule()
3061 x = torch.rand((1, 1))
3062 weight = torch.tensor([[1.0]], requires_grad=True)
3063 bias = torch.tensor([0.0], requires_grad=True)
3064 buffer = torch.tensor([0.0])
3065 parameters = {'l1.weight': weight,
3066 'l1.bias': bias,
3067 'buffer': buffer}
3068 fx_module = torch.fx.symbolic_trace(module)
3069 res = _stateless.functional_call(fx_module, parameters, x)
3070 res.backward()
3071 self.assertIsNotNone(weight.grad)
3072 self.assertIsNotNone(bias.grad)
3073 self.assertIsNone(buffer.grad)
3074 # Gradient was not calculated for the module stated and buffers
3075 self.assertIsNone(module.l1.weight.grad)
3076 self.assertIsNone(module.l1.bias.grad)
3077 self.assertIsNone(module.buffer.grad)
3078
Bradley Davis093495d2021-08-02 13:35:45 -07003079 def test_tracing_graphmodules_as_leaf_submodules(self):
3080 class A(torch.nn.Module):
3081 def forward(self, t):
3082 return t + t
3083
3084 class B(torch.nn.Module):
3085 def __init__(self):
3086 super(type(self), self).__init__()
3087 self.calling = False
3088 self.called = False
3089
3090 def forward(self, t):
3091 if self.calling:
3092 return t - t
3093 else:
3094 return t + t
3095
3096 def __call__(self, *args):
3097 self.called = True
3098 self.calling = True
3099 return super(type(self), self).__call__(*args)
3100 self.calling = False
3101
3102 class M(torch.nn.Module):
3103 def __init__(self, a, b):
3104 super().__init__()
3105 self.a = a
3106 self.b = b
3107
3108 def forward(self, t):
3109 x = self.a(t)
3110 y = self.b(t)
3111 return x + y
3112
3113 class LeafTracer(Tracer):
3114 def is_leaf_module(self, module, name):
3115 return True
3116
3117 class LeafTracerNotB(Tracer):
3118 def is_leaf_module(self, module, name):
3119 return False if "b" in name else True
3120
3121 # Recompile calls added "for fun", since they
3122 # chain __call__ wrappers.
3123
3124 #
3125 # Test: B as a regular, non-leaf module
3126 #
3127 a = symbolic_trace(A())
3128 a.recompile()
3129 m = M(a, B())
3130 graph = LeafTracerNotB().trace(m)
3131 gm = GraphModule(m, graph)
3132 gm.recompile()
3133
3134 # Test graphmodule/submodule a is not inlined.
3135 self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3136 match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3137 self.assertTrue(len(match) == 1)
3138
3139 # Test submodule b is not treated as leaf.
3140 self.assertFalse(hasattr(gm, "b"))
3141
3142 # Test assert custom __call__ on submodule b was honored.
3143 match = [
3144 n
3145 for n in gm.graph.nodes
3146 if n.op == "call_function" and n.target == operator.sub
3147 ]
3148 self.assertTrue(len(match) == 1)
3149
3150 #
3151 # Test: B as a regular, leaf module
3152 # symbolic_trace should only patch torch.nn.Module.__call__,
3153 # which means B.__call__ should still execute
3154 #
3155 a = symbolic_trace(A())
3156 a.recompile()
3157 b = B()
3158 m = M(a, b)
3159 graph = LeafTracer().trace(m)
3160 gm = GraphModule(m, graph)
3161 gm.recompile()
3162
3163 # Test graphmodule/submodule a is not inlined.
3164 self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3165 match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3166 self.assertTrue(len(match) == 1)
3167
3168 # Test submodule b is leaf:
3169 self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module))
3170 match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"]
3171 self.assertTrue(len(match) == 1)
3172
3173 # Test b.__call__ was run
3174 self.assertTrue(b.called)
3175 self.assertTrue(gm.get_submodule("b").called)
3176
3177 #
3178 # Test: B as GraphModule leaf
3179 # __call__ not honored since symbolic_trace directly invokes forward()
3180 #
3181 a = symbolic_trace(A())
3182 a.recompile()
3183 b = symbolic_trace(B())
3184 b.recompile()
3185 m = M(a, b)
3186 graph = LeafTracer().trace(m)
3187 gm = GraphModule(m, graph)
3188 gm.recompile()
3189
3190 self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
3191 match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
3192 self.assertTrue(len(match) == 1)
3193
3194 self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module))
3195 match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"]
3196 self.assertTrue(len(match) == 1)
3197
Jordan Fix3b0e4a62021-03-09 21:01:20 -08003198 def _test_graph_module_init_buffer_param_copied(self, use_dict_init: bool):
3199 class MyModule(torch.nn.Module):
3200 def __init__(self):
3201 super().__init__()
3202 self.register_buffer("my_buff", torch.rand(3, 4))
3203 self.register_parameter(
3204 "my_param", torch.nn.Parameter(torch.rand(3, 4))
3205 )
3206
3207 def forward(self, x):
3208 return x + self.my_buff + self.my_param
3209
3210 mod = MyModule()
3211 mod_traced = symbolic_trace(mod)
3212
3213 # Create new GraphModule based on original, either w/ dict or root module.
3214 orig_buff = mod_traced.get_buffer("my_buff")
3215 orig_param = mod_traced.get_parameter("my_param")
3216 mod_traced_new = GraphModule(
3217 {"my_buff": orig_buff, "my_param": orig_param} if use_dict_init else mod,
3218 mod_traced.graph,
3219 )
3220
3221 # Check that both my_buff and my_param are found and the same.
3222 try:
3223 new_buff = mod_traced_new.get_buffer("my_buff")
3224 except Exception:
3225 self.fail("Did not find my_buff")
3226 self.assertEqual(orig_buff, new_buff)
3227
3228 try:
3229 new_param = mod_traced_new.get_parameter("my_param")
3230 except Exception:
3231 self.fail("Did not find my_param")
3232 self.assertEqual(orig_param, new_param)
3233
3234 x = torch.rand(3, 4)
3235 orig_out = mod_traced(x)
3236 submodules_out = mod_traced_new(x)
3237
3238 self.assertEqual(orig_out, submodules_out)
3239
3240 def test_graph_module_init_buffer_param_copied_dict_init(self):
3241 self._test_graph_module_init_buffer_param_copied(use_dict_init=True)
3242
3243 def test_graph_module_init_buffer_param_copied_mod_init(self):
3244 self._test_graph_module_init_buffer_param_copied(use_dict_init=False)
3245
Ansley Ussery08f04c02021-03-15 19:37:37 -07003246 def test_annotations_with_no_forward_references(self):
3247 class A:
3248 def __call__(self, x: torch.Tensor):
3249 return torch.add(x, x)
3250
3251 class M(torch.nn.Module):
3252 def forward(self, x: torch.Tensor, a: A) -> torch.Tensor:
3253 return a(x)
3254
3255 self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3256
3257 def test_annotations_with_forward_references(self):
3258 class A:
3259 def __call__(self, x: torch.Tensor):
3260 return torch.add(x, x)
3261
3262 class M(torch.nn.Module):
Shen Li10224432021-08-12 11:39:31 -07003263 def forward(self, x: 'torch.Tensor', a: 'A') -> 'torch.Tensor':
Ansley Ussery08f04c02021-03-15 19:37:37 -07003264 return a(x)
3265
3266 self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3267
Shen Li10224432021-08-12 11:39:31 -07003268 def test_annotations_with_non_torch_reference_and_no_internal_forward_references(self):
Ansley Ussery08f04c02021-03-15 19:37:37 -07003269 class A:
3270 def __call__(self, x: torch.Tensor):
3271 return torch.add(x, x)
3272
3273 class M(torch.nn.Module):
3274 def forward(self, x: List[torch.Tensor], a: A) -> torch.Tensor:
3275 return a(x[0])
3276
3277 self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3278
3279 def test_annotations_with_non_torch_reference_and_internal_forward_references(self):
3280 class A:
3281 def __call__(self, x: torch.Tensor):
3282 return torch.add(x, x)
3283
3284 class M(torch.nn.Module):
Shen Li10224432021-08-12 11:39:31 -07003285 def forward(self, x: List['torch.Tensor'], a: A) -> 'torch.Tensor':
Ansley Ussery08f04c02021-03-15 19:37:37 -07003286 return a(x)[0]
3287
3288 self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
3289
Shen Li10224432021-08-12 11:39:31 -07003290 @unittest.skipIf(sys.version_info < (3, 7), "`__future__` feature "
3291 "`annotations` is not defined in Python <3.7")
Ansley Ussery08f04c02021-03-15 19:37:37 -07003292 def test_annotation_with_future(self):
3293 try:
Shen Li10224432021-08-12 11:39:31 -07003294 import fx.test_future # noqa: F401
Ansley Ussery08f04c02021-03-15 19:37:37 -07003295 finally:
3296 del sys.modules["__future__"]
3297
Ansley Ussery5268b5a2021-05-25 12:11:15 -07003298 def test_annotations_empty_tuple(self):
3299 class Foo(torch.nn.Module):
3300 def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]):
3301 return "foo"
3302
3303 traced = torch.fx.symbolic_trace(Foo())
3304
3305 x = ()
3306 y = ("bar", ())
3307
3308 traced(x, y)
3309
Shen Li10224432021-08-12 11:39:31 -07003310 FileCheck().check("_Tuple[()]") \
3311 .check("typing_Tuple[str,typing_Tuple[()]]") \
3312 .run(traced.code)
Ansley Ussery5268b5a2021-05-25 12:11:15 -07003313
3314 scripted = torch.jit.script(traced)
3315
3316 scripted(x, y)
3317
Shen Li10224432021-08-12 11:39:31 -07003318 FileCheck().check("Tuple[()]") \
3319 .check("Tuple[str, Tuple[()]]") \
3320 .run(scripted.code)
Ansley Ussery5268b5a2021-05-25 12:11:15 -07003321
Zachary DeVito7bc59622022-01-27 17:59:37 -08003322 @unittest.skipIf(IS_WINDOWS, "Python Windows bug? https://bugs.python.org/issue45108")
Nikita Shulgad80fe492022-07-27 20:22:47 +00003323 @unittest.skipIf(sys.version_info >= (3, 10), "Does not work on Python-3.10")
Zachary DeVito7bc59622022-01-27 17:59:37 -08003324 def test_assert(self):
3325 def f(x):
3326 assert x > 1
3327 return x + 1
3328 try:
3329 torch.fx.proxy.TracerBase.trace_asserts = True
3330 traced = symbolic_trace(f)
3331 finally:
3332 torch.fx.proxy.TracerBase.trace_asserts = False
3333
3334 self.assertEqual(f(2), traced(2))
3335 with self.assertRaises(AssertionError):
3336 traced(0)
3337
Horace He8d363d32021-05-07 04:46:50 -07003338 def test_pytree(self):
3339 def f_sum(x):
3340 return sum(x)
3341
3342 def f_sum_dict(x):
3343 out = 0
3344 for k, v in x.items():
3345 out += v
3346 return out
3347
3348 def f_dict_list_map(x):
3349 new_dict = {}
3350 for k, v in x.items():
3351 new_dict[k] = [i + 1 for i in v]
3352 return new_dict
3353
3354 def f_dict_add(x):
Shen Li10224432021-08-12 11:39:31 -07003355 return x['a'] + sum(x['z'])
Horace He8d363d32021-05-07 04:46:50 -07003356
Richard Zou52d1ffb2021-07-28 06:26:08 -07003357 def f_namedtuple_add(x):
3358 return x.x + x.y
Horace He8d363d32021-05-07 04:46:50 -07003359
3360 pytree._register_pytree_node(
3361 Foo,
3362 lambda x: ([x.a, x.b], None),
3363 lambda x, _: Foo(x[0], x[1]),
3364 )
3365 fx_pytree.register_pytree_flatten_spec(Foo, lambda x, _: [x.a, x.b])
3366
3367 def f_custom(x):
3368 return x.a + x.b
3369
3370 def f_custom_dict(x):
3371 return f_sum_dict(x.a) + x.b
3372
3373 def f_return_custom(x):
3374 return Foo(x.b, x.a)
3375
3376 tests = [
3377 (f_sum, [PH, PH, PH]),
3378 (f_sum, []),
Shen Li10224432021-08-12 11:39:31 -07003379 (f_sum_dict, {'a': PH, 'b': PH, 'c': PH}),
3380 (f_dict_list_map, {'a': (PH, PH), 'b': [PH], 'c': []}),
Horace He8d363d32021-05-07 04:46:50 -07003381 (f_dict_list_map, {5: (PH, PH, PH)}),
Shen Li10224432021-08-12 11:39:31 -07003382 (f_dict_add, {'a': PH, 'z': (PH, PH, PH)}),
3383 (f_dict_add, {'a': PH, 'z': []}),
Horace He8d363d32021-05-07 04:46:50 -07003384 (f_custom, Foo(PH, PH)),
3385 (f_custom, Foo(PH, 3)),
Shen Li10224432021-08-12 11:39:31 -07003386 (f_custom_dict, Foo({'a': PH, 'b': PH}, PH)),
Horace He8d363d32021-05-07 04:46:50 -07003387 # (f_return_custom, Foo(PH, PH)), # Don't currently support output pytrees
Richard Zou52d1ffb2021-07-28 06:26:08 -07003388 (f_namedtuple_add, Point(PH, PH)),
Horace He8d363d32021-05-07 04:46:50 -07003389 ]
3390
3391 def verify_pytree(f, inp):
3392 val = pytree.tree_map(lambda x: torch.randn(3) if x == PH else x, inp)
3393 num_flat_args = len([i == PH for i in pytree.tree_flatten(inp)[0]])
3394 orig_out = f(val)
Shen Li10224432021-08-12 11:39:31 -07003395 nf = symbolic_trace(f, concrete_args={'x': inp})
Horace He8d363d32021-05-07 04:46:50 -07003396 self.assertEqual(nf(val), orig_out)
Horace Hed635d0f2022-02-11 10:07:21 -08003397
3398 bare_fx = GraphModule({}, copy.deepcopy(nf.graph))
3399 bare_fx.graph.set_codegen(CodeGen())
3400 bare_fx.recompile()
3401 self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), orig_out)
3402
Horace He8d363d32021-05-07 04:46:50 -07003403 assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
Shen Li10224432021-08-12 11:39:31 -07003404 assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args)
Horace He8d363d32021-05-07 04:46:50 -07003405
3406 nf = symbolic_trace(nf)
3407 self.assertEqual(nf(val), orig_out)
3408 assert "tree_flatten_spec" not in nf.code
Shen Li10224432021-08-12 11:39:31 -07003409 assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == 1)
Horace He8d363d32021-05-07 04:46:50 -07003410
Shen Li10224432021-08-12 11:39:31 -07003411 nf = symbolic_trace(nf, concrete_args={'x': inp})
Horace He8d363d32021-05-07 04:46:50 -07003412 self.assertEqual(nf(val), orig_out)
3413 assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
Shen Li10224432021-08-12 11:39:31 -07003414 assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args)
Horace He8d363d32021-05-07 04:46:50 -07003415
3416 pickled = pickle.dumps(nf)
3417 nf = pickle.loads(pickled)
3418 self.assertEqual(nf(val), orig_out)
3419
3420 for f, inp in tests:
3421 verify_pytree(f, inp)
3422
3423 def test_pytree_concrete(self):
3424 def f(b, a):
3425 if b:
Shen Li10224432021-08-12 11:39:31 -07003426 return a['a']
Horace He8d363d32021-05-07 04:46:50 -07003427 else:
Shen Li10224432021-08-12 11:39:31 -07003428 return a['z']
Horace He8d363d32021-05-07 04:46:50 -07003429
Shen Li10224432021-08-12 11:39:31 -07003430 inp = {'a': {'a': PH, 'z': PH}, 'b': True}
Horace He8d363d32021-05-07 04:46:50 -07003431 nf = symbolic_trace(f, concrete_args=inp)
3432 val = pytree.tree_map(lambda x: torch.randn(3) if x == PH else x, inp)
3433 self.assertEqual(nf(**val), f(**val))
3434
3435 nf = symbolic_trace(nf)
3436 self.assertEqual(nf(**val), f(**val))
3437
Horace Hed635d0f2022-02-11 10:07:21 -08003438 def test_custom_codegen(self):
3439 class ListCodeGen(CodeGen):
3440 def gen_fn_def(self, free_vars, maybe_return_annotation):
3441 lst_unpack = f"""
3442def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
PyTorch MergeBot5df1ce42022-07-26 03:29:50 +00003443 {', '.join(free_vars)} = args_list"""
Horace Hed635d0f2022-02-11 10:07:21 -08003444 return lst_unpack
3445
3446 def additional_globals(self):
3447 return [('List', typing.List)]
3448
3449 def process_inputs(self, *inputs):
3450 assert(len(inputs) == 1)
3451 return inputs[0]
3452
3453 def f(a, b):
3454 return a + b
3455
3456 nf = symbolic_trace(f)
3457 vals = [torch.randn(3), torch.randn(3)]
3458 self.assertEqual(nf(*vals), f(*vals))
3459
3460 nf.graph.set_codegen(ListCodeGen())
3461 nf.recompile()
3462
3463 bare_fx = GraphModule({}, copy.deepcopy(nf.graph))
3464 bare_fx.graph.set_codegen(CodeGen())
3465 bare_fx.recompile()
3466
3467 self.assertEqual(nf(vals), f(*vals))
3468 self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(vals))), f(*vals))
3469
3470 ts_f = torch.jit.script(nf)
3471 self.assertEqual(nf(vals), ts_f(vals))
3472
Shiyan Dengf98b3162022-03-16 09:27:26 -07003473 def test_custom_codegen_with_transformer(self):
3474 class ListCodeGen(CodeGen):
3475 def gen_fn_def(self, free_vars, maybe_return_annotation):
3476 lst_unpack = f"""
3477def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
PyTorch MergeBot5df1ce42022-07-26 03:29:50 +00003478 {', '.join(free_vars)} = args_list"""
Shiyan Dengf98b3162022-03-16 09:27:26 -07003479 return lst_unpack
3480
3481 def additional_globals(self):
3482 return [('List', typing.List)]
3483
3484 def process_inputs(self, *inputs):
3485 assert(len(inputs) == 1)
3486 return inputs[0]
3487
3488 def f(a, b):
3489 return a + b
3490
3491 nf = symbolic_trace(f)
3492 vals = [torch.randn(3), torch.randn(3)]
3493 self.assertEqual(nf(*vals), f(*vals))
3494
3495 nf.graph.set_codegen(ListCodeGen())
3496 nf.recompile()
3497 self.assertEqual(nf(vals), f(*vals))
3498
3499 transformed_gm = Transformer(nf).transform()
3500 self.assertEqual(nf(vals), transformed_gm(vals))
Horace Hed635d0f2022-02-11 10:07:21 -08003501
Shiyan Deng3f164e02022-03-24 11:26:36 -07003502 def test_interpreter_with_codegen(self):
3503 class ListCodeGen(CodeGen):
3504 def gen_fn_def(self, free_vars, maybe_return_annotation):
3505 lst_unpack = f"""
3506def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
PyTorch MergeBot5df1ce42022-07-26 03:29:50 +00003507 {', '.join(free_vars)} = args_list"""
Shiyan Deng3f164e02022-03-24 11:26:36 -07003508 return lst_unpack
3509
3510 def additional_globals(self):
3511 return [('List', typing.List)]
3512
3513 def process_inputs(self, *inputs):
3514 assert(len(inputs) == 1)
3515 return inputs[0]
3516
3517 def generate_output(self, output_args):
3518 return f'return list({repr(output_args)})'
3519
3520 def process_outputs(self, outputs):
3521 return list(outputs)
3522
3523 def f(a, b):
3524 a = a + b
3525 b = a + b
3526 return a, b
3527
3528 nf = symbolic_trace(f)
3529 vals = [torch.randn(3), torch.randn(3)]
3530 nf.graph.set_codegen(ListCodeGen())
3531 nf.recompile()
3532 self.assertEqual(Interpreter(nf).run(vals), nf(vals))
3533
Jason Ansel567c2bb2022-01-27 12:22:56 -08003534 def test_imul_code_print(self):
3535 graph = torch.fx.Graph()
3536 a = graph.placeholder("a")
3537 b = graph.placeholder("b")
3538 graph.call_function(operator.imul, (a, b), {})
3539 graph.output(a)
3540 gm = torch.fx.GraphModule({}, graph)
3541 gm.recompile()
3542 self.assertEqual(gm(2, 3), 6)
3543 self.assertIn("a *= b", gm.code)
Shen Li10224432021-08-12 11:39:31 -07003544
David Berard45e7d022022-08-09 16:38:29 -07003545 def test_deepcopy_tracer(self):
3546 def fn(x, y):
3547 return (x + y).relu().sin()
3548
3549 tracer = Tracer()
3550 tracer_before = copy.deepcopy(tracer)
3551 tracer.trace(fn)
3552 tracer_after = copy.deepcopy(tracer)
3553
3554 self.assertEqual(str(tracer.graph), str(tracer_after.graph))
3555 self.assertTrue(not hasattr(tracer_before, 'graph') or str(tracer.graph) != str(tracer_before.graph))
Shen Li10224432021-08-12 11:39:31 -07003556
Jason Ansela66851a2021-01-22 15:03:09 -08003557def run_getitem_target():
James Reed7b73fdf2021-05-17 19:48:47 -07003558 from torch.fx._symbolic_trace import _wrapped_methods_to_patch
Jason Ansela66851a2021-01-22 15:03:09 -08003559 _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
3560 try:
3561 TestFX().getitem_inner()
3562 finally:
3563 _wrapped_methods_to_patch.pop()
3564
3565
James Reed255b1032021-03-17 20:39:16 -07003566class TestOperatorSignatures(JitTestCase):
James Reede1c3e5f2021-09-02 21:11:57 -07003567 def setUp(self):
3568 # Checking for mutable operations whil tracing is feature flagged
3569 # Enable it in testing but not by default
3570 self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
3571 torch.fx.proxy.TracerBase.check_mutable_operations = True
3572
3573 def tearDown(self):
3574 torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
3575
James Reed255b1032021-03-17 20:39:16 -07003576 @onlyCPU
3577 @ops(op_db, allowed_dtypes=(torch.float,))
3578 def test_get_torch_func_signature_exhaustive(self, device, dtype, op):
Horace He0b2f68e2021-11-02 15:55:43 -07003579 if not isinstance(op.op, types.BuiltinFunctionType):
3580 raise unittest.SkipTest("This path doesn't work on Python functions")
3581 sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
3582 schemas = get_signature_for_torch_op(op.op)
3583 if not schemas:
3584 raise RuntimeError('No Schemas Returned')
3585 for sample_input in sample_inputs_itr:
3586 # Iterate through overloads until we hit a match. If we exit this
3587 # loop via `else`, we haven't found a match
3588 for schema in schemas:
3589 try:
3590 bound_args = schema.bind(sample_input.input, *sample_input.args, **sample_input.kwargs)
3591 bound_args.apply_defaults()
3592 op(*bound_args.args, **bound_args.kwargs)
3593 break
3594 except TypeError as e:
3595 pass
3596 else:
3597 raise RuntimeError(f'Did not match any schemas for op {op.name}!')
James Reed255b1032021-03-17 20:39:16 -07003598
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003599
James Reed538647f2021-08-30 19:54:50 -07003600class TestFXAPIBackwardCompatibility(JitTestCase):
3601 def setUp(self):
Jane Xu6ecd13d2022-03-16 15:04:32 -07003602 super().setUp()
James Reed538647f2021-08-30 19:54:50 -07003603 self.maxDiff = None
3604
James Reede1c3e5f2021-09-02 21:11:57 -07003605 # Checking for mutable operations whil tracing is feature flagged
3606 # Enable it in testing but not by default
3607 self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
3608 torch.fx.proxy.TracerBase.check_mutable_operations = True
3609
3610 def tearDown(self):
Jane Xu6ecd13d2022-03-16 15:04:32 -07003611 super().tearDown()
James Reede1c3e5f2021-09-02 21:11:57 -07003612 torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
3613
3614
James Reed538647f2021-08-30 19:54:50 -07003615 def _fn_to_stable_annotation_str(self, obj):
3616 """
3617 Unfortunately we have to serialize function signatures manually since
3618 serialization for `inspect.Signature` objects is not stable across
3619 python versions
3620 """
3621 fn_name = torch.typename(obj)
3622
3623 signature = inspect.signature(obj)
3624
3625 sig_str = f'{fn_name}{signature}'
3626
3627 arg_strs = []
3628 for k, v in signature.parameters.items():
3629 maybe_type_annotation = f': {self._annotation_type_to_stable_str(v.annotation, sig_str)}'\
3630 if v.annotation is not inspect.Signature.empty else ''
3631
3632 def default_val_str(val):
3633 if isinstance(val, (tuple, list)):
3634 str_pieces = ['(' if isinstance(val, tuple) else '[']
3635 str_pieces.append(', '.join(default_val_str(v) for v in val))
3636 if isinstance(val, tuple) and len(str_pieces) == 2:
3637 str_pieces.append(',')
3638 str_pieces.append(')' if isinstance(val, tuple) else ']')
3639 return ''.join(str_pieces)
3640
3641 # Need to fix up some default value strings.
3642 # First case: modules. Default module `repr` contains the FS path of the module.
3643 # Don't leak that
3644 if isinstance(val, types.ModuleType):
3645 return f'<module {val.__name__}>'
3646
3647 # Second case: callables. Callables (such as lambdas) encode their address in
3648 # their string repr. Don't do that
3649 if callable(val):
3650 return f'<function {val.__name__}>'
3651
3652 return str(val)
3653
3654 if v.default is not inspect.Signature.empty:
3655 default_val_str = default_val_str(v.default) if not isinstance(v.default, str) else f"'{v.default}'"
3656 maybe_default = f' = {default_val_str}'
3657 else:
3658 maybe_default = ''
3659 maybe_stars = ''
3660 if v.kind == inspect.Parameter.VAR_POSITIONAL:
3661 maybe_stars = '*'
3662 elif v.kind == inspect.Parameter.VAR_KEYWORD:
3663 maybe_stars = '**'
3664 arg_strs.append(f'{maybe_stars}{k}{maybe_type_annotation}{maybe_default}')
3665
3666 return_annot = f' -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}'\
3667 if signature.return_annotation is not inspect.Signature.empty else ''
3668
3669 return f'{fn_name}({", ".join(arg_strs)}){return_annot}'
3670
3671 def _annotation_type_to_stable_str(self, t, sig_str):
3672 if t is inspect.Signature.empty:
3673 return ''
3674
3675 # Forward ref
3676 if isinstance(t, str):
3677 return f"'{t}'"
3678 if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef):
3679 return t.__forward_arg__
3680 if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef):
3681 return t.__forward_arg__
3682
3683 trivial_mappings = {
3684 str : 'str',
3685 int : 'int',
3686 float: 'float',
3687 bool: 'bool',
3688 torch.dtype: 'torch.dtype',
3689 torch.Tensor: 'torch.Tensor',
3690 torch.device: 'torch.device',
3691 torch.memory_format: 'torch.memory_format',
3692 slice: 'slice',
3693 torch.nn.Module: 'torch.nn.modules.module.Module',
3694 torch.fx.Graph : 'torch.fx.graph.Graph',
3695 torch.fx.Node : 'torch.fx.node.Node',
3696 torch.fx.Proxy : 'torch.fx.proxy.Proxy',
3697 torch.fx.node.Target : 'torch.fx.node.Target',
3698 torch.fx.node.Argument : 'torch.fx.node.Argument',
3699 torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
3700 torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
3701 torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
3702 Ellipsis : '...',
3703 typing.Any: 'Any',
3704 type(None): 'NoneType',
3705 None: 'None',
3706 typing.Iterator: 'Iterator',
3707 }
3708
3709 mapping = trivial_mappings.get(t, None)
3710 if mapping:
3711 return mapping
3712
3713 # Handle types with contained types
3714 contained = getattr(t, '__args__', None) or []
3715
3716 # Callables contain a bare List for arguments
3717 contained = t if isinstance(t, list) else contained
3718
3719 # Python 3.8 puts type vars into __args__ for unbound types such as Dict
3720 if all(isinstance(ct, typing.TypeVar) for ct in contained):
3721 contained = []
3722
3723 contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained]
3724 contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else ''
3725
3726
3727 origin = getattr(t, '__origin__', None)
3728 if origin is None:
3729 # Unbound types don't have `__origin__` in some Python versions, so fix that up here.
3730 origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin
3731
3732 if origin in {tuple, typing.Tuple}:
3733 return f'Tuple{contained_type_str}'
3734 if origin in {typing.Union}:
3735 # Annoying hack to detect Optional
3736 if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)):
3737 not_none_param = contained[0] if contained[0] is not type(None) else contained[1]
3738 return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]'
3739 return f'Union{contained_type_str}'
3740 if origin in {dict, typing.Dict}:
3741 return f'Dict{contained_type_str}'
3742 if origin in {list, typing.List}:
3743 return f'List{contained_type_str}'
3744 if origin in {type, typing.Type}:
3745 return f'Type{contained_type_str}'
3746 if isinstance(t, typing.Callable):
3747 if len(contained) > 0 and contained[0] is not Ellipsis:
3748 return f'Callable[[{", ".join(contained_type_annots[:-1])}], {contained_type_annots[-1]}]'
3749 else:
3750 return f'Callable{contained_type_str}'
3751
3752 raise RuntimeError(f'Unrecognized type {t} used in BC-compatible type signature {sig_str}.'
3753 f'Please add support for this type and confirm with the '
3754 f'FX team that your signature change is valid.')
3755
3756
3757 def test_function_back_compat(self):
3758 """
3759 Test backward compatibility for function signatures with
3760 @compatibility(is_backward_compatible=True). Currently this checks for
3761 exact signature matches, which may lead to false positives. If this
3762 becomes too annoying, we can refine this check to actually parse out
3763 the saved schema strings and check if the change is truly backward-
3764 incompatible.
3765 """
3766 signature_strs = []
3767
3768 for obj in _BACK_COMPAT_OBJECTS:
3769 if not isinstance(obj, type):
3770 signature_strs.append(self._fn_to_stable_annotation_str(obj))
3771
3772 signature_strs.sort()
3773
3774 try:
3775 self.assertExpected('\n'.join(signature_strs), 'fx_backcompat_function_signatures')
3776 except AssertionError as e:
3777 msg = f"{e}\n****** ERROR ******\nAn FX function that has been marked " \
3778 f"as backwards-compatible has experienced a signature change. See the " \
3779 f"above exception context for more information. If this change was " \
3780 f"unintended, please revert it. If it was intended, check with the FX " \
3781 f"team to ensure that the proper deprecation protocols have been followed " \
3782 f"and subsequently --accept the change."
3783 raise AssertionError(msg)
3784
3785 def test_class_member_back_compat(self):
3786 """
3787 Test backward compatibility for members of classes with
3788 @compatibility(is_backward_compatible=True). Currently this checks for
3789 exact matches on the publicly visible members of the class.
3790 """
3791 class_method_strs = []
3792
3793 for obj in _BACK_COMPAT_OBJECTS:
3794 if isinstance(obj, type):
3795 public_members = [name for name in obj.__dict__ if not name.startswith('_')]
3796 class_method_strs.append(f'{torch.typename(obj)} {sorted(public_members)}')
3797
3798 class_method_strs.sort()
3799
3800 try:
3801 self.assertExpected('\n'.join(class_method_strs), 'fx_backcompat_class_members')
3802 except AssertionError as e:
3803 msg = f"{e}\n****** ERROR ******\nAn FX class that has been marked " \
3804 f"as backwards-compatible has experienced change in its public members. See the " \
3805 f"above exception context for more information. If this change was " \
3806 f"unintended, please revert it. If it was intended, check with the FX " \
3807 f"team to ensure that the proper deprecation protocols have been followed " \
3808 f"and subsequently --accept the change."
3809 raise AssertionError(msg)
3810
3811 def test_public_api_surface(self):
James Reed538647f2021-08-30 19:54:50 -07003812 non_back_compat_objects = {}
3813
3814 def check_symbols_have_bc_designation(m, prefix):
3815 if not m.__name__.startswith('torch.fx'):
3816 return
3817 if m.__name__.startswith('torch.fx.experimental'):
3818 return
3819 for k, v in m.__dict__.items():
3820 if v is m:
3821 continue
3822 if k.startswith('_'):
3823 continue
3824 if isinstance(v, types.ModuleType):
3825 check_symbols_have_bc_designation(v, prefix + [k])
3826 elif isinstance(v, type) or isinstance(v, types.FunctionType):
3827 if v not in _MARKED_WITH_COMATIBLITY:
3828 non_back_compat_objects.setdefault(v)
3829
James Reed0559cb32021-09-17 09:26:37 -07003830 check_symbols_have_bc_designation(torch.fx, ['torch', 'fx'])
3831 check_symbols_have_bc_designation(torch.fx.passes, ['torch', 'fx', 'passes'])
James Reed538647f2021-08-30 19:54:50 -07003832
3833 non_back_compat_strs = [torch.typename(obj) for obj in non_back_compat_objects.keys()]
3834 # Only want objects in torch.fx
3835 non_back_compat_strs = [
3836 s for s in non_back_compat_strs if s.startswith('torch.fx') and not s.startswith('torch.fx.experimental')]
3837 # Only want objects in public namespaces
3838 non_back_compat_strs = [
3839 s for s in non_back_compat_strs if all(not atom.startswith('_') for atom in s.split('.'))]
3840 non_back_compat_strs.sort()
3841
3842 if len(non_back_compat_strs) != 0:
3843 raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a "
3844 f"backwards-compatibility classification! Please decorate these "
3845 f"API(s) with `@torch.fx._compatibility.compatibility` to specify "
3846 f"BC guarantees.")
3847
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003848class TestFunctionalTracing(JitTestCase):
James Reede1c3e5f2021-09-02 21:11:57 -07003849 def setUp(self):
Jane Xu6ecd13d2022-03-16 15:04:32 -07003850 super().setUp()
James Reede1c3e5f2021-09-02 21:11:57 -07003851 # Checking for mutable operations whil tracing is feature flagged
3852 # Enable it in testing but not by default
3853 self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
3854 torch.fx.proxy.TracerBase.check_mutable_operations = True
3855
3856 def tearDown(self):
Jane Xu6ecd13d2022-03-16 15:04:32 -07003857 super().tearDown()
James Reede1c3e5f2021-09-02 21:11:57 -07003858 torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
3859
Shen Li10224432021-08-12 11:39:31 -07003860 IGNORE_FUNCS = ("has_torch_function", "has_torch_function_unary",
3861 "has_torch_function_variadic", "handle_torch_function",
3862 "boolean_dispatch")
3863 TO_PATCH = {"has_torch_function": None,
3864 "has_torch_function_unary": None,
3865 "has_torch_function_variadic": None}
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003866
3867 BUILT_IN_FUNC = (AssertionError, "")
3868 PROXY_ITERABLE = (TypeError, r"argument of type 'Proxy' is not iterable")
3869 PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated")
3870 LEN_ERROR = (RuntimeError, r"'len' is not supported in symbolic tracing by default")
3871 ARG_TYPE_MISMATCH = (TypeError, r", not Proxy$")
Shen Li10224432021-08-12 11:39:31 -07003872 CONTROL_FLOW = (TraceError, r"symbolically traced variables cannot be used as inputs to control flow")
3873 INTERPOLATE_ARGS_CONFLICT = (ValueError, r"only one of size or scale_factor should be defined")
James Reede1c3e5f2021-09-02 21:11:57 -07003874 MUTABLE = (RuntimeError, r"Tried to trace mutable operation")
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003875
3876 UNTRACEABLE_FUNCTIONALS = {
3877 "adaptive_avg_pool1d": BUILT_IN_FUNC,
3878 "avg_pool1d": BUILT_IN_FUNC,
3879 "avg_pool2d": BUILT_IN_FUNC,
3880 "avg_pool3d": BUILT_IN_FUNC,
Peter Belle8d226c2022-02-01 08:50:33 -08003881 "bilinear": BUILT_IN_FUNC,
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003882 "celu_": BUILT_IN_FUNC,
3883 "channel_shuffle": BUILT_IN_FUNC,
Vitaly Fedyunin81fbeea2022-02-16 18:23:08 -08003884 "native_channel_shuffle": BUILT_IN_FUNC,
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003885 "conv1d": BUILT_IN_FUNC,
3886 "conv2d": BUILT_IN_FUNC,
3887 "conv3d": BUILT_IN_FUNC,
3888 "conv_tbc": BUILT_IN_FUNC,
3889 "conv_transpose1d": BUILT_IN_FUNC,
3890 "conv_transpose2d": BUILT_IN_FUNC,
3891 "conv_transpose3d": BUILT_IN_FUNC,
3892 "cosine_similarity": BUILT_IN_FUNC,
3893 "elu_": BUILT_IN_FUNC,
Peter Belle8d226c2022-02-01 08:50:33 -08003894 "gelu": BUILT_IN_FUNC,
3895 "hardshrink": BUILT_IN_FUNC,
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003896 "hardtanh_": BUILT_IN_FUNC,
3897 "leaky_relu_": BUILT_IN_FUNC,
Peter Belle8d226c2022-02-01 08:50:33 -08003898 "linear": BUILT_IN_FUNC,
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003899 "logsigmoid": BUILT_IN_FUNC,
3900 "one_hot": BUILT_IN_FUNC,
Peter Bellcb37e7a2022-04-22 17:34:59 +01003901 "pad": BUILT_IN_FUNC,
Peter Belle8d226c2022-02-01 08:50:33 -08003902 "pairwise_distance": BUILT_IN_FUNC,
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003903 "pdist": BUILT_IN_FUNC,
3904 "pixel_shuffle": BUILT_IN_FUNC,
3905 "pixel_unshuffle": BUILT_IN_FUNC,
Peter Belle8d226c2022-02-01 08:50:33 -08003906 "prelu": BUILT_IN_FUNC,
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003907 "relu_": BUILT_IN_FUNC,
3908 "rrelu_": BUILT_IN_FUNC,
3909 "selu_": BUILT_IN_FUNC,
3910 "softplus": BUILT_IN_FUNC,
3911 "softshrink": BUILT_IN_FUNC,
3912 "threshold_": BUILT_IN_FUNC,
Shen Li10224432021-08-12 11:39:31 -07003913
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003914 "adaptive_avg_pool2d": LEN_ERROR,
3915 "adaptive_avg_pool3d": LEN_ERROR,
3916 "adaptive_max_pool2d_with_indices": LEN_ERROR,
3917 "adaptive_max_pool3d_with_indices": LEN_ERROR,
Joel Schlosser7d2a9f22021-04-23 10:51:49 -07003918 "instance_norm": CONTROL_FLOW,
Shen Li10224432021-08-12 11:39:31 -07003919
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003920 "adaptive_max_pool1d": PROXY_ITERABLE,
3921 "adaptive_max_pool2d": PROXY_ITERABLE,
3922 "adaptive_max_pool3d": PROXY_ITERABLE,
3923 "fractional_max_pool2d": PROXY_ITERABLE,
3924 "fractional_max_pool3d": PROXY_ITERABLE,
3925 "max_pool1d": PROXY_ITERABLE,
3926 "max_pool2d": PROXY_ITERABLE,
3927 "max_pool3d": PROXY_ITERABLE,
Shen Li10224432021-08-12 11:39:31 -07003928
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003929 "group_norm": PROXY_ITERATED,
3930 "lp_pool2d": PROXY_ITERATED,
3931 "max_unpool1d": PROXY_ITERATED,
3932 "max_unpool2d": PROXY_ITERATED,
3933 "max_unpool3d": PROXY_ITERATED,
Shen Li10224432021-08-12 11:39:31 -07003934
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003935 "adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH,
3936 "fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
3937 "fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH,
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003938 "layer_norm": ARG_TYPE_MISMATCH,
3939 "lp_pool1d": ARG_TYPE_MISMATCH,
Shen Li10224432021-08-12 11:39:31 -07003940
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003941 "affine_grid": CONTROL_FLOW,
3942 "alpha_dropout": CONTROL_FLOW,
3943 "batch_norm": CONTROL_FLOW,
3944 "binary_cross_entropy": CONTROL_FLOW,
3945 "binary_cross_entropy_with_logits": CONTROL_FLOW,
3946 "celu": CONTROL_FLOW,
3947 "cosine_embedding_loss": CONTROL_FLOW,
3948 "cross_entropy": CONTROL_FLOW,
3949 "ctc_loss": CONTROL_FLOW,
3950 "dropout": CONTROL_FLOW,
Joel Benjamin Schlosser2d73c8e2022-06-14 17:22:18 -04003951 "dropout1d": CONTROL_FLOW,
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003952 "dropout2d": CONTROL_FLOW,
3953 "dropout3d": CONTROL_FLOW,
3954 "elu": CONTROL_FLOW,
3955 "embedding": CONTROL_FLOW,
3956 "embedding_bag": CONTROL_FLOW,
3957 "feature_alpha_dropout": CONTROL_FLOW,
3958 "fold": CONTROL_FLOW,
3959 "gaussian_nll_loss": CONTROL_FLOW,
3960 "glu": CONTROL_FLOW,
3961 "grid_sample": CONTROL_FLOW,
3962 "gumbel_softmax": CONTROL_FLOW,
3963 "hardsigmoid": CONTROL_FLOW,
3964 "hardswish": CONTROL_FLOW,
3965 "hardtanh": CONTROL_FLOW,
3966 "hinge_embedding_loss": CONTROL_FLOW,
3967 "huber_loss": CONTROL_FLOW,
3968 "interpolate": CONTROL_FLOW,
3969 "kl_div": CONTROL_FLOW,
3970 "l1_loss": CONTROL_FLOW,
3971 "leaky_relu": CONTROL_FLOW,
3972 "local_response_norm": CONTROL_FLOW,
3973 "margin_ranking_loss": CONTROL_FLOW,
Kushashwa Ravi Shrimali452c26b2022-03-04 18:42:23 +00003974 "max_pool1d_with_indices": ARG_TYPE_MISMATCH,
3975 "max_pool2d_with_indices": ARG_TYPE_MISMATCH,
3976 "max_pool3d_with_indices": ARG_TYPE_MISMATCH,
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003977 "mse_loss": CONTROL_FLOW,
3978 "multi_head_attention_forward": CONTROL_FLOW,
3979 "multi_margin_loss": CONTROL_FLOW,
3980 "multilabel_margin_loss": CONTROL_FLOW,
3981 "multilabel_soft_margin_loss": CONTROL_FLOW,
3982 "nll_loss": CONTROL_FLOW,
3983 "poisson_nll_loss": CONTROL_FLOW,
3984 "relu": CONTROL_FLOW,
3985 "relu6": CONTROL_FLOW,
3986 "rrelu": CONTROL_FLOW,
3987 "selu": CONTROL_FLOW,
3988 "silu": CONTROL_FLOW,
Adnios09a8f222021-05-25 10:34:50 -07003989 "mish": CONTROL_FLOW,
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003990 "smooth_l1_loss": CONTROL_FLOW,
3991 "soft_margin_loss": CONTROL_FLOW,
3992 "threshold": CONTROL_FLOW,
3993 "triplet_margin_loss": CONTROL_FLOW,
3994 "triplet_margin_with_distance_loss": CONTROL_FLOW,
3995 "unfold": CONTROL_FLOW,
3996 "upsample": CONTROL_FLOW,
Shen Li10224432021-08-12 11:39:31 -07003997
Erjia Guanb96cc9a2021-04-16 06:46:46 -07003998 "upsample_bilinear": INTERPOLATE_ARGS_CONFLICT,
3999 "upsample_nearest": INTERPOLATE_ARGS_CONFLICT,
Erjia Guanb96cc9a2021-04-16 06:46:46 -07004000 }
4001
4002 # List of nn.functionals with Tensor inputs but not with type annotation
4003 FUNCTIONALS_WITHOUT_ANNOTATION = (
4004 "adaptive_max_pool1d",
4005 "adaptive_max_pool2d",
4006 "adaptive_max_pool3d",
4007 "fractional_max_pool2d",
4008 "fractional_max_pool3d",
4009 "max_pool1d",
4010 "max_pool2d",
4011 "max_pool3d",
4012 "gaussian_nll_loss",
4013 "upsample",
4014 "upsample_bilinear",
4015 "upsample_nearest",
4016 )
4017
4018 # Inconsistent behavior between Python 3.8 and other Python versions:
Nikita Shulgab5873542021-05-10 10:50:23 -07004019 # - Python 3.8+: Re-raise internal exception like `PROXY_ITERATED`
Erjia Guanb96cc9a2021-04-16 06:46:46 -07004020 # - Other Python: Raise `argument of type 'Proxy' is not iterable` due to the same
4021 # internal exception above
4022 # Use the following map to override the expected exception for Python 3.8
4023 UNTRACEABLE_FUNCTIONALS_PY38 = {
4024 "adaptive_max_pool1d": PROXY_ITERATED,
4025 "adaptive_max_pool2d": PROXY_ITERATED,
4026 "adaptive_max_pool3d": PROXY_ITERATED,
4027 "fractional_max_pool2d": PROXY_ITERATED,
4028 "fractional_max_pool3d": PROXY_ITERATED,
4029 "max_pool1d": PROXY_ITERATED,
4030 "max_pool2d": PROXY_ITERATED,
4031 "max_pool3d": PROXY_ITERATED,
Shen Li10224432021-08-12 11:39:31 -07004032
4033 "group_norm": LEN_ERROR
Erjia Guanb96cc9a2021-04-16 06:46:46 -07004034 }
4035
4036 @classmethod
4037 def _get_functional(cls):
4038 functional_list = []
4039 for f in dir(torch.nn.functional):
4040 if not f.islower():
4041 continue
4042 # Ignore internal functions
Shen Li10224432021-08-12 11:39:31 -07004043 if f.startswith('_'):
Erjia Guanb96cc9a2021-04-16 06:46:46 -07004044 continue
4045 # Ignore supporting functions
4046 if f in cls.IGNORE_FUNCS:
4047 continue
4048 fn = getattr(torch.nn.functional, f)
4049 # Ignore non-callable object like modules
4050 if not isinstance(fn, Callable):
4051 continue
4052 if f not in cls.FUNCTIONALS_WITHOUT_ANNOTATION:
4053 try:
4054 sig = inspect.signature(fn)
4055 has_tensor_arg = False
4056 for arg, param in sig.parameters.items():
Shen Li10224432021-08-12 11:39:31 -07004057 if isinstance(param.annotation, type) and issubclass(param.annotation, torch.Tensor):
Erjia Guanb96cc9a2021-04-16 06:46:46 -07004058 has_tensor_arg = True
4059 if not has_tensor_arg:
4060 continue
4061 # No signature or Object is not supported
4062 except ValueError:
4063 pass
4064 functional_list.append((f, fn))
4065 return functional_list
4066
4067 @classmethod
4068 def generate_test_func(cls, func_name, fn):
Shen Li10224432021-08-12 11:39:31 -07004069
Erjia Guanb96cc9a2021-04-16 06:46:46 -07004070 def functional_test(self):
Shen Li10224432021-08-12 11:39:31 -07004071 if func_name in self.UNTRACEABLE_FUNCTIONALS_PY38 and \
Nikita Shulgad80fe492022-07-27 20:22:47 +00004072 sys.version_info >= (3, 8) and sys.version_info < (3, 11):
Erjia Guanb96cc9a2021-04-16 06:46:46 -07004073 exc, err = self.UNTRACEABLE_FUNCTIONALS_PY38[func_name]
4074 with self.assertRaisesRegex(exc, err):
4075 symbolic_trace(fn)
4076 elif func_name in self.UNTRACEABLE_FUNCTIONALS:
4077 exc, err = self.UNTRACEABLE_FUNCTIONALS[func_name]
4078 with self.assertRaisesRegex(exc, err):
4079 symbolic_trace(fn)
4080 else:
4081 symbolic_trace(fn)
4082 return functional_test
4083
4084 @classmethod
4085 def generate_tests(cls):
4086 functional_list = cls._get_functional()
4087 for func_name, fn in functional_list:
4088 test_name = "test_nn_functional_" + func_name
4089 functional_test = cls.generate_test_func(func_name, fn)
4090 setattr(cls, test_name, functional_test)
4091
4092 @classmethod
4093 def setUpClass(cls):
Shen Li10224432021-08-12 11:39:31 -07004094
Erjia Guanb96cc9a2021-04-16 06:46:46 -07004095 def no(*args, **kwargs):
4096 return False
4097
4098 for name in cls.TO_PATCH.keys():
4099 cls.TO_PATCH[name] = getattr(torch.nn.functional, name)
4100 setattr(torch.nn.functional, name, no)
4101
4102 @classmethod
4103 def tearDownClass(cls):
4104 for name in cls.TO_PATCH.keys():
4105 setattr(torch.nn.functional, name, cls.TO_PATCH[name])
4106
4107TestFunctionalTracing.generate_tests()
4108
4109
James Reed255b1032021-03-17 20:39:16 -07004110instantiate_device_type_tests(TestOperatorSignatures, globals())
4111
Suraj Subramanian78022aa2021-04-22 08:52:45 -07004112@skipIfNoTorchVision
soulitzer0fcdf932022-07-25 11:47:44 -04004113@skipIfSlowGradcheckEnv
Suraj Subramanian78022aa2021-04-22 08:52:45 -07004114class TestVisionTracing(JitTestCase):
James Reede1c3e5f2021-09-02 21:11:57 -07004115 def setUp(self):
Nikita Shulga80bf2ea2022-07-08 22:53:44 +00004116 # Checking for mutable operations while tracing is feature flagged
James Reede1c3e5f2021-09-02 21:11:57 -07004117 # Enable it in testing but not by default
4118 self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
4119 torch.fx.proxy.TracerBase.check_mutable_operations = True
4120
4121 def tearDown(self):
4122 torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
4123
Suraj Subramanian78022aa2021-04-22 08:52:45 -07004124 PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated")
4125 INCONSISTENT_TYPE = (
4126 RuntimeError,
Shen Li10224432021-08-12 11:39:31 -07004127 r"Return value was annotated as having type __torch__.torchvision.models[.\w]+ but is actually of type Tensor"
Suraj Subramanian78022aa2021-04-22 08:52:45 -07004128 )
4129
4130 UNTRACEABLE_MODELS = {
4131 "fasterrcnn_resnet50_fpn": PROXY_ITERATED,
Nikita Shulga80bf2ea2022-07-08 22:53:44 +00004132 "fasterrcnn_resnet50_fpn_v2": PROXY_ITERATED,
Suraj Subramanian78022aa2021-04-22 08:52:45 -07004133 "fasterrcnn_mobilenet_v3_large_320_fpn": PROXY_ITERATED,
4134 "fasterrcnn_mobilenet_v3_large_fpn": PROXY_ITERATED,
4135 "maskrcnn_resnet50_fpn": PROXY_ITERATED,
Nikita Shulga80bf2ea2022-07-08 22:53:44 +00004136 "maskrcnn_resnet50_fpn_v2": PROXY_ITERATED,
Suraj Subramanian78022aa2021-04-22 08:52:45 -07004137 "keypointrcnn_resnet50_fpn": PROXY_ITERATED,
4138 "retinanet_resnet50_fpn": PROXY_ITERATED,
Nikita Shulga80bf2ea2022-07-08 22:53:44 +00004139 "retinanet_resnet50_fpn_v2": PROXY_ITERATED,
4140 "ssd300_vgg16": PROXY_ITERATED,
4141 "fcos_resnet50_fpn": PROXY_ITERATED,
4142 "ssdlite320_mobilenet_v3_large": PROXY_ITERATED,
Suraj Subramanian78022aa2021-04-22 08:52:45 -07004143 }
4144 UNSCRIPTABLE_MODELS = {
4145 "googlenet": INCONSISTENT_TYPE,
4146 "inception_v3": INCONSISTENT_TYPE,
4147 }
4148
4149 output_transform = {
4150 "fcn_resnet50": lambda x: x["out"],
4151 "fcn_resnet101": lambda x: x["out"],
4152 "deeplabv3_resnet50": lambda x: x["out"],
4153 "deeplabv3_resnet101": lambda x: x["out"],
4154 "deeplabv3_mobilenet_v3_large": lambda x: x["out"],
4155 "lraspp_mobilenet_v3_large": lambda x: x["out"],
4156 "fasterrcnn_resnet50_fpn": lambda x: x[1],
4157 "fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
4158 "fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1],
4159 "maskrcnn_resnet50_fpn": lambda x: x[1],
4160 "keypointrcnn_resnet50_fpn": lambda x: x[1],
4161 "retinanet_resnet50_fpn": lambda x: x[1],
4162 }
4163
4164 @classmethod
Vasilis Vryniotis6a098472022-08-11 07:38:35 +00004165 def generate_test_fn(cls, name, x, kwargs):
Suraj Subramanian78022aa2021-04-22 08:52:45 -07004166 def run_test(self):
Vasilis Vryniotis6a098472022-08-11 07:38:35 +00004167 model = torchvision_models.get_model(name, **kwargs)
Suraj Subramanian78022aa2021-04-22 08:52:45 -07004168 model = model.eval()
4169 if name in self.UNTRACEABLE_MODELS:
4170 err, exc = self.UNTRACEABLE_MODELS[name]
4171 with self.assertRaisesRegex(err, exc):
4172 graph = symbolic_trace(model)
4173 else:
4174 out_transform = self.output_transform.get(name, lambda x: x)
Shen Li10224432021-08-12 11:39:31 -07004175 graph : torch.fx.GraphModule = symbolic_trace(model)
Suraj Subramanian78022aa2021-04-22 08:52:45 -07004176 a = out_transform(model(x))
4177 b = out_transform(graph(x))
4178 self.assertEqual(a, b)
4179
4180 if name in self.UNSCRIPTABLE_MODELS:
4181 err, exc = self.UNSCRIPTABLE_MODELS[name]
4182 with self.assertRaisesRegex(err, exc):
4183 script = torch.jit.script(graph)
4184 else:
4185 script = torch.jit.script(graph)
4186 c = out_transform(script(x))
4187 self.assertEqual(a, c)
4188
4189 return run_test
4190
4191 @classmethod
4192 def generate_classification_tests(cls):
Vasilis Vryniotis6a098472022-08-11 07:38:35 +00004193 for k in torchvision_models.list_models(module=torchvision_models):
4194 test_name = 'test_torchvision_models_' + k
4195 x = torch.rand(1, 3, 299, 299) if k in ['inception_v3'] else torch.rand(1, 3, 224, 224)
4196 kwargs = dict(num_classes=50)
4197 model_test = cls.generate_test_fn(k, x, kwargs)
4198 setattr(cls, test_name, model_test)
Suraj Subramanian78022aa2021-04-22 08:52:45 -07004199
4200 @classmethod
4201 def generate_segmentation_tests(cls):
Vasilis Vryniotis6a098472022-08-11 07:38:35 +00004202 for k in torchvision_models.list_models(module=torchvision_models.segmentation):
4203 test_name = 'test_torchvision_models_segmentation_' + k
4204 x = torch.rand(1, 3, 32, 32)
4205 kwargs = dict(num_classes=10, pretrained_backbone=False)
4206 model_test = cls.generate_test_fn(k, x, kwargs)
4207 setattr(cls, test_name, model_test)
Suraj Subramanian78022aa2021-04-22 08:52:45 -07004208
4209 @classmethod
4210 def generate_detection_tests(cls):
Vasilis Vryniotis6a098472022-08-11 07:38:35 +00004211 for k in torchvision_models.list_models(module=torchvision_models.detection):
4212 test_name = 'test_torchvision_models_detection_' + k
4213 x = [torch.rand(3, 300, 300)]
4214 kwargs = dict(num_classes=10, pretrained_backbone=False)
4215 model_test = cls.generate_test_fn(k, x, kwargs)
4216 setattr(cls, test_name, model_test)
Suraj Subramanian78022aa2021-04-22 08:52:45 -07004217
4218 @classmethod
4219 def generate_video_tests(cls):
Vasilis Vryniotis6a098472022-08-11 07:38:35 +00004220 for k in torchvision_models.list_models(module=torchvision_models.video):
4221 test_name = 'test_torchvision_models_video_' + k
Vasilis Vryniotis7e058792022-09-05 13:15:55 +00004222 x = (
4223 torch.rand(1, 3, 4, 112, 112)
4224 if k not in {"mvit_v1_b", "mvit_v2_s", "s3d"}
4225 else torch.rand(1, 3, 16, 224, 224)
4226 )
Vasilis Vryniotis6a098472022-08-11 07:38:35 +00004227 kwargs = dict(num_classes=50)
4228 model_test = cls.generate_test_fn(k, x, kwargs)
4229 setattr(cls, test_name, model_test)
Suraj Subramanian78022aa2021-04-22 08:52:45 -07004230
4231 @classmethod
4232 def generate_tests(cls):
4233 cls.generate_classification_tests()
4234 cls.generate_detection_tests()
4235 cls.generate_segmentation_tests()
4236 cls.generate_video_tests()
4237
4238if HAS_TORCHVISION:
4239 TestVisionTracing.generate_tests()
4240
Shen Li10224432021-08-12 11:39:31 -07004241if __name__ == '__main__':
James Reed575e7492020-08-11 09:57:01 -07004242 run_tests()