blob: 8a6fbe4a3a90877d61b70dd044fb70d92ae614b7 [file] [log] [blame]
import torch
from torch.fx.symbolic_trace import symbolic_trace
from torch.fx.experimental import GraphManipulation
from torch.fx.experimental.Partitioner import Partitioner, Device
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.jit_utils import JitTestCase
class TestFXExperimental(JitTestCase):
def test_find_single_partition(self):
class TestModule(torch.nn.Module):
def forward(self, a, b):
return a + b
m = TestModule()
traced = symbolic_trace(m)
a = torch.rand(1)
b = torch.rand(1)
GraphManipulation.get_size_of_all_nodes(
traced,
[a, b]
)
partitioner = Partitioner()
devices = [
Device('dev_0', 125),
Device('dev_1', 125),
Device('dev_2', 125)
]
ret = partitioner.partition_graph(traced, m, devices)
module_with_submodules = ret.module_with_submodules
self.assertEqual(traced(a, b), module_with_submodules(a, b))
def test_size_based_partition(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, a, b):
add_1 = a + b
linear = self.linear(add_1)
e = torch.rand(4)
add_2 = linear + e
return add_2
m = TestModule()
traced = symbolic_trace(m)
a = torch.rand(4)
b = torch.rand(4)
GraphManipulation.get_size_of_all_nodes(
traced,
[a, b]
)
partitioner = Partitioner()
devices = [
Device('dev_0', 125),
Device('dev_1', 125),
Device('dev_2', 125)
]
ret = partitioner.partition_graph(traced, m, devices)
module_with_submodules = ret.module_with_submodules
self.assertEqual(traced(a, b), module_with_submodules(a, b))
assert len(module_with_submodules.graph.nodes) == 7
def test_partition_combining(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear_0 = torch.nn.Linear(4, 4)
def forward(self, a, b):
add_1 = a + b
c = self.linear_0(a)
add_2 = c + add_1
return add_2
m = TestModule()
traced = symbolic_trace(m)
a = torch.rand(4)
b = torch.rand(4)
GraphManipulation.get_size_of_all_nodes(
traced,
[a, b]
)
partitioner = Partitioner()
devices = [
Device('dev_0', 125),
Device('dev_1', 125),
Device('dev_2', 125)
]
ret = partitioner.partition_graph(traced, m, devices)
module_with_submodules = ret.module_with_submodules
self.assertEqual(traced(a, b), module_with_submodules(a, b))
assert len(module_with_submodules.graph.nodes) == 5
if __name__ == '__main__':
run_tests()