blob: 3e70a48d11f29a4f7453b72ffff8d0bda4a0972e [file] [log] [blame]
import unittest
import torch
import torch.backends.xnnpack
from torch.utils import mobile_optimizer
from torch.nn import functional as F
FileCheck = torch._C.FileCheck
class TestOptimizer(unittest.TestCase):
@unittest.skipUnless(torch.backends.xnnpack.enabled,
" XNNPACK must be enabled for these tests."
" Please build with USE_XNNPACK=1.")
def test_optimize_for_mobile(self):
batch_size = 2
input_channels_per_group = 6
height = 16
width = 16
output_channels_per_group = 6
groups = 4
kernel_h = kernel_w = 3
stride_h = stride_w = 1
pad_h = pad_w = 1
dilation = 1
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
kernels = (kernel_h, kernel_w)
strides = (stride_h, stride_w)
paddings = (pad_h, pad_w)
dilations = (dilation, dilation)
conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
conv_bias_shape = (output_channels)
input_data = torch.rand((batch_size, input_channels, height, width))
conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
conv_bias = torch.rand((output_channels))
result = F.conv2d(input_data, conv_weight, conv_bias, strides, paddings, dilations, groups)
weight_output_dim = 24
linear_input_shape = result.shape[1]
linear_weight_shape = (weight_output_dim, linear_input_shape)
class MyTestModule(torch.nn.Module):
def __init__(self):
super(MyTestModule, self).__init__()
self.conv_weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)))
self.conv_bias = torch.nn.Parameter(torch.Tensor(torch.rand((conv_bias_shape))))
self.linear_weight = torch.nn.Parameter(torch.Tensor(torch.rand(linear_weight_shape)))
self.linear_bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))))
self.strides = strides
self.paddings = paddings
self.dilations = dilations
self.groups = groups
def forward(self, x):
o = F.conv2d(x, self.conv_weight, self.conv_bias,
self.strides, self.paddings, self.dilations, self.groups)
o = F.relu(o)
o = o.permute([0, 2, 3, 1])
o = F.linear(o, self.linear_weight, self.linear_bias)
return F.relu(o)
data_shape = (batch_size, input_channels, height, width)
input_data = torch.normal(1, 20, size=data_shape)
scripted_model = torch.jit.script(MyTestModule())
scripted_model.eval()
initial_result = scripted_model(input_data)
optimized_scripted_model = mobile_optimizer.optimize_for_mobile(scripted_model)
optimized_result = optimized_scripted_model(input_data)
FileCheck().check_not("Tensor = aten::conv2d") \
.check_not("Tensor = prim::CallFunction") \
.check_not("prepacked::conv2d_clamp_prepack") \
.check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
.check_not("prepacked::linear_clamp_prepack") \
.check_count("prepacked::linear_clamp_run", 1, exactly=True) \
.run(optimized_scripted_model.graph)
torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
if __name__ == '__main__':
unittest.main()