|  | from __future__ import absolute_import, division, print_function, unicode_literals | 
|  | import copy | 
|  | import unittest | 
|  |  | 
|  | import torch | 
|  | from torch.utils import mkldnn as mkldnn_utils | 
|  | from common_utils import TestCase, run_tests | 
|  | from torch.autograd.gradcheck import gradgradcheck, gradcheck | 
|  |  | 
|  |  | 
|  | # Comment the line below to find out the CI machines having MKL-DNN build disabled | 
|  | @unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled") | 
|  | class TestMkldnn(TestCase): | 
|  | def test_conversion(self): | 
|  | for cpu_tensor in [torch.randn((1, 2, 3, 4), | 
|  | dtype=torch.float, device=torch.device('cpu')), | 
|  | torch.randn((1, 2, 3, 4, 5), | 
|  | dtype=torch.float, device=torch.device('cpu'))[:, :, :, :, 1]]: | 
|  | cpu_tensor.requires_grad_() | 
|  | mkldnn_tensor = cpu_tensor.to_mkldnn() | 
|  | cpu_tensor_1 = mkldnn_tensor.to_dense() | 
|  | self.assertEqual(cpu_tensor, cpu_tensor_1) | 
|  | self.assertEqual(mkldnn_tensor.dtype, torch.float) | 
|  | self.assertEqual(mkldnn_tensor.device, torch.device('cpu')) | 
|  | self.assertEqual(mkldnn_tensor.size(), torch.Size([1, 2, 3, 4])) | 
|  | self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel()) | 
|  | self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size()) | 
|  | self.assertRaisesRegex(RuntimeError, | 
|  | "Cannot access data pointer of Tensor that doesn't have storage", | 
|  | lambda: mkldnn_tensor.data_ptr() != 0) | 
|  |  | 
|  | def test_unsupported(self): | 
|  | # unsupported types and unsupported types with gpu | 
|  | for dtype in [torch.double, torch.half, torch.uint8, torch.int8, | 
|  | torch.short, torch.int, torch.long]: | 
|  | with self.assertRaises(RuntimeError) as context: | 
|  | torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cpu')).to_mkldnn() | 
|  | if torch.cuda.is_available(): | 
|  | with self.assertRaises(RuntimeError) as context: | 
|  | torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cuda')).to_mkldnn() | 
|  | # supported type with gpu | 
|  | if torch.cuda.is_available(): | 
|  | with self.assertRaises(RuntimeError) as context: | 
|  | torch.randn(1, 2, 3, 4, dtype=torch.float, device=torch.device('cuda')).to_mkldnn() | 
|  | # some factory functions | 
|  | for creator in [torch.empty, torch.ones, torch.zeros, torch.randn, torch.rand]: | 
|  | with self.assertRaises(RuntimeError) as context: | 
|  | creator(1, 2, 3, 4, dtype=torch.float, device=torch.device('cpu'), layout=torch._mkldnn) | 
|  |  | 
|  | def test_autograd_to_mkldnn(self): | 
|  | # MKLDNN only supports float32 | 
|  | root = torch.randn(4, 5, dtype=torch.float32, requires_grad=True) | 
|  |  | 
|  | def func(root): | 
|  | return root.to_mkldnn().to_dense() | 
|  |  | 
|  | # because MKLDNN only supports float32, we need to lessen the precision. | 
|  | # these numbers are just empirical results that seem to work. | 
|  | self.assertWarnsRegex(lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2), | 
|  | 'double precision floating point') | 
|  | self.assertWarnsRegex(lambda: gradgradcheck(func, [root], atol=4e-2, rtol=1e-2), | 
|  | 'double precision floating point') | 
|  |  | 
|  | def test_autograd_from_mkldnn(self): | 
|  | # MKLDNN only supports float32 | 
|  | root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_() | 
|  |  | 
|  | def func(root): | 
|  | return root.to_dense() | 
|  |  | 
|  | # because MKLDNN only supports float32, we need to lessen the precision. | 
|  | # these numbers are just empirical results that seem to work. | 
|  | self.assertWarnsRegex(lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2), | 
|  | 'double precision floating point') | 
|  |  | 
|  | def test_detach(self): | 
|  | root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_() | 
|  |  | 
|  | detach = root.detach() | 
|  | self.assertEqual((4, 5), detach.size()) | 
|  | self.assertFalse(detach.requires_grad) | 
|  | self.assertTrue(root.requires_grad) | 
|  |  | 
|  | detach_ = root.detach_() | 
|  | self.assertEqual((4, 5), detach_.size()) | 
|  | self.assertFalse(detach_.requires_grad) | 
|  | self.assertFalse(root.requires_grad) | 
|  |  | 
|  | def test_repr(self): | 
|  | self.assertTrue("layout=torch._mkldnn" in str(torch.randn((1, 2, 3, 4), | 
|  | dtype=torch.float, device=torch.device('cpu')).to_mkldnn())) | 
|  |  | 
|  | def test_conv2d(self): | 
|  | for groups in [1, 4]: | 
|  | N = torch.randint(3, 10, (1,)).item() | 
|  | C = torch.randint(1, 3, (1,)).item() * groups | 
|  | M = torch.randint(1, 3, (1,)).item() * groups | 
|  | x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100 | 
|  | for bias in [True, False]: | 
|  | conv2d = torch.nn.Conv2d(in_channels=C, | 
|  | out_channels=M, | 
|  | kernel_size=3, | 
|  | stride=2, | 
|  | padding=1, | 
|  | bias=bias, | 
|  | groups=groups).float() | 
|  | mkldnn_conv2d = mkldnn_utils.to_mkldnn(copy.deepcopy(conv2d)) | 
|  | self.assertEqual( | 
|  | conv2d(x), | 
|  | mkldnn_conv2d(x.to_mkldnn()).to_dense()) | 
|  |  | 
|  | def test_relu(self): | 
|  | x = torch.randn((4, 5), dtype=torch.float32) * 10 | 
|  | self.assertEqual(torch.relu(x), torch.relu(x.to_mkldnn()).to_dense()) | 
|  |  | 
|  | def test_relu_(self): | 
|  | x1 = torch.randn((4, 5), dtype=torch.float32) * 10 | 
|  | x2 = x1.clone().to_mkldnn() | 
|  | self.assertEqual(torch.relu_(x1), torch.relu_(x2).to_dense()) | 
|  |  | 
|  | def test_max_pool2d(self): | 
|  | N = torch.randint(3, 10, (1,)).item() | 
|  | C = torch.randint(3, 10, (1,)).item() | 
|  | x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10 | 
|  |  | 
|  | max_pool2d = torch.nn.MaxPool2d( | 
|  | kernel_size=3, | 
|  | stride=2, | 
|  | padding=1) | 
|  |  | 
|  | self.assertEqual( | 
|  | max_pool2d(x), | 
|  | max_pool2d(x.to_mkldnn()).to_dense()) | 
|  |  | 
|  | def test_avg_pool2d(self): | 
|  | N = torch.randint(3, 10, (1,)).item() | 
|  | C = torch.randint(3, 10, (1,)).item() | 
|  | x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10 | 
|  |  | 
|  | for count_include_pad in [True, False]: | 
|  | avg_pool2d = torch.nn.AvgPool2d( | 
|  | kernel_size=3, | 
|  | stride=2, | 
|  | padding=1, | 
|  | count_include_pad=count_include_pad) | 
|  |  | 
|  | self.assertEqual( | 
|  | avg_pool2d(x), | 
|  | avg_pool2d(x.to_mkldnn()).to_dense()) | 
|  |  | 
|  | def test_adaptive_avg_pool2d(self): | 
|  | N = torch.randint(3, 10, (1,)).item() | 
|  | C = torch.randint(3, 10, (1,)).item() | 
|  | x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100 | 
|  |  | 
|  | adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d(7) | 
|  |  | 
|  | self.assertEqual( | 
|  | adaptive_avg_pool2d(x), | 
|  | adaptive_avg_pool2d(x.to_mkldnn()).to_dense()) | 
|  |  | 
|  | def test_batch_norm2d(self): | 
|  | N = torch.randint(3, 10, (1,)).item() | 
|  | C = torch.randint(3, 100, (1,)).item() | 
|  | x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10 | 
|  |  | 
|  | # TODO: support training | 
|  | for train in [False]: | 
|  | bn = torch.nn.BatchNorm2d(C).float().train(train) | 
|  | mkldnn_bn = mkldnn_utils.to_mkldnn(copy.deepcopy(bn)) | 
|  | self.assertEqual( | 
|  | bn(x), | 
|  | mkldnn_bn(x.to_mkldnn()).to_dense()) | 
|  |  | 
|  | def test_add(self): | 
|  | N = torch.randint(3, 10, (1,)).item() | 
|  | C = torch.randint(3, 100, (1,)).item() | 
|  | alpha = torch.randn(1, dtype=torch.float32).item() | 
|  |  | 
|  | x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10 | 
|  | y = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10 | 
|  | mx = x.to_mkldnn() | 
|  | my = y.to_mkldnn() | 
|  |  | 
|  | # add | 
|  | self.assertEqual( | 
|  | x + y, | 
|  | (mx + my).to_dense()) | 
|  |  | 
|  | self.assertEqual( | 
|  | torch.add(x, y, alpha=alpha), | 
|  | torch.add(mx, my, alpha=alpha).to_dense()) | 
|  |  | 
|  | # add_ | 
|  | x += y | 
|  | mx += my | 
|  | self.assertEqual(x, mx.to_dense()) | 
|  |  | 
|  | # add_out | 
|  | out = x.clone() | 
|  | mkldnn_out = out.to_mkldnn() | 
|  | torch.add(x, y, alpha=alpha, out=out) | 
|  | torch.add(mx, my, alpha=alpha, out=mkldnn_out) | 
|  | self.assertEqual(out, mkldnn_out.to_dense()) | 
|  |  | 
|  | def test_view(self): | 
|  | x = torch.randn(3, 4, 5, dtype=torch.float32).to_mkldnn() | 
|  | self.assertRaisesRegex(RuntimeError, | 
|  | "Change to use reshape", | 
|  | lambda: x.view(x.size(0), -1)) | 
|  |  | 
|  | def test_reshape(self): | 
|  | x = torch.randn(3, 4, 5, dtype=torch.float32) * 10 | 
|  | size = (x.size(0), -1) | 
|  |  | 
|  | self.assertEqual( | 
|  | x.reshape(size), | 
|  | x.to_mkldnn().reshape(size).to_dense(), | 
|  | ) | 
|  |  | 
|  | def test_clone(self): | 
|  | x = torch.randn(4, 5, dtype=torch.float32) * 10 | 
|  | self.assertEqual( | 
|  | x.clone(), | 
|  | x.to_mkldnn().clone().to_dense(), | 
|  | ) | 
|  |  | 
|  | def test_linear(self): | 
|  | in_features = torch.randint(3, 10, (1,)).item() | 
|  | out_features = torch.randint(3, 100, (1,)).item() | 
|  | x = torch.randn(3, in_features, dtype=torch.float32) * 10 | 
|  |  | 
|  | for bias in [True, False]: | 
|  | linear = torch.nn.Linear(in_features, out_features).float() | 
|  | mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear)) | 
|  | self.assertEqual( | 
|  | linear(x), | 
|  | mkldnn_linear(x.to_mkldnn()).to_dense()) | 
|  |  | 
|  |  | 
|  | if __name__ == '__main__': | 
|  | run_tests() |