| # Owner(s): ["oncall: quantization"] |
| import copy |
| import unittest |
| |
| import torch |
| import torch._dynamo as torchdynamo |
| from torch.ao.quantization.pt2e.graph_utils import ( |
| find_sequential_partitions, |
| get_equivalent_types, |
| update_equivalent_types_dict, |
| ) |
| from torch.testing._internal.common_utils import IS_WINDOWS, TestCase |
| |
| |
| class TestGraphUtils(TestCase): |
| @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows") |
| def test_conv_bn_conv_relu(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d(3, 3, 3) |
| self.bn1 = torch.nn.BatchNorm2d(3) |
| self.conv2 = torch.nn.Conv2d(3, 3, 3) |
| self.relu2 = torch.nn.ReLU() |
| |
| def forward(self, x): |
| bn_out = self.bn1(self.conv1(x)) |
| relu_out = torch.nn.functional.relu(bn_out) |
| return self.relu2(self.conv2(relu_out)) |
| |
| m = M().eval() |
| example_inputs = (torch.randn(1, 3, 5, 5),) |
| |
| # program capture |
| m, guards = torchdynamo.export( |
| m, |
| *copy.deepcopy(example_inputs), |
| aten_graph=True, |
| ) |
| fused_partitions = find_sequential_partitions( |
| m, [torch.nn.Conv2d, torch.nn.BatchNorm2d] |
| ) |
| self.assertEqual(len(fused_partitions), 1) |
| fused_partitions = find_sequential_partitions( |
| m, [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU] |
| ) |
| self.assertEqual(len(fused_partitions), 1) |
| |
| def x(): |
| find_sequential_partitions( |
| m, |
| [ |
| torch.nn.Conv2d, |
| torch.nn.BatchNorm2d, |
| torch.nn.ReLU, |
| torch.nn.functional.conv2d, |
| ], |
| ) |
| |
| self.assertRaises(ValueError, x) |
| |
| @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows") |
| def test_conv_bn_relu(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.bn1 = torch.nn.BatchNorm2d(3) |
| self.conv2 = torch.nn.Conv2d(3, 3, 3) |
| self.relu2 = torch.nn.ReLU() |
| |
| def forward(self, x): |
| bn_out = self.bn1(x) |
| return self.relu2(self.conv2(bn_out)) |
| |
| m = M().eval() |
| example_inputs = (torch.randn(1, 3, 5, 5),) |
| |
| # program capture |
| m, guards = torchdynamo.export( |
| m, |
| *copy.deepcopy(example_inputs), |
| aten_graph=True, |
| ) |
| fused_partitions = find_sequential_partitions( |
| m, [torch.nn.Conv2d, torch.nn.BatchNorm2d] |
| ) |
| self.assertEqual(len(fused_partitions), 0) |
| fused_partitions = find_sequential_partitions( |
| m, [torch.nn.BatchNorm2d, torch.nn.Conv2d] |
| ) |
| self.assertEqual(len(fused_partitions), 1) |
| fused_partitions = find_sequential_partitions( |
| m, [torch.nn.BatchNorm2d, torch.nn.ReLU] |
| ) |
| self.assertEqual(len(fused_partitions), 0) |
| |
| @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows") |
| def test_customized_equivalet_types_dict(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 3) |
| |
| def forward(self, x): |
| return torch.nn.functional.relu6(self.conv(x)) |
| |
| m = M().eval() |
| example_inputs = (torch.randn(1, 3, 5, 5),) |
| |
| # program capture |
| m, guards = torchdynamo.export( |
| m, |
| *copy.deepcopy(example_inputs), |
| aten_graph=True, |
| ) |
| customized_equivalent_types = get_equivalent_types() |
| customized_equivalent_types.append({torch.nn.ReLU6, torch.nn.functional.relu6}) |
| update_equivalent_types_dict(customized_equivalent_types) |
| fused_partitions = find_sequential_partitions( |
| m, |
| [torch.nn.Conv2d, torch.nn.ReLU6], |
| ) |
| self.assertEqual(len(fused_partitions), 1) |