blob: 5bc609e5d65cb2165de85dc78d9a7016426e46b9 [file] [log] [blame]
# Owner(s): ["module: fx"]
import torch
from torch.fx.passes.split_utils import split_by_tags
from torch.testing._internal.common_utils import TestCase
class TestFXSplit(TestCase):
def test_split_preserve_node_meta(self):
class TestModule(torch.nn.Module):
def forward(self, x, y):
x = x + x
y = y * y
return x - y
gm = torch.fx.symbolic_trace(TestModule())
for node in gm.graph.nodes:
node.meta["name"] = node.name
if node.name == "add":
node.tag = "a"
elif node.name == "mul":
node.tag = "b"
elif node.name == "sub":
node.tag = "c"
split_gm = split_by_tags(gm, ["a", "b", "c"])
for m in split_gm.children():
for n in m.graph.nodes:
if n.op != "output":
self.assertIn("name", n.meta)
self.assertEqual(n.meta["name"], n.name)