|  | # Owner(s): ["module: mkldnn"] | 
|  | import itertools | 
|  | import unittest | 
|  | from typing import NamedTuple, List | 
|  |  | 
|  | import torch | 
|  | from torch import nn | 
|  |  | 
|  | from torch.testing._internal.common_utils import run_tests | 
|  | from torch.testing._internal.jit_utils import JitTestCase | 
|  |  | 
|  | from test_tensorexpr import warmup_and_run_forward | 
|  |  | 
|  | FUSION_GROUP = 'prim::TensorExprGroup' | 
|  |  | 
|  | class PointwisePostOp(NamedTuple): | 
|  | attr : str | 
|  | pointwise_module : nn.Module | 
|  | scalars : List = [] | 
|  | algorithm : str = "" | 
|  |  | 
|  | CONV_MODULES = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d} | 
|  | CONV_TRANSPOSE_MODULES = {2: torch.nn.ConvTranspose2d} | 
|  |  | 
|  | @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled") | 
|  | class TestMkldnnFusion(JitTestCase): | 
|  | def assertFused(self, graph, fused_patterns): | 
|  | for pat in fused_patterns: | 
|  | self.assertGraphContainsExactly(graph, pat, 0) | 
|  |  | 
|  | def _check_model(self, m, x, trace=False): | 
|  | old_fusion_inlining = torch._C._debug_get_fusion_group_inlining() | 
|  | torch._C._debug_set_fusion_group_inlining(False) | 
|  |  | 
|  | old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu() | 
|  | torch._C._jit_override_can_fuse_on_cpu(True) | 
|  |  | 
|  | old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu() | 
|  | torch._C._jit_set_te_must_use_llvm_cpu(False) | 
|  |  | 
|  | m.eval() | 
|  | with torch.no_grad(): | 
|  | if trace: | 
|  | script = torch.jit.trace(m, x) | 
|  | else: | 
|  | script = torch.jit.script(m) | 
|  | script = torch.jit.freeze(script) | 
|  |  | 
|  | with torch.no_grad(): | 
|  | y = warmup_and_run_forward(script, x) | 
|  | y = script(x) | 
|  | y_ref = m(x) | 
|  |  | 
|  | graph = script.graph_for(*x) | 
|  | self.assertEqual(y, y_ref) | 
|  |  | 
|  | torch._C._debug_set_fusion_group_inlining(old_fusion_inlining) | 
|  | torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuser_state) | 
|  | torch._C._jit_set_te_must_use_llvm_cpu(old_te_must_use_llvm_cpu) | 
|  | return graph | 
|  |  | 
|  | def test_single_conv(self): | 
|  | class M(nn.Module): | 
|  | def __init__(self, in_channels, out_channels, bias, **kwargs): | 
|  | super().__init__() | 
|  | self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) | 
|  |  | 
|  | def forward(self, x): | 
|  | res = self.conv(x) | 
|  | return res | 
|  |  | 
|  | for memory_format, enabled in [ | 
|  | [torch.contiguous_format, False], | 
|  | [torch.channels_last, True], | 
|  | ]: | 
|  | for trace in [True, False]: | 
|  | input_size = 224 | 
|  | batch_size = 1 | 
|  | kernel_size = 3 | 
|  | options = itertools.product([True, False], [1, 2], [1, 4]) | 
|  | for bias, dilation, groups in options: | 
|  | iC = 3 * groups | 
|  | oC = 10 * groups | 
|  | m = M(iC, | 
|  | oC, | 
|  | bias, | 
|  | kernel_size=(kernel_size, kernel_size), | 
|  | stride=2, | 
|  | padding=1, | 
|  | dilation=dilation, | 
|  | groups=groups).to(memory_format=memory_format) | 
|  | x = torch.randn(batch_size, iC, input_size, input_size).to(memory_format=memory_format) | 
|  | graph = self._check_model(m, x, trace) | 
|  | conv_node_name = 'aten::_convolution' if trace else 'aten::conv2d' | 
|  | if enabled: | 
|  | self.assertFused(graph, [conv_node_name]) | 
|  | self.assertGraphContainsExactly(graph, FUSION_GROUP, 1) | 
|  | else: | 
|  | self.assertGraphContains(graph, kind=conv_node_name) | 
|  |  | 
|  | def test_conv_unary_fusion_nnc(self): | 
|  | class M(nn.Module): | 
|  | def __init__(self, unary_fn, in_channels, out_channels, bias, **kwargs): | 
|  | super().__init__() | 
|  | self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) | 
|  | self.unary = unary_fn | 
|  |  | 
|  | def forward(self, x): | 
|  | x = self.conv(x) | 
|  | x = self.unary(x) | 
|  | return x | 
|  |  | 
|  | for memory_format, enabled in [ | 
|  | [torch.contiguous_format, False], | 
|  | [torch.channels_last, True], | 
|  | ]: | 
|  | for unary_fn in [torch.relu]: | 
|  | for bias in [True, False]: | 
|  | for oC in [1, 10]: | 
|  | m = M(unary_fn, 3, oC, bias, kernel_size=(3, 3)).to(memory_format=memory_format) | 
|  | x = torch.randn(1, 3, 224, 224).to(memory_format=memory_format) | 
|  |  | 
|  | graph = self._check_model(m, x) | 
|  | if enabled: | 
|  | self.assertFused(graph, ['aten::conv2d', 'aten::' + unary_fn.__name__]) | 
|  | self.assertGraphContainsExactly(graph, FUSION_GROUP, 1) | 
|  | else: | 
|  | self.assertGraphContains(graph, kind='aten::conv2d') | 
|  |  | 
|  | def test_unsupported_conv(self): | 
|  | class M(nn.Module): | 
|  | def __init__(self, m, in_channels, out_channels, bias, **kwargs): | 
|  | super().__init__() | 
|  | self.conv = m(in_channels, out_channels, bias=bias, **kwargs) | 
|  |  | 
|  | def forward(self, x): | 
|  | res = self.conv(x) | 
|  | return res | 
|  |  | 
|  | for module, dim, memory_format in [ | 
|  | [nn.Conv3d, 3, torch.contiguous_format], | 
|  | [nn.Conv3d, 3, torch.channels_last_3d], | 
|  | [nn.ConvTranspose2d, 2, torch.contiguous_format], | 
|  | [nn.ConvTranspose2d, 2, torch.channels_last], | 
|  | ]: | 
|  | trace = True | 
|  | input_size = 224 | 
|  | batch_size = 1 | 
|  | kernel_size = 3 | 
|  | groups = 2 | 
|  | bias = True | 
|  | iC = 3 * groups | 
|  | oC = 10 * groups | 
|  | dilation = 2 | 
|  | m = M(module, | 
|  | iC, | 
|  | oC, | 
|  | bias, | 
|  | kernel_size=kernel_size, | 
|  | stride=2, | 
|  | padding=1, | 
|  | dilation=dilation, | 
|  | groups=groups).to(memory_format=memory_format) | 
|  | input_sizes = [batch_size, iC, input_size, input_size] | 
|  | if dim == 3: | 
|  | input_sizes.append(input_size) | 
|  | x = torch.randn(input_sizes).to(memory_format=memory_format) | 
|  | graph = self._check_model(m, x, trace) | 
|  | self.assertGraphContains(graph, kind='aten::_convolution') | 
|  |  | 
|  | def _unary_list(self): | 
|  | unary_list = { | 
|  | "relu": PointwisePostOp("relu", nn.ReLU()), | 
|  | "sigmoid": PointwisePostOp("sigmoid", nn.Sigmoid()), | 
|  | "tanh": PointwisePostOp("tanh", nn.Tanh()), | 
|  | "hardswish": PointwisePostOp("hardswish", nn.Hardswish()), | 
|  | "leaky_relu": PointwisePostOp("leaky_relu", nn.LeakyReLU(0.1, inplace=False), scalars=[0.1]), | 
|  | "hardtanh": PointwisePostOp("hardtanh", nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False), scalars=[-0.5, 4]), | 
|  | "gelu_none": PointwisePostOp("gelu", nn.GELU(approximate="none"), algorithm="none"), | 
|  | "gelu_tanh": PointwisePostOp("gelu", nn.GELU(approximate="tanh"), algorithm="tanh"), | 
|  | } | 
|  | return unary_list | 
|  |  | 
|  | def _binary_list(self): | 
|  | binary_list = { | 
|  | "add": torch.add, | 
|  | "sub": torch.sub, | 
|  | "mul": torch.mul, | 
|  | "div": torch.div, | 
|  | } | 
|  | return binary_list | 
|  |  | 
|  | def test_linear_unary_fusion_ops(self): | 
|  | class M(nn.Module): | 
|  | def __init__(self, unary_fn, in_channels, out_channels, bias, **kwargs): | 
|  | super().__init__() | 
|  | self.linear = torch.nn.Linear( | 
|  | in_channels, out_channels, bias=bias, **kwargs | 
|  | ) | 
|  | self.unary = unary_fn | 
|  |  | 
|  | def forward(self, x): | 
|  | x = self.linear(x) | 
|  | x = self.unary(x) | 
|  | return x | 
|  |  | 
|  | for pointwise_info in self._unary_list().values(): | 
|  | options = itertools.product([[2, 3, 10], [2, 10]], [True, False]) | 
|  | for input_shape, bias in options: | 
|  | with torch.no_grad(): | 
|  | mod = M(pointwise_info.pointwise_module, input_shape[-1], 10, bias).eval() | 
|  | v = torch.randn(input_shape) | 
|  | ref = mod(v) | 
|  | attr = pointwise_info.attr | 
|  | scalars = pointwise_info.scalars | 
|  | algorithm = pointwise_info.algorithm | 
|  | fused = torch.ops.mkldnn._linear_pointwise( | 
|  | v, mod.linear.weight, mod.linear.bias, attr, scalars, algorithm | 
|  | ) | 
|  | self.assertEqual(ref, fused) | 
|  |  | 
|  |  | 
|  | def test_conv_unary_fusion_ops(self): | 
|  | class M(nn.Module): | 
|  | def __init__(self, unary_fn, dim, in_channels, out_channels, dilation, groups, bias, **kwargs): | 
|  | super().__init__() | 
|  | self.conv = CONV_MODULES[dim](in_channels, out_channels, dilation=dilation, groups=groups, bias=bias, **kwargs) | 
|  | self.unary = unary_fn | 
|  |  | 
|  | def forward(self, x): | 
|  | x = self.conv(x) | 
|  | x = self.unary(x) | 
|  | return x | 
|  |  | 
|  | input_shapes = {2: (112, 112), 3: (55, 55, 55)} | 
|  | for pointwise_info in self._unary_list().values(): | 
|  | for dim in [2, 3]: | 
|  | channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d | 
|  | options = itertools.product([True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last]) | 
|  | for bias, dilation, groups, memory_format in options: | 
|  | oC = 32 * groups | 
|  | iC = 3 * groups | 
|  | x_shape = (1, iC) + input_shapes[dim] | 
|  | x = torch.randn(x_shape, dtype=torch.float32).to(memory_format=memory_format) | 
|  | mod = M(pointwise_info.pointwise_module, dim, iC, oC, dilation, groups, bias, kernel_size=3) | 
|  | mod = mod.to(memory_format=memory_format).eval() | 
|  | with torch.no_grad(): | 
|  | ref = mod(x) | 
|  | attr = pointwise_info.attr | 
|  | scalars = pointwise_info.scalars | 
|  | algorithm = pointwise_info.algorithm | 
|  | fused = torch.ops.mkldnn._convolution_pointwise( | 
|  | x, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation, | 
|  | mod.conv.groups, attr, scalars, algorithm | 
|  | ) | 
|  | self.assertEqual(ref, fused) | 
|  |  | 
|  |  | 
|  | def test_conv_binary_fusion_ops(self): | 
|  | class M(nn.Module): | 
|  | def __init__(self, binary_fn, dim, in_channels, out_channels, dilation, groups, bias, **kwargs): | 
|  | super().__init__() | 
|  | self.conv = CONV_MODULES[dim](in_channels, out_channels, dilation=dilation, groups=groups, bias=bias, **kwargs) | 
|  | self.binary = binary_fn | 
|  |  | 
|  | def forward(self, x, other): | 
|  | x = self.conv(x) | 
|  | x = self.binary(x, other) | 
|  | return x | 
|  |  | 
|  | input_shapes = {2: (112, 112), 3: (55, 55, 55)} | 
|  | for pointwise_name, pointwise_fn in self._binary_list().items(): | 
|  | for dim in [2, 3]: | 
|  | channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d | 
|  | options = itertools.product([False, True], [True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last]) | 
|  | for fuse_relu, bias, dilation, groups, memory_format in options: | 
|  | oC = 32 * groups | 
|  | iC = 3 * groups | 
|  | x_shape = (1, iC) + input_shapes[dim] | 
|  | x = torch.randn(x_shape, dtype=torch.float32).to(memory_format=memory_format) | 
|  | mod = M(pointwise_fn, dim, iC, oC, dilation, groups, bias, kernel_size=3) | 
|  | mod = mod.to(memory_format=memory_format).eval() | 
|  | other = torch.randn_like(mod.conv(x)) | 
|  | with torch.no_grad(): | 
|  | ref = mod(x, other) | 
|  | unary_attr = None | 
|  | if fuse_relu: | 
|  | ref.relu_() | 
|  | unary_attr = "relu" | 
|  | attr = pointwise_name | 
|  | fused = torch.ops.mkldnn._convolution_pointwise( | 
|  | x, other, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation, | 
|  | mod.conv.groups, attr, None, unary_attr, [], None | 
|  | ) | 
|  | # for binary add, we support inplace version. | 
|  | if attr == "add": | 
|  | fused_inplace = torch.ops.mkldnn._convolution_pointwise_( | 
|  | other, x, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation, | 
|  | mod.conv.groups, attr, None, unary_attr, [], None | 
|  | ) | 
|  | self.assertEqual(ref, other) | 
|  | self.assertEqual(ref, fused_inplace) | 
|  |  | 
|  | self.assertEqual(ref, fused) | 
|  |  | 
|  |  | 
|  | def test_linear_binary_fusion_ops(self): | 
|  | class M(nn.Module): | 
|  | def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs): | 
|  | super().__init__() | 
|  | self.linear = torch.nn.Linear( | 
|  | in_channels, out_channels, bias=bias, **kwargs | 
|  | ) | 
|  | self.binary = binary_fn | 
|  |  | 
|  | def forward(self, x, other): | 
|  | x = self.linear(x) | 
|  | x = self.binary(x, other) | 
|  | return x | 
|  |  | 
|  | out_feature = 20 | 
|  | for pointwise_name, pointwise_fn in self._binary_list().items(): | 
|  | options = itertools.product([[2, 3, 10], [2, 10]], [True, False]) | 
|  | for input_shape, bias in options: | 
|  | with torch.no_grad(): | 
|  | mod = M(pointwise_fn, input_shape[-1], out_feature, bias).eval() | 
|  | v = torch.randn(input_shape) | 
|  | other = torch.randn(input_shape[:-1] + [out_feature]) | 
|  | ref = mod(v, other) | 
|  | attr = pointwise_name | 
|  | fused = torch.ops.mkldnn._linear_pointwise( | 
|  | v, other, mod.linear.weight, mod.linear.bias, attr | 
|  | ) | 
|  | self.assertEqual(ref, fused) | 
|  |  | 
|  | def test_conv_transpose_unary_fusion_ops(self): | 
|  | class M(nn.Module): | 
|  | def __init__(self, unary_fn, dim, in_channels, out_channels, kernel_size, **kwargs): | 
|  | super().__init__() | 
|  | self.conv_transpose = CONV_TRANSPOSE_MODULES[dim](in_channels, out_channels, kernel_size, **kwargs) | 
|  | self.unary = unary_fn | 
|  |  | 
|  | def forward(self, x): | 
|  | x = self.conv_transpose(x) | 
|  | x = self.unary(x) | 
|  | return x | 
|  |  | 
|  | input_shapes = {2: (28, 28)} | 
|  | kernel_size = 3 | 
|  | for pointwise_info in self._unary_list().values(): | 
|  | for dim in [2]: | 
|  | channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d | 
|  | options = itertools.product([True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last], [False, True]) | 
|  | for bias, dilation, groups, memory_format, prepack_weight in options: | 
|  | oC = 32 * groups | 
|  | iC = 3 * groups | 
|  | x_shape = (1, iC) + input_shapes[dim] | 
|  | x = torch.randn(x_shape, dtype=torch.float32).to(memory_format=memory_format) | 
|  | mod = M(pointwise_info.pointwise_module, dim, iC, oC, kernel_size, dilation=dilation, groups=groups, bias=bias) | 
|  | mod = mod.to(memory_format=memory_format).eval() | 
|  | with torch.no_grad(): | 
|  | ref = mod(x) | 
|  | attr = pointwise_info.attr | 
|  | scalars = pointwise_info.scalars | 
|  | algorithm = pointwise_info.algorithm | 
|  |  | 
|  | if prepack_weight: | 
|  | packed_weight = torch.ops.mkldnn._reorder_convolution_transpose_weight( | 
|  | mod.conv_transpose.weight, | 
|  | mod.conv_transpose.padding, | 
|  | mod.conv_transpose.output_padding, | 
|  | mod.conv_transpose.stride, | 
|  | mod.conv_transpose.dilation, | 
|  | mod.conv_transpose.groups, | 
|  | x.size()) | 
|  | mod.conv_transpose.weight = torch.nn.Parameter( | 
|  | packed_weight, | 
|  | requires_grad=mod.conv_transpose.weight.requires_grad, | 
|  | ) | 
|  |  | 
|  | fused = torch.ops.mkldnn._convolution_transpose_pointwise( | 
|  | x, | 
|  | mod.conv_transpose.weight, | 
|  | mod.conv_transpose.bias, | 
|  | mod.conv_transpose.padding, | 
|  | mod.conv_transpose.output_padding, | 
|  | mod.conv_transpose.stride, | 
|  | mod.conv_transpose.dilation, | 
|  | mod.conv_transpose.groups, | 
|  | attr, | 
|  | scalars, | 
|  | algorithm) | 
|  | self.assertEqual(ref, fused) | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | run_tests() |