blob: 5a47c729f7ebe9433c69b1be15fc864e32df1c90 [file] [log] [blame]
import torch
import unittest
import operator
import numbers
import pickle
import copy
import sys
import functools
import contextlib
from pathlib import Path
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Tracer, Graph
from torch.fx.experimental import shape_prop
from torch.fx.immutable_collections import immutable_dict, immutable_list
from copy import deepcopy
from torch.fx.proxy import TraceError
from fx.quantization import Quantizer
from typing import Any, Callable, Dict, NamedTuple, List, Optional, Tuple, Union
from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, IS_WINDOWS, IS_SANDCASTLE, IS_MACOS
from torch.testing._internal.jit_utils import JitTestCase
try:
from torchvision.models import resnet18
HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
class SimpleTest(torch.nn.Module):
def forward(self, x):
return torch.relu(x + 3.0)
def a_non_torch_leaf(a, b):
return a + b
class Pair(NamedTuple):
x : torch.Tensor
y : torch.Tensor
class TestFX(JitTestCase):
def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None):
"""Check that an nn.Module's results match the GraphModule version
for a given set of args/kwargs.
"""
kwargs = kwargs if kwargs else {}
ref_outs = m(*args, **kwargs)
gm = symbolic_trace(m)
gm.graph.lint(gm)
test_outs = gm(*args, **kwargs)
self.assertEqual(ref_outs, test_outs)
def test_graph_module(self):
class MySub(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.nn.Parameter(torch.rand(4, 3))
def forward(self, x):
return self.w + x
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(4, 3)
self.sub_mod = MySub()
self.w = torch.nn.Parameter(torch.rand(3))
def forward(self, A, B, c):
t = torch.sigmoid(A) + self.lin(c)
return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3))
m = MyModule()
gm = symbolic_trace(m)
ms = torch.jit.script(gm)
class M2(torch.nn.Module):
def forward(self, A):
m, idx = torch.max(A, 0)
return m + 1, idx + 1
m2 = M2()
gm2 = symbolic_trace(m2)
class T(torch.nn.Module):
def forward(self, A, b=4, *args, c=5, **kwargs):
x = A + 1 + args[0] + kwargs['3']
return x
t = T()
symbolic_trace(t)
def test_custom_import(self):
graph = torch.fx.Graph()
a = graph.placeholder('x')
b = graph.placeholder('y')
c = graph.call_function(a_non_torch_leaf, (a, b))
d = graph.call_function(torch.sin, (c,))
graph.output(d)
gm = GraphModule(torch.nn.Module(), graph)
x, y = torch.rand(1), torch.rand(1)
self.assertEqual(torch.sin(x + y), gm(x, y))
def test_args_kwargs(self):
class T(torch.nn.Module):
def forward(self, *args, **kwargs):
x = args[0] + kwargs['foo']
return x
t = T()
self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
def test_args_kwargs_no_self(self):
class T(torch.nn.Module):
def forward(*args, **kwargs): # noqa: B902
self = args[0]
return torch.relu(args[1])
t = T()
with self.assertRaisesRegex(RuntimeError, r'cannot be part of \*args expansion'):
self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
def test_fx_shifts(self):
class MyModule(torch.nn.Module):
def forward(self, x):
return x << 3, x >> 3
input = torch.LongTensor(10).random_(0, 1024)
m = MyModule()
self.checkGraphModule(m, (input,))
def test_dict(self):
class MyDictMod(torch.nn.Module):
def forward(self, d):
return d['3'].relu(), {'4' : d['3'].neg()}
input_dict = {'3': torch.rand(3, 4)}
m = MyDictMod()
self.checkGraphModule(m, (input_dict,))
def test_disallow_override(self):
# Custom delegate to disallow in-place tensor operations
class NoMutableCallTracer(Tracer):
def create_node(self, kind : str, target : Union[str, Callable],
args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None,
type_expr : Optional[Any] = None) -> Node:
name = target if isinstance(target, str) else torch.typename(target)
if name[-1] == '_':
raise RuntimeError('In-place operations are not supported')
return super().create_node(kind, target, args, kwargs, name)
# Test method
class MyInplaceMod(torch.nn.Module):
def forward(self, x):
x.add_(3.0)
return x
m = MyInplaceMod()
with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
NoMutableCallTracer().trace(m)
# Test free function
class MyInplaceMod2(torch.nn.Module):
def forward(self, x):
torch.log_(x)
return x
m2 = MyInplaceMod2()
with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
NoMutableCallTracer().trace(m2)
# Test symbolic node as an arg
class MyInplaceMod3(torch.nn.Module):
def forward(self, x):
y = torch.ones(3, 4)
y.add_(x)
return x
m3 = MyInplaceMod3()
with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
NoMutableCallTracer().trace(m3)
def test_leaf_module(self):
# Custom delegate to make it so that there are no leaf modules, everything
# should get traced through
class NoLeafModulesTracer(Tracer):
def is_leaf_module(self, m, qualname):
return False
class MyReluMod(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(x)
mrm = MyReluMod()
sym = NoLeafModulesTracer().trace(mrm)
for node in sym.nodes:
self.assertNotEqual(node.op, 'call_module')
sym.lint(sym)
def test_graph_edit_with_proxy(self):
class M(torch.nn.Module):
def forward(self, a, b):
return a + b
m = M()
g = symbolic_trace(m).graph
new_g = torch.fx.Graph()
val_map : Dict[Node, Node] = {}
output_val = new_g.graph_copy(g, val_map)
t = Proxy(output_val)
# test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
new_g.output((t + t).node)
gm = GraphModule(m, new_g)
gm.graph.lint(gm)
self.assertEqual(gm(3, 4), 14)
def test_graph_unique_names(self):
class M(torch.nn.Module):
def forward(self, a, b):
return a + b
m = M()
g = symbolic_trace(m).graph
new_g = torch.fx.Graph()
val_map : Dict[Node, Node] = {}
output_val = new_g.graph_copy(g, val_map)
t = Proxy(output_val)
# test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
new_g.output((t + t).node)
gm = GraphModule(m, new_g)
seen_names : Set[str] = set()
for node in gm.graph.nodes:
assert node.name not in seen_names
seen_names.add(node.name)
def test_graph_unique_names_manual(self):
graph : torch.fx.Graph = torch.fx.Graph()
a : torch.fx.Node = graph.create_node('placeholder', 'x')
b : torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,), name='foo_1_1')
c : torch.fx.Node = graph.create_node('get_attr', 'y_attr', name='foo_1')
d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
graph.output(d)
graph2 = torch.fx.Graph()
val_map : Dict[Node, Node] = {}
graph2.graph_copy(graph, val_map)
seen_names : Set[str] = set()
for node in graph2.nodes:
assert node.name not in seen_names
seen_names.add(node.name)
@skipIfNoTorchVision
def test_resnet(self):
resnet = resnet18()
resnet.train()
res_graph = symbolic_trace(resnet)
res_script = torch.jit.script(res_graph)
ip = torch.rand(1, 3, 224, 224)
a = resnet(ip)
b = res_graph(ip)
c = res_script(ip)
self.assertEqual(a, b)
self.assertEqual(a, c)
quantizer = Quantizer(res_graph)
for i in range(10):
quantizer.observe((torch.rand(1, 3, 224, 224),))
qgraph = quantizer.quantize()
qgraph.graph.lint(qgraph)
qgraph_script = torch.jit.script(qgraph)
d = qgraph(ip)
e = qgraph_script(ip)
assert (a - d).abs().max() < 2
self.assertEqual(d, e)
def test_unpack(self):
class M(torch.nn.Module):
def forward(self, a, b):
c, d = a
return c + d + b
a = (torch.rand(1), torch.rand(1))
b = torch.rand(1)
m = M()
self.checkGraphModule(m, (a, b))
def test_native_callable(self):
if TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS:
raise unittest.SkipTest("non-portable load_library call used in test")
torch_root = Path(__file__).resolve().parent.parent
p = torch_root / 'build' / 'lib' / 'libtorchbind_test.so'
torch.ops.load_library(str(p))
# This test exercises the case where we use FX to translate from Python
# code to some native callable object
#
# For the purposes of testing, we use ElementwiseInterpreter defined
# in test_custom_class.cpp.
#
# We test that we can
# 1) Construct a native callable from FX IR
# 2) Construct a drop-in replacement module that delegates to the
# native callable rather than the original code
# 3) Run both the original code and native callable wrapper with
# equivalent results
# 4) TorchScript compile the native callable wrapper and confirm
# equivalent results with the reference
# 5) TorchScript serialize and deserialize the native callable
# and confirm equivalent results with the reference
# We use this simple Module as a reference computation
class MySimpleMod(torch.nn.Module):
def forward(self, x):
return 3.0 * x + x
msm = MySimpleMod()
# This is what a lowering pass might look like: a function that takes
# a valid nn.Module, symbolically traces it, lowers the Module to some
# representation, and wraps that representation up into another
# nn.Module instance that handles dispatch to the compiled/lowered code.
def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Module:
# ===== Stage 1: Symbolic trace the module =====
mod = symbolic_trace(orig_mod)
# ===== Stage 2: Lower GraphModule representation to the C++
# interpreter's instruction format ======
instructions = []
constant_idx = 0
constants = {}
fn_input_names = []
target_to_name = {
operator.add : "add",
operator.mul : "mul"
}
output_node : Optional[Node] = None
# For each instruction, create a triple
# (instruction_name : str, inputs : List[str], output : str)
# to feed into the C++ interpreter
for n in mod.graph.nodes:
target, args, out_name = n.target, n.args, n.name
assert len(n.kwargs) == 0, "kwargs currently not supported"
if n.op == 'placeholder':
# Placeholders specify function argument names. Save these
# for later when we generate the wrapper GraphModule
fn_input_names.append(target)
elif n.op == 'call_function':
assert target in target_to_name, "Unsupported call target " + target
arg_names = []
for arg in args:
if not isinstance(arg, Node):
# Pull out constants. These constants will later be
# fed to the interpreter C++ object via add_constant()
arg_name = f'constant_{constant_idx}'
constants[arg_name] = torch.Tensor(
[arg] if isinstance(arg, numbers.Number) else arg)
arg_names.append(arg_name)
constant_idx += 1
else:
arg_names.append(arg.name)
instructions.append((target_to_name[target], arg_names, out_name))
elif n.op == 'output':
if output_node is not None:
raise RuntimeError('Multiple output nodes!')
output_node = n
else:
raise RuntimeError('Unsupported opcode ' + n.op)
interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter()
# Load constants
for k, v in constants.items():
interpreter.add_constant(k, v)
# Specify names for positional input arguments
interpreter.set_input_names(fn_input_names)
# Load instructions
interpreter.set_instructions(instructions)
# Specify name for single output
assert isinstance(output_node.args[0], torch.fx.Node)
interpreter.set_output_name(output_node.args[0].name)
# ===== Stage 3: Create a wrapper GraphModule around the interpreter =====
class WrapperModule(torch.nn.Module):
def __init__(self, interpreter):
super().__init__()
self.interpreter = interpreter
wrapper = WrapperModule(interpreter)
# Create a graph that: 1) Takes function arguments 2) Invokes the interpreter
# 3) Returns the speficied return value
# FIXME: The following code could be greatly simplified by symbolic_trace'ing
# the wrapper with a Tracer that considers the Wrapper instance a root
# module, however, I can't get `__call__` exposed on TorchBind classes
# without it messing up Python `hasattr` for some reason. More digging
# into CPython's implementation of hasattr is probably in order...
graph = torch.fx.Graph()
# Add placeholders for fn inputs
placeholder_nodes = []
for name in fn_input_names:
placeholder_nodes.append(graph.create_node('placeholder', name))
# Get the interpreter object
interpreter_node = graph.create_node('get_attr', 'interpreter')
# Add a node to call the interpreter instance
output_node = graph.create_node(
op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes))
# Register output
graph.output(output_node)
graph.lint(wrapper)
# Return final GraphModule!!!
return GraphModule(wrapper, graph)
# Lower GraphModule to C++ interpreter
lowered = lower_to_elementwise_interpreter(msm)
# Compare correctness with original module
x = torch.rand(3, 4)
ref_out = msm(x)
test_out = lowered(x)
torch.testing.assert_allclose(test_out, ref_out)
# Test TorchScript compilation
scripted_lowered = torch.jit.script(lowered)
script_out = scripted_lowered(x)
torch.testing.assert_allclose(script_out, ref_out)
# Test TorchScript ser/de
import_copy = self.getExportImportCopy(scripted_lowered)
imported_out = import_copy(x)
torch.testing.assert_allclose(imported_out, ref_out)
def test_reserved_getattr(self):
"""Ensure that we do not name any nodes with a reserved builtin like `getattr`"""
class M(torch.nn.Module):
def forward(self, a):
return a.foo.bar.baz
m = M()
m_g = symbolic_trace(m)
m_g.graph.lint(m_g)
for node in m_g.graph.nodes:
self.assertTrue(node.name != "getattr")
def test_node_tagging(self):
class TaggingTracer(Tracer):
def create_node(self, kind : str, target : Union[str, Callable],
args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None,
type_expr : Optional[Any] = None) -> Node:
n = super().create_node(kind, target, args, kwargs, name)
n.tag = 'foo'
return n
class M(torch.nn.Module):
def forward(self, a, b):
return a + b
m = M()
g = TaggingTracer().trace(m)
g.lint(m)
for n in g.nodes:
self.assertTrue(hasattr(n, 'tag'))
self.assertEqual(n.tag, 'foo')
def test_tensor_attribute(self):
class TensorAttribute(torch.nn.Module):
def __init__(self):
super().__init__()
self.tensor = torch.rand(3, 4)
def forward(self, x):
return torch.nn.functional.linear(x, self.tensor)
ta = TensorAttribute()
traced = symbolic_trace(ta)
traced(torch.rand(4, 4))
class WrapperForQualname(torch.nn.Module):
def __init__(self):
super().__init__()
self.ta = TensorAttribute()
def forward(self, x):
return torch.nn.functional.linear(x, self.ta.tensor)
wfq = WrapperForQualname()
traced2 = symbolic_trace(wfq)
traced2.graph.lint(traced2)
traced2(torch.rand(4, 4))
def test_symbolic_trace_sequential(self):
class Simple(torch.nn.Module):
def forward(self, x):
return torch.neg(x)
seq = torch.nn.Sequential(
Simple(),
Simple(),
Simple()
)
traced = symbolic_trace(seq)
traced.graph.lint(traced)
x = torch.rand(3, 4)
self.assertEqual(traced(x), seq(x))
def test_tensor_constant(self):
class ConstTensor(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.linear(x, torch.zeros(3, 4))
ct = ConstTensor()
traced = symbolic_trace(ct)
traced.graph.lint(traced)
traced(torch.rand(4, 4))
def test_pickle_graphmodule(self):
class Nested(torch.nn.Module):
def __init__(self):
super().__init__()
self.st = torch.nn.Linear(4, 4)
def forward(self, x):
return self.st(x)
n = Nested()
traced = symbolic_trace(n)
traced.graph.lint(traced)
pickled = pickle.dumps(traced)
loaded = pickle.loads(pickled)
loaded.graph.lint(loaded)
x = torch.rand(3, 4)
self.assertEqual(loaded(x), traced(x))
def test_all_input_nodes(self):
graph : torch.fx.Graph = torch.fx.Graph()
a : torch.fx.Node = graph.placeholder('x')
b : torch.fx.Node = graph.call_module('linear_mod', args=(a,))
c : torch.fx.Node = graph.get_attr('y_attr')
d : torch.fx.Node = graph.call_function(operator.add, args=(b, c))
e : torch.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0))
graph.output(e)
graph.lint()
self.assertEqual(b.all_input_nodes, [a])
self.assertEqual(c.all_input_nodes, [])
self.assertEqual(d.all_input_nodes, [b, c])
self.assertEqual(e.all_input_nodes, [d])
def test_deepcopy_graphmodule_with_transform(self):
st = SimpleTest()
traced = symbolic_trace(st)
traced.graph.lint(traced)
def transform(traced):
new_graph = torch.fx.Graph()
val_map : Dict[Node, Node] = {}
output_value = new_graph.graph_copy(traced.graph, val_map)
relu_out = new_graph.create_node(
op='call_method', target='neg', args=(output_value,), kwargs={})
new_graph.output(relu_out)
return GraphModule(traced, new_graph)
transformed = transform(traced)
transformed.graph.lint(transformed)
copied = copy.deepcopy(transformed)
self.assertNotEqual(id(type(transformed)), id(type(copied)))
x = torch.randn(3, 4)
self.assertEqual(copied(x), transformed(x))
def test_deepcopy_with_submods_params(self):
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
def forward(self, x):
return torch.relu(x) + self.param
class Baz(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.bar = Bar()
def forward(self, x):
return self.bar(x) - self.param
baz = Baz()
traced = symbolic_trace(baz)
traced.graph.lint(traced)
copied = copy.deepcopy(traced)
copied.graph.lint(copied)
def test_unpack_list_better_error(self):
class SomeArgs(torch.nn.Module):
def forward(self, a, b):
return torch.rand(3, 4)
class UnpacksList(torch.nn.Module):
def __init__(self):
super().__init__()
self.sa = SomeArgs()
def forward(self, x : list):
return self.sa(*x)
ul = UnpacksList()
with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
symbolic_trace(ul)
def test_unpack_dict_better_error(self):
class SomeKwargs(torch.nn.Module):
def forward(self, x=3, y=4):
return torch.rand(3, 4)
class UnpacksDict(torch.nn.Module):
def __init__(self):
super().__init__()
self.sk = SomeKwargs()
def forward(self, x : dict):
return self.sk(**x)
ud = UnpacksDict()
with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
symbolic_trace(ud)
def test_script_tensor_constant(self):
# TorchScript seems to ignore attributes that start with `__`.
# We used to call anonymous Tensor values `__tensor_constant*`, but
# they were getting ignored by script. Now they're called
# `_tensor_constant*`
class IHaveATensorConstant(torch.nn.Module):
def forward(self, x):
return x + torch.rand(3, 4)
traced = torch.fx.symbolic_trace(IHaveATensorConstant())
torch.jit.script(traced)
def test_torch_custom_ops(self):
class M(torch.nn.Module):
def forward(self, a):
b = torch.ops.aten.sigmoid(a)
c = torch.ops.aten.cat([a, b])
return torch.ops.aten.cat((c, c))
m = M()
input = torch.randn(3)
ref_out = m(input)
gm = symbolic_trace(m)
gm.graph.lint(gm)
out = gm(input)
self.assertEqual(out, ref_out)
def test_pretty_print(self):
st = SimpleTest()
traced = symbolic_trace(st)
traced.graph.lint(traced)
printed = str(traced)
assert 'GraphModuleImpl()' in printed
assert 'torch.relu' in printed
def test_pretty_print_graph(self):
class KwargPrintTest(torch.nn.Module):
def forward(self, x):
return torch.squeeze(x + 3.0, dim=2)
st = KwargPrintTest()
traced = symbolic_trace(st)
traced.graph.lint(traced)
stringed = str(traced.graph)
for s in ['args', 'kwargs', '#users']:
assert s in stringed
def test_graph_fns(self):
g = Graph()
a = g.placeholder('a')
b = g.call_module('linear', (a,))
c = g.get_attr('bias')
d = g.call_method('add', (b, c))
e = g.call_function(torch.sin, (d,))
g.output(e)
mod = torch.nn.Module()
mod.linear = torch.nn.Linear(3, 4)
mod.bias = torch.rand(4)
gm = GraphModule(mod, g)
gm.graph.lint(gm)
input = torch.rand(3)
r = gm(input)
ref = torch.sin(mod.linear(input) + mod.bias)
self.assertEqual(r, ref)
def test_remove_uses(self):
g : torch.fx.Graph = Graph()
x : torch.fx.Node = g.placeholder('x')
relu : torch.fx.Node = g.call_function(torch.relu, (x,))
neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
g.output(neg)
neg.replace_all_uses_with(relu)
g.erase_node(neg)
self.assertTrue(neg not in relu.users)
def test_construct_root_dict(self):
graph : torch.fx.Graph = torch.fx.Graph()
a : torch.fx.Node = graph.create_node('placeholder', 'x')
b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
graph.output(d)
linear_mod : torch.nn.Module = torch.nn.Linear(3, 4)
add_param : torch.Tensor = torch.rand(3, 4)
gm : torch.fx.GraphModule = torch.fx.GraphModule(
{'foo.bar.baz': linear_mod, 'zip.zap.zam' : add_param}, graph)
gm.graph.lint(gm)
assert 'self.foo.bar.baz' in gm.code
x : torch.Tensor = torch.rand(3, 3)
out : torch.Tensor = gm(x)
ref_out : torch.Tensor = linear_mod(x) + add_param
self.assertEqual(out, ref_out)
def test_symbolic_trace_assert(self):
class AssertsTensorShape(torch.nn.Module):
def forward(self, x):
torch._assert(x.shape[1] > 4, "assert_foobar")
return x
m = AssertsTensorShape()
# verify traceability
traced = symbolic_trace(m)
# verify assertion on traced model works correctly at runtime
traced(torch.rand(4, 5))
with self.assertRaisesRegex(AssertionError, "assert_foobar"):
traced(torch.rand(4, 3))
# verify the symbolically traced module is scriptable
ms = torch.jit.script(m)
with self.assertRaisesRegex(torch.jit.Error, "assert_foobar"):
ms(torch.rand(4, 3))
def test_copy_no_remap(self):
traced = symbolic_trace(SimpleTest())
g = traced.graph
copied = torch.fx.Graph()
for node in g.nodes:
copied.node_copy(node)
with self.assertRaisesRegex(RuntimeError, 'does not belong to this Graph'):
copied.lint()
def test_wrong_topo(self):
graph : torch.fx.Graph = torch.fx.Graph()
a : torch.fx.Node = graph.create_node('placeholder', 'x')
b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
graph.output(d)
nodes = list(graph.nodes)
nodes[3].append(nodes[2])
with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'):
graph.lint()
def test_example_shape_prop(self):
class TestCase(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.randn(3, 4)
self.submod = torch.nn.Linear(4, 4)
def forward(self, x):
return torch.neg(self.submod(x.relu() + self.attr))
tc = TestCase()
tc_traced = symbolic_trace(tc)
ref_out = tc_traced(torch.rand(3, 4))
shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4))
# Make sure we're testing all opcodes
opcodes = set()
output_shape : Optional[torch.Shape] = None
for node in tc_traced.graph.nodes:
opcodes.add(node.op)
if node.op == 'output':
output_shape = node.args[0].shape
self.assertEqual(opcodes, set(['placeholder', 'get_attr', 'call_function', 'call_method',
'call_module', 'output']))
# Test shape propogation and make sure results match actual
self.assertEqual(output_shape, ref_out.shape)
def test_fn_type_annotations(self):
class Foo(torch.nn.Module):
def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]:
return {'a': p.x + p.y + z + i}
foo_scripted = torch.jit.script(Foo())
foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
fxed = symbolic_trace(Foo())
fxed_scripted = torch.jit.script(fxed)
fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
def test_fn_type_annotation_empty(self):
def forward(a : List[torch.Tensor]):
return a[0]
torch.jit.script(symbolic_trace(forward))
def test_wrapped_method(self):
def wrap_with_relu(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
return torch.relu(fn(*args, **kwargs))
return wrapper
class Foo(torch.nn.Module):
@wrap_with_relu
def forward(self, x, w):
return torch.matmul(x, w)
f = Foo()
traced = symbolic_trace(f)
x, w = torch.rand(3, 4), torch.rand(4, 4)
self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes))
def test_sequential(self):
m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1))
gm = torch.fx.symbolic_trace(m)
gm_copy = copy.deepcopy(gm)
def test_ctx_mgr(self):
@contextlib.contextmanager
def do_nothing():
yield
class M(torch.nn.Module):
def __init__(self):
super().__init__()
@do_nothing()
def forward(self, x):
return torch.relu(x)
m = M()
self.checkGraphModule(m, (torch.rand(3, 4),))
def test_typename_print(self):
graph : torch.fx.Graph = torch.fx.Graph()
x : torch.fx.Node = graph.create_node('placeholder', 'x')
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,),
type_expr=List[float])
output : torch.fx.Node = graph.output(b)
self.assertTrue('typing.List[float]' in str(graph))
def test_inf_nan(self):
class FooMod(torch.nn.Module):
def forward(self, x):
return x + float('inf'), x + float('-inf'), x + float('nan')
fm = FooMod()
self.checkGraphModule(fm, (torch.rand(3, 4),))
def test_inf_nan_kwds(self):
graph : torch.fx.Graph = torch.fx.Graph()
x : torch.fx.Node = graph.create_node('placeholder', 'x')
b : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('inf')), {}, name='inf')
c : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('nan')), {}, name='nan')
graph.output((b, c))
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
x = torch.rand(3, 4)
self.assertEqual(gm(x), (x + float('inf'), x + float('nan')))
def test_deepcopy_recursion_depth(self):
depth = sys.getrecursionlimit() + 20
g = torch.fx.Graph()
x = g.placeholder('x')
for i in range(depth):
x = g.call_function(torch.relu, (x,))
g.output(x)
copied_graph = copy.deepcopy(g)
val_map = {}
for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
val_map[orig_node] = new_node
for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
orig_users = set(orig_node.users.keys())
orig_users_equiv = set(val_map[u] for u in orig_users)
new_users = set(new_node.users.keys())
self.assertEqual(orig_users_equiv, new_users)
@skipIfNoTorchVision
def test_replace_uses(self):
rn18 = resnet18()
class LowerReluTracer(torch.fx.Tracer):
def is_leaf_module(self, m : torch.nn.Module, qualname : str):
if isinstance(m, torch.nn.ReLU):
return False
return super().is_leaf_module(m, qualname)
rn18_traced = GraphModule(rn18, LowerReluTracer().trace(rn18))
to_erase = []
for node in rn18_traced.graph.nodes:
if node.op == 'call_function' and node.target in [torch.relu, torch.nn.functional.relu]:
kwargs = node.kwargs.copy()
# Neg doesn't have in-place
kwargs.pop('inplace')
with rn18_traced.graph.inserting_before(node):
new_node = rn18_traced.graph.call_function(
the_function=torch.neg, args=node.args, kwargs=node.kwargs)
node.replace_all_uses_with(replace_with=new_node)
to_erase.append(node)
for node in to_erase:
rn18_traced.graph.erase_node(node)
def test_insertion_point(self):
graph : torch.fx.Graph = torch.fx.Graph()
x : torch.fx.Node = graph.create_node('placeholder', 'x')
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
output : torch.fx.Node = graph.output(b)
with graph.inserting_before(b):
neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
_, *relu_args = b.args
b.args = (neg, *relu_args)
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
input = torch.randn(33, 44)
self.assertEqual(gm(input), torch.relu(torch.neg(input)))
def test_move_before(self):
graph : torch.fx.Graph = torch.fx.Graph()
x : torch.fx.Node = graph.create_node('placeholder', 'x')
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
output : torch.fx.Node = graph.output(b)
neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
_, *relu_args = b.args
b.args = (neg, *relu_args)
b.prepend(neg)
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
input = torch.randn(33, 44)
self.assertEqual(gm(input), torch.relu(torch.neg(input)))
def test_erase_node_error(self):
st = SimpleTest()
traced = symbolic_trace(st)
for node in traced.graph.nodes:
# Test deleting with uses both in another Node and at the output
if node.target in [operator.add, torch.relu]:
with self.assertRaisesRegex(RuntimeError, 'but it still had .* users in the graph'):
traced.graph.erase_node(node)
def test_copy_it(self):
d = immutable_dict([(3, 4), (5, 6)])
l = immutable_list([(3, 4), (5, 6)])
self.assertEqual(d, deepcopy(d))
self.assertEqual(l, deepcopy(l))
def test_find_uses(self):
graph = torch.fx.Graph()
x = torch.fx.Proxy(graph.placeholder('x'))
y = torch.relu(x)
z = x + x
u = torch.neg(x)
graph.output((y + z + u).node)
graph.lint()
users_of_x = x.node.users
self.assertEqual(len(users_of_x), 3)
expected_ops = set(['relu', 'add', 'neg'])
for use in users_of_x:
assert any(use.name.startswith(prefix) for prefix in expected_ops)
def test_inline_graph(self):
class InlineInto(torch.nn.Module):
def forward(self, x):
return torch.relu(x)
class ToInline(torch.nn.Module):
def forward(self, x):
return torch.neg(x)
inline_into = symbolic_trace(InlineInto())
to_inline = symbolic_trace(ToInline())
combined_graph = torch.fx.Graph()
output_node = combined_graph.graph_copy(inline_into.graph, {})
input_node = list(to_inline.graph.nodes)[0]
assert input_node and input_node.op == 'placeholder'
val_map = {input_node : output_node}
output = combined_graph.graph_copy(to_inline.graph, val_map)
combined_graph.output(output)
combined_module = torch.fx.GraphModule(torch.nn.Module(), combined_graph)
input = torch.rand(3, 4)
self.assertEqual(combined_module(input), input.relu().neg())
def test_multi_insert_point(self):
graph = torch.fx.Graph()
x = torch.fx.Proxy(graph.placeholder('x'))
relu = torch.relu(x)
with graph.inserting_before(relu.node):
y = torch.neg(x)
z = torch.tanh(y)
graph.output((relu.node, z.node))
graph.lint()
expected_ops = ['x', 'neg', 'tanh', 'relu']
for node, expected in zip(graph.nodes, expected_ops):
assert expected in node.name
def test_reassign_args_kwargs_uses(self):
graph = torch.fx.Graph()
x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y'))
z = x + y
zed = z + z + z
graph.output(zed.node)
graph.lint()
# zed = z + z + z -> zed = z + z + x
zed.node.args = (zed.node.args[0], x.node)
self.assertEqual(x.node.users.keys(), [z.node, zed.node])
# z = x + y -> z = y + y
z.node.args = (y.node, y.node)
self.assertEqual(x.node.users.keys(), [zed.node])
def test_trace_function(self):
def foo(x, y):
return torch.relu(x) + y
x, y = torch.randn(3, 4), torch.randn(3, 4)
self.checkGraphModule(foo, (x, y))
def test_direct_param_use(self):
class TransposeTest(torch.nn.Module):
def __init__(self):
super().__init__()
self.b = torch.nn.Parameter(torch.rand(4, 3))
def forward(self, x):
return self.b
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = TransposeTest()
def forward(self, x):
return self.a.b, self.a.b.t(), self.a.b.view(12)
traced = torch.fx.symbolic_trace(Foo())
assert(all('constant' not in node.target for node in traced.graph.nodes))
def test_single_default_arg(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, y=1):
return y
m = M()
self.checkGraphModule(m, ())
self.checkGraphModule(m, (3,))
def test_multiple_default_args(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, y=1, z=2):
return y + z
m = M()
self.checkGraphModule(m, ())
self.checkGraphModule(m, (3,))
self.checkGraphModule(m, (3, 4))
def test_regular_and_default_args(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y=1):
return x + y
m = M()
self.checkGraphModule(m, (2,))
self.checkGraphModule(m, (2, 3))
def test_string_literal_return(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self):
return "foo"
m = M()
self.checkGraphModule(m, ())
if __name__ == '__main__':
run_tests()