Fix splitter_base and add unit test for trt splitter (#67569)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67569
Splitter_base has assumption that the first subgraph after split must be cpu subgraph if there exists cpu node. This is wrong, start subgraph should be determined by which subgraph has 0-dep node.
Also add unit test for splitter.
Reviewed By: yinghai
Differential Revision: D32012549
fbshipit-source-id: e2639ccd7774b4295ca05c2ddbefff9726702b3f
diff --git a/test/fx/passes/trt_splitter_test.py b/test/fx/passes/trt_splitter_test.py
new file mode 100644
index 0000000..dc42756
--- /dev/null
+++ b/test/fx/passes/trt_splitter_test.py
@@ -0,0 +1,1147 @@
+import operator
+import unittest
+
+import torch # isort:skip
+import torch.fx # isort:skip
+
+import torch.fx.experimental.fx_acc.acc_ops as acc_ops
+import torch.fx.passes.operator_support as op_support
+import torch.fx.passes.shape_prop as shape_prop
+from torch.fx.experimental.fx2trt.tools.trt_splitter import TRTSplitter
+from torch.fx.passes import splitter_base
+from torch.fx.experimental.fx_acc import acc_tracer
+
+
+ERROR_MSG_NO_ACC_MODULE = "FX split failed: Did not find any ACC submodule!"
+ERROR_MSG_MULTI_ACC_MODULES = "FX split failed: Found more than one ACC submodules!"
+ACC_SUBMODULE_PREFIX = "_run_on_acc_"
+
+# Check if the split result has expected number of ACC submodule. If not, raise runtime error;
+def verify_split_model(
+ mod: torch.fx.GraphModule, acc_submodule_keyword: str = ACC_SUBMODULE_PREFIX, expected_number: int = 1,
+) -> None:
+ acc_submodule_num = 0
+ for name, _ in mod.named_children():
+ if name.startswith(acc_submodule_keyword):
+ acc_submodule_num = acc_submodule_num + 1
+
+ if acc_submodule_num < expected_number:
+ raise RuntimeError(ERROR_MSG_NO_ACC_MODULE)
+ elif acc_submodule_num > expected_number:
+ raise RuntimeError(ERROR_MSG_MULTI_ACC_MODULES)
+
+def find_inputs(module):
+ return [n for n in module.graph.nodes if n.op == "placeholder"]
+
+
+def find_fun_calls(module, target):
+ return [
+ n for n in module.graph.nodes if n.op == "call_function" and n.target == target
+ ]
+
+
+def find_output(module):
+ return next(n for n in module.graph.nodes if n.op == "output")
+
+
+TENSOR_SIZE_DUMMY = "tensor_size_dummy"
+
+
+def find_call_targets(module: torch.fx.GraphModule):
+ result = set()
+ for n in module.graph.nodes:
+ n: torch.fx.Node
+ if n.op in {"call_module", "call_function", "call_method"}:
+ result.add(n.target)
+ return result
+
+
+# We test both FxNetSplitOnly and FxNetSplitter here, since they share most
+# functionalities. The only difference is that FxNetSplitOnly does not implement
+# split_preview() related functions, while FxNetSplitter does.
+class TestSplit(unittest.TestCase):
+ def test_demo(self):
+ """
+ ==> b ==>
+ // \\
+ a d
+ \\ //
+ ==> c ==>
+ """
+
+ class SimpleModule(torch.nn.Module):
+ def forward(self, a):
+ b = torch.sin(a)
+ c = torch.cos(a)
+ d = b + c
+ return d
+
+ mod = acc_tracer.trace(SimpleModule(), torch.randn(2, 3))
+
+ # Making b and c run on ACC
+ splitter = TRTSplitter(
+ mod,
+ (torch.randn(2, 3),),
+ op_support_with_support_dict(
+ {
+ "acc_ops.sin": None,
+ "acc_ops.cos": None,
+ }
+ ),
+ )
+
+ st_split = splitter()
+
+ [arg] = find_inputs(st_split)
+
+ # First subgraph calculates b = sin(a) and c = cos(a) on ACC
+ [sin] = find_fun_calls(st_split._run_on_acc_0, acc_ops.sin)
+ self.assertEqual(arg.name, sin.kwargs["input"].name)
+
+ [cos] = find_fun_calls(st_split._run_on_acc_0, acc_ops.cos)
+ self.assertEqual(arg.name, cos.kwargs["input"].name)
+
+ # Second subgraph calculates d = b + c on CPU
+ [add] = find_fun_calls(st_split._run_on_cpu_1, acc_ops.add)
+ self.assertEqual(sin.name, add.kwargs["input"].name)
+ self.assertEqual(cos.name, add.kwargs["other"].name)
+
+ def test_mod_with_getattr(self):
+ """
+ CPU subgraph should have get_attr for self.a while ACC subgraph
+ should have get_attr for self.b.
+ """
+
+ class SimpleModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.a = torch.randn(1, 1, 1, 1)
+ self.b = torch.randn(1, 1, 1, 1)
+ self.conv = torch.nn.Conv2d(1, 1, 1)
+ self.linear = torch.nn.Linear(1, 1)
+
+ def forward(self, x):
+ x = x + self.a
+ x = self.conv(x)
+ return self.linear(x - self.b)
+
+ mod = acc_tracer.trace(SimpleModule(), torch.randn(1, 1, 1, 1))
+ mod.eval()
+
+ splitter = TRTSplitter(
+ mod,
+ (torch.randn(1, 1, 1, 1),),
+ op_support_with_support_dict(
+ {
+ "acc_ops.linear": None,
+ "acc_ops.sub": None,
+ }
+ ),
+ )
+
+ def test_splitter(splitter):
+ st_split = splitter()
+ verify_split_model(st_split)
+ # Should be "a", "conv.weight", "conv.bias".
+ get_attr_nodes = [
+ node.target
+ for node in st_split._run_on_cpu_0.graph.nodes
+ if node.op == "get_attr"
+ ]
+ assert len(get_attr_nodes) == 3 and "a" in get_attr_nodes
+
+ # Should be "b", "conv.weight", "conv.bias".
+ get_attr_nodes = [
+ node.target
+ for node in st_split._run_on_acc_1.graph.nodes
+ if node.op == "get_attr"
+ ]
+ assert len(get_attr_nodes) == 3 and "b" in get_attr_nodes
+
+ test_splitter(splitter)
+
+ def test_nothing_to_split(self):
+ class SimpleModule(torch.nn.Module):
+ def forward(self, a):
+ return a
+
+ mod = acc_tracer.trace(SimpleModule(), torch.randn(2, 3))
+
+ # Mark any operation as runnable on ACC
+ class CustomOpSupport(op_support.OperatorSupportBase):
+ def is_node_supported(self, submodules, node):
+ return True
+
+ splitter = TRTSplitter(
+ mod, (torch.randn(2, 3),), CustomOpSupport()
+ )
+
+ def test_splitter(splitter):
+ st_split = splitter()
+ try:
+ verify_split_model(st_split)
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err), ERROR_MSG_NO_ACC_MODULE
+ )
+ self.assertEqual(splitter.module.__dict__.keys(), st_split.__dict__.keys())
+
+ test_splitter(splitter)
+
+ def test_multi_output(self):
+ class MultiOutputModule(torch.nn.Module):
+ def forward(self, x):
+ res, ind = torch.topk(x, 3)
+ return torch.sigmoid(res), ind
+
+ mod = acc_tracer.trace(MultiOutputModule(), torch.randn(2, 3))
+
+ # Mark any operation as runnable on ACC
+ class CustomOpSupport(op_support.OperatorSupportBase):
+ def is_node_supported(self, submodules, node):
+ return True
+
+ splitter = TRTSplitter(
+ mod, (torch.randn(2, 3),), CustomOpSupport()
+ )
+
+ def test_splitter(splitter):
+ st_split = splitter()
+ verify_split_model(st_split)
+ [arg] = find_inputs(st_split)
+
+ # There is only one subgraph that executes topk and sigmoid on ACC
+ [topk] = find_fun_calls(st_split._run_on_acc_0, acc_ops.topk)
+ self.assertEqual(arg.name, topk.kwargs["input"].name)
+ self.assertEqual(3, topk.kwargs["k"])
+
+ [topk_res1, topk_res2] = find_fun_calls(
+ st_split._run_on_acc_0, acc_ops.getitem
+ )
+
+ [sigmoid] = find_fun_calls(st_split._run_on_acc_0, acc_ops.sigmoid)
+ self.assertIn(
+ sigmoid.kwargs["input"].name, {topk_res1.name, topk_res2.name}
+ )
+
+ # Main graph returns a tuple
+ output = find_output(st_split._run_on_acc_0)
+ self.assertLess(
+ {output.args[0][0].name, output.args[0][1].name},
+ {topk_res1.name, topk_res2.name, sigmoid.name},
+ )
+
+ test_splitter(splitter)
+
+ def test_nested_modules(self):
+ """
+ x
+ // \\
+ // \\
+ relu(x) sin(x)
+ \\ //
+ \\ //
+ relu(x) + sin(x)
+ """
+
+ class ReluModule(torch.nn.Module):
+ def forward(self, x):
+ return torch.relu(x)
+
+ class SinModule(torch.nn.Module):
+ def forward(self, x):
+ return torch.sin(x)
+
+ class TestModule3(torch.nn.Module):
+ def __init__(self, relu_module, sin_module):
+ super().__init__()
+ self.relu_module = relu_module
+ self.sin_module = sin_module
+
+ def forward(self, x):
+ return self.relu_module(x) + self.sin_module(x)
+
+ mod = acc_tracer.trace(TestModule3(ReluModule(), SinModule()), torch.randn(2, 3))
+
+ # Making sin(x) run on ACC
+ splitter = TRTSplitter(
+ mod,
+ (torch.randn(2, 3),),
+ op_support_with_support_dict(
+ {
+ "acc_ops.sin": None,
+ }
+ ),
+ )
+
+ def test_splitter(splitter):
+ st_split = splitter()
+ verify_split_model(st_split)
+ [arg] = find_inputs(st_split)
+
+ # First subgraph calculates relu(x) on CPU
+ [relu] = find_fun_calls(st_split._run_on_cpu_0, acc_ops.relu)
+ self.assertEqual(arg.name, relu.kwargs["input"].name)
+
+ # Second subgraph calculates sin(x) on ACC
+ [sin] = find_fun_calls(st_split._run_on_acc_1, acc_ops.sin)
+ self.assertEqual(arg.name, sin.kwargs["input"].name)
+
+ # Third subgraph calculates sum on CPU
+ [add] = find_fun_calls(st_split._run_on_cpu_2, acc_ops.add)
+ self.assertEqual(relu.name, add.kwargs["input"].name)
+ self.assertEqual(sin.name, add.kwargs["other"].name)
+
+ # Checking that results of applying split module will be the same
+ tensor = torch.randn(5)
+ self.assertTrue(torch.equal(mod(tensor), st_split(tensor)))
+
+ test_splitter(splitter)
+
+ def test_longer_chain(self):
+ """
+ sin relu cos sigmoid tanh
+ a ====> b =====> c ====> d ========> e =====> f
+ """
+
+ class TestModule(torch.nn.Module):
+ def forward(self, a):
+ b = torch.sin(a)
+ c = torch.relu(b)
+ d = torch.cos(c)
+ e = torch.sigmoid(d)
+ f = torch.tanh(e)
+ return f
+
+ mod = acc_tracer.trace(TestModule(), torch.randn(2, 3))
+
+ # Making relu and sigmoid execute on ACC
+ splitter = TRTSplitter(
+ mod,
+ (torch.randn(2, 3),),
+ op_support_with_support_dict(
+ {
+ "acc_ops.relu": None,
+ "acc_ops.sigmoid": None,
+ }
+ ),
+ )
+
+ def test_splitter(splitter):
+ st_split = splitter()
+ try:
+ verify_split_model(st_split)
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err), ERROR_MSG_MULTI_ACC_MODULES
+ )
+ [arg] = find_inputs(st_split)
+
+ # First subgraph calculates b = sin(a) on CPU
+ [sin] = find_fun_calls(st_split._run_on_cpu_0, acc_ops.sin)
+ self.assertEqual(arg.name, sin.kwargs["input"].name)
+
+ # Second subgraph calculates c = relu(b) on ACC
+ [relu] = find_fun_calls(st_split._run_on_acc_1, acc_ops.relu)
+ self.assertEqual(sin.name, relu.kwargs["input"].name)
+
+ # Third subgraph calculates d = cos(c) on CPU
+ [cos] = find_fun_calls(st_split._run_on_cpu_2, acc_ops.cos)
+ self.assertEqual(relu.name, cos.kwargs["input"].name)
+
+ # Fourth subgraph calculates e = sigmoid(d) on ACC
+ [sigmoid] = find_fun_calls(st_split._run_on_acc_3, acc_ops.sigmoid)
+ self.assertEqual(cos.name, sigmoid.kwargs["input"].name)
+
+ # Fifth subgraph calculates f = tanh(e) on CPU
+ [tanh] = find_fun_calls(st_split._run_on_cpu_4, acc_ops.tanh)
+ self.assertEqual(sigmoid.name, tanh.kwargs["input"].name)
+
+ test_splitter(splitter)
+
+ def test_min_acc_module_size(self):
+ """
+ sin relu cos sigmoid tanh
+ a ====> b =====> c ====> d ========> e =====> f
+
+ We set sin, cos and tanh as acc node but also set min_acc_module_size to 2
+ and expect the whole module stay on CPU.
+ """
+
+ class TestModule(torch.nn.Module):
+ def forward(self, a):
+ b = torch.sin(a)
+ c = torch.relu(b)
+ d = torch.cos(c)
+ e = torch.sigmoid(d)
+ f = torch.tanh(e)
+ return f
+
+ mod = acc_tracer.trace(TestModule(), torch.randn(2, 3))
+
+ # Set sin, cos and tanh as acc node and split with settings
+ class CustomOpSupport(op_support.OperatorSupport):
+ _support_dict = {
+ "acc_ops.sin": None,
+ "acc_ops.cos": None,
+ "acc_ops.tanh": None,
+ }
+
+ # Create splitter setting and set min_acc_module_size to 2
+ settings = splitter_base._SplitterSettingBase()
+ settings.min_acc_module_size = 2
+ splitter = TRTSplitter(
+ mod,
+ (torch.randn(2, 3),),
+ op_support_with_support_dict(
+ {
+ "acc_ops.sin": None,
+ "acc_ops.cos": None,
+ "acc_ops.tanh": None,
+ }
+ ),
+ settings,
+ )
+
+ def test_splitter(splitter):
+ st_split = splitter()
+ try:
+ verify_split_model(st_split)
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err), ERROR_MSG_NO_ACC_MODULE
+ )
+ modules = list(st_split.named_modules())
+ # Main module and a submodule
+ assert len(modules) == 2
+
+ assert modules[1][0] == "_run_on_cpu_0"
+
+ test_splitter(splitter)
+
+ def test_extend_acc_subgraph_after_split(self):
+ class TestModule(torch.nn.Module):
+ r""" a (input)
+ |
+ b
+ / \
+ c d
+ \ /
+ e
+ / \
+ | (g1, g2, g3, g4)
+ \ / |
+ f |
+ \ |
+ h
+
+ c and f are not runnable on acc while all other nodes are supported by acc.
+ g1, g2, g3 and g4 should be in a fusion group, let's call it g.
+
+ After split we have 2 cpu subgraphs (c) and (f), 3 acc subgraphs (b, d), (e, g) and (h).
+ We expect 3 acc subgraphs (b), (d, e, g) and (h) after extend the second acc subgraph.
+ And expect acc subgraphs stay the same after extend the third acc subgraph because of
+ the unbreakable fusion group.
+ """
+
+ def forward(self, a: torch.Tensor):
+ b = a + a
+ c = b - b
+ d = b + b
+ e = c + d
+
+ # These four nodes should be in a fusion group
+ g1 = e.size()
+ g2 = g1[0]
+ g3 = e + g2
+ g4 = g3 + g2
+
+ f = e - g3
+ h = f + g4
+ return h
+
+ a = torch.randn(2)
+ mod = acc_tracer.trace(TestModule(), (a,))
+
+ # Allow all nodes expect subtract run on accelerator
+ class CustomOpSupport(op_support.OperatorSupportBase):
+ def is_node_supported(self, submodules, node):
+ return op_support.get_node_target(submodules, node) != "acc_ops.sub"
+
+ splitter = TRTSplitter(mod, (a,), CustomOpSupport())
+
+ def test_splitter(splitter):
+ # Manually tag nodes first in case split algorithm changes in the future
+ nodes = list(splitter.module.graph.nodes)
+ # b and d
+ nodes[1].tag = "acc_0"
+ nodes[3].tag = "acc_0"
+ # c
+ nodes[2].tag = "cpu_1"
+ # e and g
+ nodes[4].tag = "acc_2"
+ nodes[5].tag = "acc_2"
+ nodes[6].tag = "acc_2"
+ nodes[7].tag = "acc_2"
+ nodes[8].tag = "acc_2"
+ # f
+ nodes[9].tag = "cpu_3"
+ # h
+ nodes[10].tag = "acc_4"
+
+ splitter.tags = ["acc_0", "cpu_1", "acc_2", "cpu_3", "acc_4"]
+ split_module = splitter.split()
+ try:
+ verify_split_model(split_module, "acc_")
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err), ERROR_MSG_MULTI_ACC_MODULES
+ )
+ try:
+ verify_split_model(split_module)
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err), ERROR_MSG_NO_ACC_MODULE
+ )
+
+ module_names = [name for name, _ in split_module.named_modules()]
+ # Main module, 2 cpu submodules and 3 acc submodule
+ assert len(module_names) == 6
+
+ # 1 Placeholder, 2 Adds and 1 Output
+ assert len(split_module.acc_0.graph.nodes) == 4
+ # 2 Placeholder, 3 Adds, 1 Size, 1 GetItem and 1 Output
+ assert len(split_module.acc_2.graph.nodes) == 8
+
+ # Extend the second acc subgraph
+ splitter.extend_acc_subgraph("acc_2")
+ extend_module = splitter.split()
+ try:
+ verify_split_model(extend_module, "acc_")
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err), ERROR_MSG_MULTI_ACC_MODULES
+ )
+
+ # 1 Placeholder, 1 Adds and 1 Output
+ assert len(extend_module.acc_0.graph.nodes) == 3
+ # 2 Placeholder, 4 Adds 1 Size, 1 GetItem and 1 Output
+ assert len(extend_module.acc_2.graph.nodes) == 9
+
+ # Extend the third acc subgraph
+ splitter.extend_acc_subgraph("acc_4")
+ extend_module = splitter.split()
+ try:
+ verify_split_model(extend_module, "acc_")
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err), ERROR_MSG_MULTI_ACC_MODULES
+ )
+
+ assert len(extend_module.acc_2.graph.nodes) == 9
+ # 2 Placeholder, 1 Adds and 1 Output
+ assert len(extend_module.acc_4.graph.nodes) == 4
+
+ test_splitter(splitter)
+
+ def test_get_attr_into_output(self):
+ """
+ Here we verify the case when get_attr node is consumed directly by the
+ output. We don't expect any split to happen in this test, just want to
+ make sure that the splitter code doesn't break.
+ """
+
+ class TestModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.a = torch.randn(2, 3)
+
+ def forward(self, x):
+ return (x, self.a)
+
+ # No need to put anything on ACC.
+ class TestOperatorSupport:
+ def is_node_supported(self, submodules, node):
+ return False
+
+ module_original = acc_tracer.trace(TestModule(), torch.randn(4, 5))
+
+ splitter = TRTSplitter(
+ module=module_original,
+ sample_input=torch.randn(4, 5),
+ operator_support=TestOperatorSupport(),
+ )
+
+ def test_splitter(splitter):
+ module_split = splitter()
+ try:
+ verify_split_model(module_split)
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err), ERROR_MSG_NO_ACC_MODULE
+ )
+
+ output = find_output(module_split)
+ # Second argument of the output should be get_attr.
+ self.assertEqual("get_attr", output.args[0][1].op)
+
+ # Check if modules are equivalent.
+ tensor = torch.randn(10, 20)
+ result_original = module_original(tensor)
+ result_split = module_split(tensor)
+ self.assertTrue(torch.equal(result_original[0], result_split[0]))
+ self.assertTrue(torch.equal(result_original[1], result_split[1]))
+
+ test_splitter(splitter)
+
+ def test_get_attr_into_starter_node(self):
+ """
+ Here we verify the case when starter nodes depend on get_attr node only.
+ We don't expect any split to happen in this test, just want to make sure
+ that the splitter code doesn't break.
+ """
+
+ class TestModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.a = torch.randn(2, 3)
+
+ def forward(self):
+ m = self.a + self.a
+ o = m + m
+ return o
+
+ # No need to put anything on ACC.
+ class TestOperatorSupport:
+ def is_node_supported(self, submodules, node):
+ return False
+
+ module_original = acc_tracer.trace(TestModule(), torch.randn(2, 3))
+
+ splitter = TRTSplitter(
+ module=module_original,
+ sample_input=torch.randn(2, 3),
+ operator_support=TestOperatorSupport(),
+ )
+
+ def test_splitter(splitter):
+ module_split = splitter()
+ try:
+ verify_split_model(module_split)
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err), ERROR_MSG_NO_ACC_MODULE
+ )
+
+ # Check if modules are equivalent.
+ result_original = module_original()
+ result_split = module_split()
+ self.assertTrue(torch.equal(result_original, result_split))
+
+ test_splitter(splitter)
+
+
+class TestSplitComplexGraph(unittest.TestCase):
+ """
+ a ======
+ // \\ \\
+ b c d
+ \\ // //
+ e //
+ \\ //
+ \\ //
+ f
+ """
+
+ class TestModule(torch.nn.Module):
+ def forward(self, a):
+ b = torch.sin(a)
+ c = torch.relu(a)
+ d = torch.cos(a)
+ e = b + c
+ f = e - d
+ return f
+
+ def test_split_complex_graph_1(self):
+ mod = acc_tracer.trace(self.TestModule(), torch.randn(2, 3))
+
+ # Making 'c' and 'd' run on ACC
+ splitter = TRTSplitter(
+ mod,
+ (torch.randn(2, 3),),
+ op_support_with_support_dict(
+ {
+ "acc_ops.cos": None,
+ "acc_ops.relu": None,
+ }
+ ),
+ )
+
+ def test_splitter(splitter):
+ st_split = splitter()
+ verify_split_model(st_split)
+
+ [arg] = find_inputs(st_split)
+
+ # First subgraph calculates b = sin(a) on CPU
+ [sin] = find_fun_calls(st_split._run_on_cpu_0, acc_ops.sin)
+ self.assertEqual(arg.name, sin.kwargs["input"].name)
+
+ # Second subgraph calculates c = relu(a) and d = cos(a) on ACC
+ [relu] = find_fun_calls(st_split._run_on_acc_1, acc_ops.relu)
+ self.assertEqual(arg.name, relu.kwargs["input"].name)
+
+ [cos] = find_fun_calls(st_split._run_on_acc_1, acc_ops.cos)
+ self.assertEqual(arg.name, cos.kwargs["input"].name)
+
+ # Third subgraph calculates the e = b + c and f = e - d on CPU
+ [add] = find_fun_calls(st_split._run_on_cpu_2, acc_ops.add)
+ self.assertEqual(sin.name, add.kwargs["input"].name)
+ self.assertEqual(relu.name, add.kwargs["other"].name)
+
+ [sub] = find_fun_calls(st_split._run_on_cpu_2, acc_ops.sub)
+ self.assertEqual(add.name, sub.kwargs["input"].name)
+ self.assertEqual(cos.name, sub.kwargs["other"].name)
+
+ test_splitter(splitter)
+
+ def test_split_complex_graph_2(self):
+ module_nn = self.TestModule()
+ module = acc_tracer.trace(module_nn, (torch.randn(2, 3),))
+
+ # Making 'c', 'd' and 'e' run on ACC
+ splitter = TRTSplitter(
+ module,
+ (torch.randn(2, 3),),
+ op_support_with_support_dict(
+ {
+ "acc_ops.cos": None,
+ "acc_ops.relu": None,
+ "acc_ops.add": None,
+ }
+ ),
+ )
+
+ def test_splitter(splitter):
+ module_fx_split = splitter()
+ verify_split_model(module_fx_split)
+
+ [arg] = find_inputs(module)
+
+ # First subgraph calculates b = sin(a) on CPU
+ [sin] = find_fun_calls(module_fx_split._run_on_cpu_0, acc_ops.sin)
+ self.assertEqual(arg.name, sin.kwargs["input"].name)
+
+ # Second subgraph calculates c = relu(a), d = cos(a) and e = b + c on ACC
+ [relu] = find_fun_calls(module_fx_split._run_on_acc_1, acc_ops.relu)
+ self.assertEqual(arg.name, relu.kwargs["input"].name)
+
+ [cos] = find_fun_calls(module_fx_split._run_on_acc_1, acc_ops.cos)
+ self.assertEqual(arg.name, cos.kwargs["input"].name)
+
+ [add] = find_fun_calls(module_fx_split._run_on_acc_1, acc_ops.add)
+ self.assertEqual(sin.name, add.kwargs["input"].name)
+ self.assertEqual(relu.name, add.kwargs["other"].name)
+
+ # Third subgraph calculates f = e + d on CPU
+ [sub] = find_fun_calls(module_fx_split._run_on_cpu_2, acc_ops.sub)
+ self.assertEqual(add.name, sub.kwargs["input"].name)
+ self.assertEqual(cos.name, sub.kwargs["other"].name)
+
+ test_splitter(splitter)
+
+
+class TestSplitNonTensorEdges(unittest.TestCase):
+ """
+ a (relu)
+ // \\
+ (b1,b2) c (cos)
+ \\ //
+ d (add)
+ ||
+ e (sigmoid)
+ """
+
+ # Note non-tensor edge between b2 and d
+ class TestModule(torch.nn.Module):
+ def forward(self, x):
+ a = torch.relu(x)
+
+ b1 = a.size()
+ b2 = b1[0]
+
+ c = torch.cos(a)
+
+ d = b2 + c
+ e = torch.sigmoid(d)
+ return e
+
+ def test_split_non_tensor_edges_1(self):
+ test_data = torch.randn(2, 3)
+
+ module_nn = acc_tracer.trace(self.TestModule(), (test_data,))
+
+ # Making 'a', 'b1', 'b2', 'd' and 'e' run on ACC
+ splitter = TRTSplitter(
+ module_nn,
+ (test_data,),
+ op_support_with_support_dict(
+ {
+ "acc_ops.relu": None,
+ "acc_ops.sigmoid": None,
+ "acc_ops.add": None,
+ "acc_ops.getitem": None,
+ "acc_ops.size": None,
+ }
+ ),
+ )
+
+ def test_splitter(splitter):
+ module_fx_split = splitter()
+ try:
+ verify_split_model(module_fx_split)
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err), ERROR_MSG_MULTI_ACC_MODULES
+ )
+
+ self.assertEqual(
+ {acc_ops.relu}, find_call_targets(module_fx_split._run_on_acc_0)
+ )
+
+ self.assertEqual(
+ {acc_ops.cos}, find_call_targets(module_fx_split._run_on_cpu_1)
+ )
+
+ self.assertEqual(
+ {acc_ops.size, acc_ops.getitem, acc_ops.add, acc_ops.sigmoid},
+ find_call_targets(module_fx_split._run_on_acc_2),
+ )
+
+ # Make sure we can compile to TorchScript
+ module_jit = torch.jit.trace_module(module_fx_split, {"forward": test_data})
+ self.assertTrue(torch.allclose(module_nn(test_data), module_jit(test_data)))
+
+ test_splitter(splitter)
+
+ def test_split_non_tensor_edges_2(self):
+ test_data = torch.randn(2, 3)
+
+ module_nn = acc_tracer.trace(self.TestModule(), (test_data,))
+
+ # Making 'a', 'b1', 'b2', 'd' and 'e' run on ACC with limit on ACC
+ # subgraph size
+ settings = splitter_base._SplitterSettingBase()
+ settings.min_acc_module_size = 2
+ splitter = TRTSplitter(
+ module_nn,
+ (test_data,),
+ op_support_with_support_dict(
+ {
+ "acc_ops.relu": None,
+ "acc_ops.sigmoid": None,
+ "acc_ops.add": None,
+ "acc_ops.getitem": None,
+ "acc_ops.size": None,
+ }
+ ),
+ settings,
+ )
+
+ def test_splitter(splitter):
+ module_fx_split = splitter()
+ verify_split_model(module_fx_split)
+
+ self.assertEqual(
+ {acc_ops.relu, acc_ops.cos},
+ find_call_targets(module_fx_split._run_on_cpu_0),
+ )
+
+ self.assertEqual(
+ {acc_ops.size, acc_ops.getitem, acc_ops.add, acc_ops.sigmoid},
+ find_call_targets(module_fx_split._run_on_acc_1),
+ )
+
+ # Make sure we can compile to TorchScript
+ module_jit = torch.jit.trace_module(module_fx_split, {"forward": test_data})
+ self.assertTrue(torch.allclose(module_nn(test_data), module_jit(test_data)))
+
+ test_splitter(splitter)
+
+ def test_split_non_tensor_edges_3(self):
+ test_data = torch.randn(2, 3)
+
+ module_nn = acc_tracer.trace(self.TestModule(), (test_data,),)
+
+ # Making 'a', 'c', 'd' and 'e' run on ACC
+ splitter = TRTSplitter(
+ module_nn,
+ (test_data,),
+ op_support_with_support_dict(
+ {
+ "acc_ops.relu": None,
+ "acc_ops.sigmoid": None,
+ "acc_ops.cos": None,
+ "acc_ops.add": None,
+ }
+ ),
+ )
+
+ def test_splitter(splitter):
+ module_fx_split = splitter()
+ try:
+ verify_split_model(module_fx_split)
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err), ERROR_MSG_MULTI_ACC_MODULES
+ )
+
+ self.assertEqual(
+ {acc_ops.relu, acc_ops.cos},
+ find_call_targets(module_fx_split._run_on_acc_0),
+ )
+
+ self.assertEqual(
+ {acc_ops.size, acc_ops.getitem, acc_ops.add},
+ find_call_targets(module_fx_split._run_on_cpu_1),
+ )
+
+ self.assertEqual(
+ {acc_ops.sigmoid},
+ find_call_targets(module_fx_split._run_on_acc_2),
+ )
+
+ # Make sure we can compile to TorchScript
+ module_jit = torch.jit.trace_module(module_fx_split, {"forward": test_data})
+ self.assertTrue(torch.allclose(module_nn(test_data), module_jit(test_data)))
+
+ test_splitter(splitter)
+
+ def test_split_non_tensor_edges_4(self):
+ test_data = torch.randn(2, 3)
+
+ module_nn = acc_tracer.trace(self.TestModule(), (test_data,),)
+
+ # Making 'a', 'c', 'd' and 'e' run on ACC with limit on ACC
+ # subgraph size
+ settings = splitter_base._SplitterSettingBase()
+ settings.min_acc_module_size = 2
+ splitter = TRTSplitter(
+ module_nn,
+ (test_data,),
+ op_support_with_support_dict(
+ {
+ "acc_ops.relu": None,
+ "acc_ops.sigmoid": None,
+ "acc_ops.cos": None,
+ "acc_ops.add": None,
+ }
+ ),
+ settings,
+ )
+
+ def test_splitter(splitter):
+ module_fx_split = splitter()
+ verify_split_model(module_fx_split)
+
+ self.assertEqual(
+ {acc_ops.relu, acc_ops.cos},
+ find_call_targets(module_fx_split._run_on_acc_0),
+ )
+
+ self.assertEqual(
+ {acc_ops.size, acc_ops.getitem, acc_ops.add, acc_ops.sigmoid},
+ find_call_targets(module_fx_split._run_on_cpu_1),
+ )
+
+ # Make sure we can compile to TorchScript
+ module_jit = torch.jit.trace_module(module_fx_split, {"forward": test_data})
+ self.assertTrue(torch.allclose(module_nn(test_data), module_jit(test_data)))
+
+ test_splitter(splitter)
+
+
+class TestAccNodesFinder(unittest.TestCase):
+ def test_acc_nodes_finder_1(self):
+ """
+ y ------------->
+ |
+ ----> b ---->
+ x ----> a d
+ ----> c ---->
+ |
+ z ------------->
+ """
+
+ # Make a return non-tensor data
+ class TestModule(torch.nn.Module):
+ def forward(self, x, y, z):
+ a1 = x.size()
+ a1 = a1[0]
+
+ b = y + a1
+ c = z - a1
+
+ d = b + c
+
+ return d
+
+ module_nn = TestModule()
+ module_fx = torch.fx.symbolic_trace(module_nn)
+
+ # Make a and c lowerable to ACC
+ finder = torch.fx.passes.splitter_base.FxNetAccNodesFinder(
+ module_fx,
+ op_support_with_support_dict(
+ {
+ "acc_ops.sub": None,
+ "acc_ops.getitem": None,
+ "acc_ops.size": None,
+ }
+ ),
+ False,
+ )
+ acc_nodes = finder()
+ self.assertEqual(set(), acc_nodes, "Shouldn't have ACC nodes")
+
+
+class TestAccFusionsFinder(unittest.TestCase):
+ """
+ x
+ / \\
+ a b
+ / | \\
+ / | a2
+ a0 a1 |
+ | / |
+ c |
+ | |
+ d |
+ \\ /
+ e
+ """
+
+ class TestModule(torch.nn.Module):
+ def forward(self, x):
+ a = x.size()
+ b = x + x
+
+ a0 = a[0]
+ a1 = a[1]
+ a2 = a[2]
+ c = x.view(a1, a0, -1)
+
+ d = c + c
+ e = d + a2
+ return b, e
+
+ def test_acc_fusions_finder_1(self):
+ """
+ Assume every node is acc node. We should have one fusion group
+ (a, a0, a1, a2, c, d, e).
+ """
+ module_nn = self.TestModule()
+ module_fx = torch.fx.symbolic_trace(module_nn)
+ shape_prop.ShapeProp(module_fx).propagate(torch.randn(1, 1, 1))
+
+ acc_node = {
+ node
+ for node in module_fx.graph.nodes
+ if node.op in torch.fx.passes.tools_common.CALLABLE_NODE_OPS
+ }
+
+ fusions_finder = torch.fx.passes.splitter_base.FxNetAccFusionsFinder(
+ module_fx,
+ acc_node,
+ )
+ fusion_map = fusions_finder()
+
+ self.assertEqual(len(fusion_map), 7)
+ for _, v in fusion_map.items():
+ self.assertEqual(len(v), 7)
+
+ def test_acc_fusions_finder_2(self):
+ """
+ Let b and d be cpu nodes. After fusion all nodes should be cpu nodes
+ because d is included in the fusion group which force all other nodes
+ in the same fusion group to be on CPU too.
+ """
+ module_nn = self.TestModule()
+ module_fx = torch.fx.symbolic_trace(module_nn)
+ shape_prop.ShapeProp(module_fx).propagate(torch.randn(1, 1, 1))
+
+ acc_node = {
+ node for node in module_fx.graph.nodes if node.target == operator.add
+ }
+ fusions_finder = torch.fx.passes.splitter_base.FxNetAccFusionsFinder(
+ module_fx,
+ acc_node,
+ )
+ fusion_map = fusions_finder()
+ self.assertEqual(len(fusion_map), 0)
+
+
+ def test_start_with_acc_module_(self):
+ """
+ sin relu cos sigmoid tanh
+ a ====> b =====> c ====> d ========> e =====> f
+
+ We set sin, relu and cos as acc node but also set min_acc_module_size to 2
+ and expect the whole module stay on CPU.
+ """
+
+ class TestModule(torch.nn.Module):
+ def forward(self, a):
+ b = torch.sin(a)
+ c = torch.relu(b)
+ d = torch.cos(c)
+ e = torch.sigmoid(d)
+ f = torch.tanh(e)
+ return f
+
+ mod = acc_tracer.trace(TestModule(), torch.randn(2, 3))
+
+ # Set sin, cos and tanh as acc node and split with settings
+ class CustomOpSupport(op_support.OperatorSupport):
+ _support_dict = {
+ "acc_ops.sin": None,
+ "acc_ops.cos": None,
+ "acc_ops.relu": None,
+ }
+
+ # Create splitter setting and set min_acc_module_size to 2
+ settings = splitter_base._SplitterSettingBase()
+ settings.min_acc_module_size = 2
+ splitter = TRTSplitter(
+ mod,
+ (torch.randn(2, 3),),
+ op_support_with_support_dict(
+ {
+ "acc_ops.sin": None,
+ "acc_ops.cos": None,
+ "acc_ops.relu": None,
+ }
+ ),
+ settings,
+ )
+
+ def test_splitter(splitter):
+ st_split = splitter()
+ try:
+ verify_split_model(st_split)
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err), ERROR_MSG_NO_ACC_MODULE
+ )
+ modules = list(st_split.named_modules())
+ # Main module and a submodule
+ assert len(modules) == 3
+
+ assert modules[1][0] == "_run_on_acc_0"
+ assert modules[2][0] == "_run_on_cpu_1"
+
+ test_splitter(splitter)
+
+
+def op_support_with_support_dict(support_dict: dict) -> op_support.OperatorSupportBase:
+ return op_support.OperatorSupport(support_dict)
diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py
index 2c3372f..831d115 100644
--- a/torch/fx/passes/splitter_base.py
+++ b/torch/fx/passes/splitter_base.py
@@ -655,8 +655,13 @@
current_cpu_nodes, current_acc_nodes = self.starter_nodes()
visited_nodes: NodeSet = set()
- # If there are CPU nodes, start with them
- acc_subgraph: bool = not current_cpu_nodes
+ # Determine which subgraph to start from based on node dependency
+ acc_subgraph: bool = True
+ for n in current_cpu_nodes:
+ if self.deps[n] <= visited_nodes:
+ acc_subgraph = False
+ break
+
current_subgraph_nodes: NodeList = []
# Result accumulator