|  | # Owner(s): ["oncall: mobile"] | 
|  |  | 
|  | import unittest | 
|  | import torch | 
|  | from torch.nn import functional as F | 
|  |  | 
|  | from torch.testing._internal.common_utils import TestCase, run_tests | 
|  | from torch.testing import FileCheck | 
|  | import io | 
|  |  | 
|  | @unittest.skipUnless(torch.is_vulkan_available(), | 
|  | "Vulkan backend must be available for these tests.") | 
|  | class TestVulkanRewritePass(TestCase): | 
|  | @staticmethod | 
|  | def validate_transformed_module( | 
|  | # To please flake | 
|  | self, | 
|  | pattern_count_map, | 
|  | data_shape, | 
|  | prepack_removal=False, | 
|  | fuse_clamping_ops=False): | 
|  | module_instance = self | 
|  | scripted_model = torch.jit.script(module_instance) | 
|  | scripted_model.eval() | 
|  | input_data = torch.normal(1, 20, size=data_shape) | 
|  | ref_result = scripted_model(input_data) | 
|  | torch._C._jit_pass_vulkan_insert_prepacked_ops(scripted_model._c) | 
|  | if fuse_clamping_ops or prepack_removal: | 
|  | scripted_model._c = torch._C._freeze_module(scripted_model._c) | 
|  | if fuse_clamping_ops: | 
|  | torch._C._jit_pass_vulkan_fuse_clamp_w_prepacked_conv(scripted_model._c) | 
|  | if prepack_removal: | 
|  | torch._C._jit_pass_vulkan_fold_prepacking_ops(scripted_model._c) | 
|  |  | 
|  | buffer = io.BytesIO() | 
|  | torch.jit.save(scripted_model, buffer) | 
|  | buffer.seek(0) | 
|  | deserialized_scripted_model = torch.jit.load(buffer) | 
|  | for pattern, v in pattern_count_map.items(): | 
|  | if (v == 0): | 
|  | FileCheck().check(pattern).run(deserialized_scripted_model.graph) | 
|  | elif (v == -1): | 
|  | FileCheck().check_not(pattern).run(deserialized_scripted_model.graph) | 
|  | else: | 
|  | FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph) | 
|  |  | 
|  | def test_conv(self): | 
|  | # Conv params | 
|  | batch_size = 2 | 
|  | input_channels_per_group = 6 | 
|  | height = 16 | 
|  | width = 16 | 
|  | output_channels_per_group = 6 | 
|  | groups = 4 | 
|  | kernel_h = kernel_w = 3 | 
|  | stride_h = stride_w = 1 | 
|  | pad_h = pad_w = 1 | 
|  | dilation = 1 | 
|  | input_channels = input_channels_per_group * groups | 
|  | output_channels = output_channels_per_group * groups | 
|  | kernels = (kernel_h, kernel_w) | 
|  | strides = (stride_h, stride_w) | 
|  | paddings = (pad_h, pad_w) | 
|  | dilations = (dilation, dilation) | 
|  | conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w) | 
|  | conv_bias_shape = (output_channels) | 
|  |  | 
|  | class Conv2D(torch.nn.Module): | 
|  | def __init__(self): | 
|  | super().__init__() | 
|  | self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) | 
|  | self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) | 
|  | self.strides = strides | 
|  | self.paddings = paddings | 
|  | self.dilations = dilations | 
|  | self.groups = groups | 
|  |  | 
|  | def forward(self, x): | 
|  | return F.conv2d(x, self.weight, self.bias, | 
|  | self.strides, self.paddings, self.dilations, self.groups) | 
|  |  | 
|  | data_shape = (batch_size, input_channels, height, width) | 
|  | pattern_count_map = {"Tensor = aten::conv2d": -1, | 
|  | "vulkan_prepack::conv2d_clamp_prepack": 1, | 
|  | "vulkan_prepack::conv2d_clamp_run": 1} | 
|  | TestVulkanRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape) | 
|  |  | 
|  | class Conv2DRelu(torch.nn.Module): | 
|  | def __init__(self): | 
|  | super().__init__() | 
|  | self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) | 
|  | self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) | 
|  | self.strides = strides | 
|  | self.paddings = paddings | 
|  | self.dilations = dilations | 
|  | self.groups = groups | 
|  |  | 
|  | def forward(self, x): | 
|  | o = F.conv2d(x, self.weight, self.bias, | 
|  | self.strides, self.paddings, self.dilations, self.groups) | 
|  | o = F.relu(o) | 
|  | return o | 
|  |  | 
|  | data_shape = (batch_size, input_channels, height, width) | 
|  | pattern_count_map = {"Tensor = aten::conv2d": -1, | 
|  | "vulkan_prepack::conv2d_clamp_prepack": 1, | 
|  | "vulkan_prepack::conv2d_clamp_run": 1} | 
|  | TestVulkanRewritePass.validate_transformed_module( | 
|  | Conv2DRelu(), pattern_count_map, data_shape) | 
|  |  | 
|  | pattern_count_map["aten::relu"] = 1 | 
|  | pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1 | 
|  | TestVulkanRewritePass.validate_transformed_module( | 
|  | Conv2DRelu(), | 
|  | pattern_count_map, | 
|  | data_shape, | 
|  | prepack_removal=True) | 
|  | pattern_count_map["aten::relu"] = -1 | 
|  | TestVulkanRewritePass.validate_transformed_module( | 
|  | Conv2DRelu(), | 
|  | pattern_count_map, | 
|  | data_shape, | 
|  | prepack_removal=True, | 
|  | fuse_clamping_ops=True) | 
|  |  | 
|  |  | 
|  | class Conv2DHardtanh(torch.nn.Module): | 
|  | def __init__(self): | 
|  | super().__init__() | 
|  | self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) | 
|  | self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) | 
|  | self.strides = strides | 
|  | self.paddings = paddings | 
|  | self.dilations = dilations | 
|  | self.groups = groups | 
|  |  | 
|  | def forward(self, x): | 
|  | o = F.conv2d(x, self.weight, self.bias, | 
|  | self.strides, self.paddings, self.dilations, self.groups) | 
|  | o = F.hardtanh(o) | 
|  | return o | 
|  |  | 
|  | data_shape = (batch_size, input_channels, height, width) | 
|  | pattern_count_map = {"Tensor = aten::conv2d": -1, | 
|  | "vulkan_prepack::conv2d_clamp_prepack": 1, | 
|  | "vulkan_prepack::conv2d_clamp_run": 1} | 
|  | TestVulkanRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape) | 
|  | pattern_count_map["aten::hardtanh"] = 1 | 
|  | pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1 | 
|  | TestVulkanRewritePass.validate_transformed_module( | 
|  | Conv2DHardtanh(), | 
|  | pattern_count_map, | 
|  | data_shape, | 
|  | prepack_removal=True) | 
|  | pattern_count_map["aten::hardtanh"] = -1 | 
|  | TestVulkanRewritePass.validate_transformed_module( | 
|  | Conv2DRelu(), | 
|  | pattern_count_map, | 
|  | data_shape, | 
|  | prepack_removal=True, | 
|  | fuse_clamping_ops=True) | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | run_tests() |