| # Owner(s): ["oncall: mobile"] | 
 |  | 
 | import torch | 
 | from torch.nn import functional as F | 
 |  | 
 | from torch.testing._internal.common_utils import TestCase, run_tests | 
 | from torch.testing import FileCheck | 
 | import io | 
 |  | 
 | class TestMetalRewritePass(TestCase): | 
 |     @staticmethod | 
 |     def validate_transformed_module( | 
 |             # To please flake | 
 |             self, | 
 |             pattern_count_map, | 
 |             data_shape, | 
 |             prepack_removal=False, | 
 |             fuse_clamping_ops=False): | 
 |         module_instance = self | 
 |         scripted_model = torch.jit.script(module_instance) | 
 |         scripted_model.eval() | 
 |         input_data = torch.normal(1, 20, size=data_shape) | 
 |         ref_result = scripted_model(input_data) | 
 |         torch._C._jit_pass_metal_insert_prepacked_ops(scripted_model._c) | 
 |         if fuse_clamping_ops or prepack_removal: | 
 |             scripted_model._c = torch._C._freeze_module(scripted_model._c) | 
 |         if fuse_clamping_ops: | 
 |             torch._C._jit_pass_metal_fuse_clamp_w_prepacked_conv(scripted_model._c) | 
 |         if prepack_removal: | 
 |             torch._C._jit_pass_metal_fold_prepacking_ops(scripted_model._c) | 
 |  | 
 |         buffer = io.BytesIO() | 
 |         torch.jit.save(scripted_model, buffer) | 
 |         buffer.seek(0) | 
 |         deserialized_scripted_model = torch.jit.load(buffer) | 
 |         for pattern, v in pattern_count_map.items(): | 
 |             if (v == 0): | 
 |                 FileCheck().check(pattern).run(deserialized_scripted_model.graph) | 
 |             elif (v == -1): | 
 |                 FileCheck().check_not(pattern).run(deserialized_scripted_model.graph) | 
 |             else: | 
 |                 FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph) | 
 |  | 
 |     def test_conv(self): | 
 |         # Conv params | 
 |         batch_size = 2 | 
 |         input_channels_per_group = 6 | 
 |         height = 16 | 
 |         width = 16 | 
 |         output_channels_per_group = 6 | 
 |         groups = 4 | 
 |         kernel_h = kernel_w = 3 | 
 |         stride_h = stride_w = 1 | 
 |         pad_h = pad_w = 1 | 
 |         dilation = 1 | 
 |         input_channels = input_channels_per_group * groups | 
 |         output_channels = output_channels_per_group * groups | 
 |         kernels = (kernel_h, kernel_w) | 
 |         strides = (stride_h, stride_w) | 
 |         paddings = (pad_h, pad_w) | 
 |         dilations = (dilation, dilation) | 
 |         conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w) | 
 |         conv_bias_shape = (output_channels) | 
 |  | 
 |         class Conv2D(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(Conv2D, self).__init__() | 
 |                 self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) | 
 |                 self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) | 
 |                 self.strides = strides | 
 |                 self.paddings = paddings | 
 |                 self.dilations = dilations | 
 |                 self.groups = groups | 
 |  | 
 |             def forward(self, x): | 
 |                 return F.conv2d(x, self.weight, self.bias, | 
 |                                 self.strides, self.paddings, self.dilations, self.groups) | 
 |  | 
 |         data_shape = (batch_size, input_channels, height, width) | 
 |         pattern_count_map = {"Tensor = aten::conv2d": -1, | 
 |                              "metal_prepack::conv2d_prepack": 1, | 
 |                              "metal_prepack::conv2d_run": 1} | 
 |         TestMetalRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape) | 
 |  | 
 |         class Conv2DRelu(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(Conv2DRelu, self).__init__() | 
 |                 self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) | 
 |                 self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) | 
 |                 self.strides = strides | 
 |                 self.paddings = paddings | 
 |                 self.dilations = dilations | 
 |                 self.groups = groups | 
 |  | 
 |             def forward(self, x): | 
 |                 o = F.conv2d(x, self.weight, self.bias, | 
 |                              self.strides, self.paddings, self.dilations, self.groups) | 
 |                 o = F.relu(o) | 
 |                 return o | 
 |  | 
 |         data_shape = (batch_size, input_channels, height, width) | 
 |         pattern_count_map = {"Tensor = aten::conv2d": -1, | 
 |                              "metal_prepack::conv2d_prepack": 1, | 
 |                              "metal_prepack::conv2d_run": 1} | 
 |         TestMetalRewritePass.validate_transformed_module( | 
 |             Conv2DRelu(), pattern_count_map, data_shape) | 
 |  | 
 |         pattern_count_map["aten::relu"] = 1 | 
 |         pattern_count_map["metal_prepack::conv2d_prepack"] = -1 | 
 |         TestMetalRewritePass.validate_transformed_module( | 
 |             Conv2DRelu(), | 
 |             pattern_count_map, | 
 |             data_shape, | 
 |             prepack_removal=True) | 
 |         pattern_count_map["aten::relu"] = -1 | 
 |         TestMetalRewritePass.validate_transformed_module( | 
 |             Conv2DRelu(), | 
 |             pattern_count_map, | 
 |             data_shape, | 
 |             prepack_removal=True, | 
 |             fuse_clamping_ops=True) | 
 |  | 
 |  | 
 |         class Conv2DHardtanh(torch.nn.Module): | 
 |             def __init__(self): | 
 |                 super(Conv2DHardtanh, self).__init__() | 
 |                 self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) | 
 |                 self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) | 
 |                 self.strides = strides | 
 |                 self.paddings = paddings | 
 |                 self.dilations = dilations | 
 |                 self.groups = groups | 
 |  | 
 |             def forward(self, x): | 
 |                 o = F.conv2d(x, self.weight, self.bias, | 
 |                              self.strides, self.paddings, self.dilations, self.groups) | 
 |                 o = F.hardtanh(o) | 
 |                 return o | 
 |  | 
 |         data_shape = (batch_size, input_channels, height, width) | 
 |         pattern_count_map = {"Tensor = aten::conv2d": -1, | 
 |                              "metal_prepack::conv2d_prepack": 1, | 
 |                              "metal_prepack::conv2d_run": 1} | 
 |         TestMetalRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape) | 
 |         pattern_count_map["aten::hardtanh"] = 1 | 
 |         pattern_count_map["metal_prepack::conv2d_prepack"] = -1 | 
 |         TestMetalRewritePass.validate_transformed_module( | 
 |             Conv2DHardtanh(), | 
 |             pattern_count_map, | 
 |             data_shape, | 
 |             prepack_removal=True) | 
 |         pattern_count_map["aten::hardtanh"] = -1 | 
 |         TestMetalRewritePass.validate_transformed_module( | 
 |             Conv2DRelu(), | 
 |             pattern_count_map, | 
 |             data_shape, | 
 |             prepack_removal=True, | 
 |             fuse_clamping_ops=True) | 
 |  | 
 | if __name__ == "__main__": | 
 |     run_tests() |