| from __future__ import absolute_import, division, print_function, unicode_literals |
| import copy |
| import unittest |
| |
| import torch |
| from torch.utils import mkldnn as mkldnn_utils |
| from common_utils import TestCase, run_tests |
| from torch.autograd.gradcheck import gradgradcheck, gradcheck |
| |
| |
| # Comment the line below to find out the CI machines having MKL-DNN build disabled |
| @unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled") |
| class TestMkldnn(TestCase): |
| def test_conversion(self): |
| for cpu_tensor in [torch.randn((1, 2, 3, 4), |
| dtype=torch.float, device=torch.device('cpu')), |
| torch.randn((1, 2, 3, 4, 5), |
| dtype=torch.float, device=torch.device('cpu'))[:, :, :, :, 1]]: |
| cpu_tensor.requires_grad_() |
| mkldnn_tensor = cpu_tensor.to_mkldnn() |
| cpu_tensor_1 = mkldnn_tensor.to_dense() |
| self.assertEqual(cpu_tensor, cpu_tensor_1) |
| self.assertEqual(mkldnn_tensor.dtype, torch.float) |
| self.assertEqual(mkldnn_tensor.device, torch.device('cpu')) |
| self.assertEqual(mkldnn_tensor.size(), torch.Size([1, 2, 3, 4])) |
| self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel()) |
| self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size()) |
| self.assertRaisesRegex(RuntimeError, |
| "Cannot access data pointer of Tensor that doesn't have storage", |
| lambda: mkldnn_tensor.data_ptr() != 0) |
| |
| def test_unsupported(self): |
| # unsupported types and unsupported types with gpu |
| for dtype in [torch.double, torch.half, torch.uint8, torch.int8, |
| torch.short, torch.int, torch.long]: |
| with self.assertRaises(RuntimeError) as context: |
| torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cpu')).to_mkldnn() |
| if torch.cuda.is_available(): |
| with self.assertRaises(RuntimeError) as context: |
| torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cuda')).to_mkldnn() |
| # supported type with gpu |
| if torch.cuda.is_available(): |
| with self.assertRaises(RuntimeError) as context: |
| torch.randn(1, 2, 3, 4, dtype=torch.float, device=torch.device('cuda')).to_mkldnn() |
| # some factory functions |
| for creator in [torch.empty, torch.ones, torch.zeros, torch.randn, torch.rand]: |
| with self.assertRaises(RuntimeError) as context: |
| creator(1, 2, 3, 4, dtype=torch.float, device=torch.device('cpu'), layout=torch._mkldnn) |
| |
| def test_autograd_to_mkldnn(self): |
| # MKLDNN only supports float32 |
| root = torch.randn(4, 5, dtype=torch.float32, requires_grad=True) |
| |
| def func(root): |
| return root.to_mkldnn().to_dense() |
| |
| # because MKLDNN only supports float32, we need to lessen the precision. |
| # these numbers are just empirical results that seem to work. |
| self.assertWarnsRegex(lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2), |
| 'double precision floating point') |
| self.assertWarnsRegex(lambda: gradgradcheck(func, [root], atol=4e-2, rtol=1e-2), |
| 'double precision floating point') |
| |
| def test_autograd_from_mkldnn(self): |
| # MKLDNN only supports float32 |
| root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_() |
| |
| def func(root): |
| return root.to_dense() |
| |
| # because MKLDNN only supports float32, we need to lessen the precision. |
| # these numbers are just empirical results that seem to work. |
| self.assertWarnsRegex(lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2), |
| 'double precision floating point') |
| |
| def test_detach(self): |
| root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_() |
| |
| detach = root.detach() |
| self.assertEqual((4, 5), detach.size()) |
| self.assertFalse(detach.requires_grad) |
| self.assertTrue(root.requires_grad) |
| |
| detach_ = root.detach_() |
| self.assertEqual((4, 5), detach_.size()) |
| self.assertFalse(detach_.requires_grad) |
| self.assertFalse(root.requires_grad) |
| |
| def test_repr(self): |
| self.assertTrue("layout=torch._mkldnn" in str(torch.randn((1, 2, 3, 4), |
| dtype=torch.float, device=torch.device('cpu')).to_mkldnn())) |
| |
| def test_conv2d(self): |
| for groups in [1, 4]: |
| N = torch.randint(3, 10, (1,)).item() |
| C = torch.randint(1, 3, (1,)).item() * groups |
| M = torch.randint(1, 3, (1,)).item() * groups |
| x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100 |
| for bias in [True, False]: |
| conv2d = torch.nn.Conv2d(in_channels=C, |
| out_channels=M, |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| bias=bias, |
| groups=groups).float() |
| mkldnn_conv2d = mkldnn_utils.to_mkldnn(copy.deepcopy(conv2d)) |
| self.assertEqual( |
| conv2d(x), |
| mkldnn_conv2d(x.to_mkldnn()).to_dense()) |
| |
| def test_relu(self): |
| x = torch.randn((4, 5), dtype=torch.float32) * 10 |
| self.assertEqual(torch.relu(x), torch.relu(x.to_mkldnn()).to_dense()) |
| |
| def test_relu_(self): |
| x1 = torch.randn((4, 5), dtype=torch.float32) * 10 |
| x2 = x1.clone().to_mkldnn() |
| self.assertEqual(torch.relu_(x1), torch.relu_(x2).to_dense()) |
| |
| def test_max_pool2d(self): |
| N = torch.randint(3, 10, (1,)).item() |
| C = torch.randint(3, 10, (1,)).item() |
| x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10 |
| |
| max_pool2d = torch.nn.MaxPool2d( |
| kernel_size=3, |
| stride=2, |
| padding=1) |
| |
| self.assertEqual( |
| max_pool2d(x), |
| max_pool2d(x.to_mkldnn()).to_dense()) |
| |
| def test_avg_pool2d(self): |
| N = torch.randint(3, 10, (1,)).item() |
| C = torch.randint(3, 10, (1,)).item() |
| x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10 |
| |
| for count_include_pad in [True, False]: |
| avg_pool2d = torch.nn.AvgPool2d( |
| kernel_size=3, |
| stride=2, |
| padding=1, |
| count_include_pad=count_include_pad) |
| |
| self.assertEqual( |
| avg_pool2d(x), |
| avg_pool2d(x.to_mkldnn()).to_dense()) |
| |
| |
| if __name__ == '__main__': |
| run_tests() |