| # Owner(s): ["module: nn"] |
| import itertools |
| import math |
| import unittest |
| import warnings |
| from itertools import product |
| |
| import torch |
| |
| import torch.autograd.forward_ad as fwAD |
| import torch.backends.cudnn as cudnn |
| import torch.nn as nn |
| import torch.nn.functional as F |
| |
| from torch.testing import make_tensor |
| from torch.testing._internal.common_cuda import ( |
| TEST_CUDA, |
| TEST_CUDNN, |
| tf32_is_not_fp32, |
| tf32_on_and_off, |
| ) |
| from torch.testing._internal.common_device_type import ( |
| disablecuDNN, |
| disableMkldnn, |
| dtypes, |
| dtypesIfCUDA, |
| instantiate_device_type_tests, |
| largeTensorTest, |
| onlyCPU, |
| onlyCUDA, |
| onlyNativeDeviceTypes, |
| precisionOverride, |
| skipCPUIfNoMkldnn, |
| skipCUDAIfCudnnVersionLessThan, |
| skipCUDAIfMiopen, |
| skipCUDAIfNoCudnn, |
| skipCUDAIfNoMiopen, |
| skipCUDAIfNotMiopenSuggestNHWC, |
| skipCUDAIfRocm, |
| skipCUDAIfRocmVersionLessThan, |
| skipMeta, |
| ) |
| from torch.testing._internal.common_dtype import ( |
| floating_and_complex_types_and, |
| floating_types_and, |
| ) |
| from torch.testing._internal.common_nn import _test_module_empty_input, NNTestCase |
| from torch.testing._internal.common_utils import ( |
| download_file, |
| dtype2prec_DONTUSE, |
| gradcheck, |
| GRADCHECK_NONDET_TOL, |
| gradgradcheck, |
| instantiate_parametrized_tests, |
| parametrize as parametrize_test, |
| run_tests, |
| set_default_dtype, |
| skipIfNotMiopenSuggestNHWC, |
| skipIfRocmVersionLessThan, |
| subtest, |
| TEST_SCIPY, |
| TEST_WITH_ROCM, |
| ) |
| |
| AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32() |
| |
| |
| if TEST_SCIPY: |
| import scipy.ndimage |
| import scipy.signal |
| |
| |
| class TestConvolutionNN(NNTestCase): |
| _do_cuda_memory_leak_check = True |
| _do_cuda_non_default_stream = True |
| |
| def test_conv_backcompat(self): |
| from torch.serialization import SourceChangeWarning |
| |
| # This file was generated by running on PyTorch 1.0.1 on Python 2: |
| # |
| # import torch |
| # from torch import nn |
| # m = nn.Conv2d(1, 1, 1) |
| # torch.save(m, 'legacy_conv2d.pt') |
| # |
| # NB: This Pickle also contains some Unicode data! |
| path = download_file("https://download.pytorch.org/test_data/legacy_conv2d.pt") |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore", SourceChangeWarning) |
| m = torch.load(path, encoding="utf-8") |
| input = torch.randn((1, 1, 1, 1), dtype=torch.float) |
| self.assertEqual(m(input).size(), (1, 1, 1, 1)) |
| |
| def test_invalid_conv1d(self): |
| for dtype in [ |
| torch.half, |
| torch.bfloat16, |
| torch.float, |
| torch.double, |
| torch.cfloat, |
| torch.cdouble, |
| ]: |
| module = nn.Conv1d( |
| in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True |
| ).to(dtype) |
| input = torch.randn(1, 3, 4).to(dtype) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| r"Calculated padded input size per channel: \(4\). " |
| + r"Kernel size: \(10\). Kernel size can\'t be greater than actual input size", |
| ): |
| module(input) |
| |
| # Negative stride check |
| module = nn.Conv1d( |
| in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True |
| ).to(dtype) |
| input = torch.randn(1, 3, 4).to(dtype) |
| with self.assertRaisesRegex( |
| RuntimeError, "non-positive stride is not supported" |
| ): |
| module(input) |
| |
| def test_mismatch_shape_conv2d(self): |
| for dtype in (torch.float, torch.cfloat): |
| x = torch.randn(1, 10, 1, 28, 28, dtype=dtype) |
| w = torch.randn(6, 1, 5, 5, dtype=dtype) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, |
| r"Expected 3D \(unbatched\) or 4D \(batched\) input to conv2d, but got " |
| + r"input of size: \[1, 10, 1, 28, 28\]", |
| ): |
| F.conv2d(x, w) |
| |
| def test_conv2d_discontiguous_weight(self): |
| for dtype in (torch.float, torch.cfloat): |
| # Test for https://github.com/pytorch/pytorch/issues/55781 |
| x = torch.ones(64, 16, 16, 16, dtype=dtype) |
| weight = ( |
| torch.arange(0, 1.0, 1 / 2.0**10) |
| .reshape(32, 16, 1, 2) |
| .to(dtype)[:, :, :, ::2] |
| ) |
| self.assertFalse(weight.is_contiguous()) |
| y = torch.nn.functional.conv2d(x, weight, None) |
| if torch.backends.mkldnn.is_available(): |
| # Disable MKLDNN explicitly, so that either NNPACK or THCNN will be used |
| with torch.backends.mkldnn.flags(enabled=False): |
| y_ = torch.nn.functional.conv2d(x, weight, None) |
| self.assertEqual(y, y_) |
| self.assertEqual(y.sum(), 4186112.0) |
| |
| def test_invalid_conv2d(self): |
| for dtype in [ |
| torch.half, |
| torch.bfloat16, |
| torch.float, |
| torch.double, |
| torch.cfloat, |
| torch.cdouble, |
| ]: |
| module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to( |
| dtype |
| ) |
| input = torch.empty(1, 1, 4, 4).to(dtype) |
| self.assertRaises(RuntimeError, lambda: module(input)) |
| |
| module = nn.Conv2d( |
| in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True |
| ) |
| input = torch.randn(1, 3, 1, 1) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| r"Calculated padded input size per channel: \(1 x 1\). " |
| + r"Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size", |
| ): |
| module(input) |
| |
| # Negative stride check |
| module = nn.Conv2d( |
| in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True |
| ).to(dtype) |
| input = torch.randn(1, 3, 4, 4).to(dtype) |
| with self.assertRaisesRegex( |
| RuntimeError, "non-positive stride is not supported" |
| ): |
| module(input) |
| |
| # Zero stride check |
| module = nn.Conv2d( |
| in_channels=3, out_channels=6, kernel_size=4, stride=0, bias=True |
| ).to(dtype) |
| input = torch.randn(1, 3, 4, 4).to(dtype) |
| with self.assertRaisesRegex( |
| RuntimeError, "non-positive stride is not supported" |
| ): |
| module(input) |
| |
| def test_invalid_conv3d(self): |
| for dtype in [ |
| torch.half, |
| torch.bfloat16, |
| torch.float, |
| torch.double, |
| torch.cfloat, |
| torch.cdouble, |
| ]: |
| module = torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2).to( |
| dtype |
| ) |
| input = torch.empty(1, 1, 4, 4, 4).to(dtype) |
| self.assertRaises(RuntimeError, lambda: module(input)) |
| |
| # Negative stride check |
| module = torch.nn.Conv3d(1, 1, kernel_size=3, stride=-2) |
| input = torch.empty(1, 1, 4, 4, 4) |
| with self.assertRaisesRegex( |
| RuntimeError, "non-positive stride is not supported" |
| ): |
| module(input) |
| |
| def test_conv_invalid_groups(self): |
| with self.assertRaisesRegex(ValueError, "groups must be a positive integer"): |
| torch.nn.Conv1d(1, 1, kernel_size=3, dilation=2, stride=2, groups=0) |
| with self.assertRaisesRegex(ValueError, "groups must be a positive integer"): |
| torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-1) |
| with self.assertRaisesRegex(ValueError, "groups must be a positive integer"): |
| torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2, groups=-2) |
| |
| def test_Conv1d_module_same_padding(self): |
| # Compare module against functional: without strides/dilation, asymmetric padding |
| x = torch.rand(1, 1, 20) |
| module = nn.Conv1d( |
| in_channels=1, out_channels=1, kernel_size=10, padding="same" |
| ) |
| expect = F.conv1d(x, module.weight, module.bias, padding="same") |
| self.assertEqual(expect, module(x)) |
| |
| # Test dilation, symmetric padding |
| module = nn.Conv1d( |
| in_channels=1, out_channels=1, kernel_size=10, padding="same", dilation=2 |
| ) |
| expect = F.conv1d(x, module.weight, module.bias, padding="same", dilation=2) |
| self.assertEqual(expect, module(x)) |
| |
| # Test non-zero padding_mode, requiring explicit padding |
| module = nn.Conv1d( |
| in_channels=1, |
| out_channels=1, |
| kernel_size=10, |
| padding="same", |
| padding_mode="replicate", |
| ) |
| x_padded = F.pad(x, [4, 5], mode="replicate") |
| expect = F.conv1d(x_padded, module.weight, module.bias, padding="valid") |
| self.assertEqual(expect, module(x)) |
| self.assertEqual(x.size(), expect.size()) |
| |
| # Test connstruction with invalid padding string raises |
| with self.assertRaisesRegex(ValueError, "Invalid padding string"): |
| module = nn.Conv1d( |
| in_channels=3, out_channels=33, kernel_size=10, padding="foo" |
| ) |
| |
| # Test connstruction with same padding and strides raises |
| with self.assertRaisesRegex(ValueError, "padding='same'"): |
| module = nn.Conv1d( |
| in_channels=3, out_channels=33, kernel_size=10, padding="same", stride=2 |
| ) |
| |
| def test_Conv2d_module_same_padding(self): |
| # Compare module against functional: |
| # without strides/dilation, both symmetric and asymmetric padding |
| x = torch.rand(1, 1, 9, 20) |
| module = nn.Conv2d( |
| in_channels=1, out_channels=1, kernel_size=(5, 10), padding="same" |
| ) |
| expect = F.conv2d(x, module.weight, module.bias, padding="same") |
| self.assertEqual(expect, module(x)) |
| |
| # with dilation, symmetric padding |
| module = nn.Conv2d( |
| in_channels=1, |
| out_channels=1, |
| kernel_size=(3, 4), |
| padding="same", |
| dilation=(1, 2), |
| ) |
| expect = F.conv2d( |
| x, module.weight, module.bias, padding="same", dilation=(1, 2) |
| ) |
| self.assertEqual(expect, module(x)) |
| |
| # Test non-zero padding_mode, requiring explicit padding |
| module = nn.Conv2d( |
| in_channels=1, |
| out_channels=1, |
| kernel_size=(3, 4), |
| padding="same", |
| padding_mode="reflect", |
| ) |
| x_padded = F.pad(x, [1, 2, 1, 1], mode="reflect") |
| expect = F.conv2d(x_padded, module.weight, module.bias, padding="valid") |
| self.assertEqual(expect, module(x)) |
| self.assertEqual(x.size(), expect.size()) |
| |
| # Test connstruction with invalid padding string raises |
| with self.assertRaisesRegex(ValueError, "Invalid padding string"): |
| module = nn.Conv2d( |
| in_channels=3, out_channels=33, kernel_size=10, padding="foo" |
| ) |
| |
| # Test connstruction with same padding and strides raises |
| with self.assertRaisesRegex(ValueError, "padding='same'"): |
| module = nn.Conv2d( |
| in_channels=3, out_channels=33, kernel_size=10, padding="same", stride=2 |
| ) |
| with self.assertRaisesRegex(ValueError, "padding='same'"): |
| module = nn.Conv2d( |
| in_channels=3, |
| out_channels=33, |
| kernel_size=10, |
| padding="same", |
| stride=(1, 3), |
| ) |
| with self.assertRaisesRegex(ValueError, "padding='same'"): |
| module = nn.Conv2d( |
| in_channels=3, |
| out_channels=33, |
| kernel_size=10, |
| padding="same", |
| stride=(4, 1), |
| ) |
| |
| def test_Conv3d_module_same_padding(self): |
| # Compare module against functional: |
| x = torch.rand(1, 1, 4, 4, 4) |
| # without dilation, both symmetric and asymmetric padding |
| module = nn.Conv3d( |
| in_channels=1, out_channels=1, kernel_size=(2, 3, 4), padding="same" |
| ) |
| expect = F.conv3d(x, module.weight, module.bias, padding="same") |
| self.assertEqual(expect, module(x)) |
| |
| # with dilation, both symmetric and asymmetric padding |
| module = nn.Conv3d( |
| in_channels=1, |
| out_channels=1, |
| kernel_size=(2, 3, 4), |
| padding="same", |
| dilation=(3, 2, 1), |
| ) |
| expect = F.conv3d( |
| x, module.weight, module.bias, padding="same", dilation=(3, 2, 1) |
| ) |
| self.assertEqual(expect, module(x)) |
| |
| # Test non-zero padding_mode, requiring explicit padding |
| module = nn.Conv3d( |
| in_channels=1, |
| out_channels=1, |
| kernel_size=(2, 3, 4), |
| padding="same", |
| padding_mode="circular", |
| ) |
| x_padded = F.pad(x, [1, 2, 1, 1, 0, 1], mode="circular") |
| expect = F.conv3d(x_padded, module.weight, module.bias, padding="valid") |
| self.assertEqual(expect, module(x)) |
| self.assertEqual(x.size(), expect.size()) |
| |
| # Test connstruction with invalid padding string raises |
| with self.assertRaisesRegex(ValueError, "Invalid padding string"): |
| module = nn.Conv3d( |
| in_channels=3, out_channels=33, kernel_size=10, padding="foo" |
| ) |
| |
| # Test connstruction with same padding and strides raises |
| with self.assertRaisesRegex(ValueError, "padding='same'"): |
| module = nn.Conv2d( |
| in_channels=3, out_channels=33, kernel_size=10, padding="same", stride=2 |
| ) |
| with self.assertRaisesRegex(ValueError, "padding='same'"): |
| module = nn.Conv2d( |
| in_channels=3, |
| out_channels=33, |
| kernel_size=10, |
| padding="same", |
| stride=(1, 1, 3), |
| ) |
| with self.assertRaisesRegex(ValueError, "padding='same'"): |
| module = nn.Conv2d( |
| in_channels=3, |
| out_channels=33, |
| kernel_size=10, |
| padding="same", |
| stride=(1, 4, 1), |
| ) |
| with self.assertRaisesRegex(ValueError, "padding='same'"): |
| module = nn.Conv2d( |
| in_channels=3, |
| out_channels=33, |
| kernel_size=10, |
| padding="same", |
| stride=(5, 1, 1), |
| ) |
| |
| @unittest.skipIf(not TEST_CUDA, "CUDA not available") |
| def test_thnn_conv_strided_padded_dilated(self): |
| for convfn, dims, transposed in ( |
| (torch.nn.functional.conv2d, 2, False), |
| (torch.nn.functional.conv_transpose2d, 2, True), |
| (torch.nn.functional.conv3d, 3, False), |
| (torch.nn.functional.conv_transpose3d, 3, True), |
| ): |
| for stride, padding, dilation in ( |
| (2, 0, 1), |
| (1, 1, 1), |
| (2, 1, 1), |
| (1, 0, 2), |
| ): |
| kwargs = {"stride": stride, "padding": padding, "dilation": dilation} |
| inp_shape = (1, 2) + dims * (4,) |
| weight_shape = (2, 2) + dims * (1,) |
| inputs = torch.randn( |
| inp_shape, dtype=torch.double, device="cuda", requires_grad=True |
| ) |
| weight = torch.randn( |
| weight_shape, dtype=torch.double, device="cuda", requires_grad=True |
| ) |
| bias = torch.randn( |
| 2, dtype=torch.double, device="cuda", requires_grad=True |
| ) |
| with torch.backends.cudnn.flags(enabled=False): |
| res = convfn(inputs, weight, bias, **kwargs) |
| res_cpu = convfn(inputs.cpu(), weight.cpu(), bias.cpu(), **kwargs) |
| self.assertEqual(res, res_cpu) |
| with torch.backends.cudnn.flags(enabled=False): |
| torch.autograd.gradcheck( |
| lambda x, w, b: convfn(x, w, b, **kwargs), |
| (inputs, weight, bias), |
| ) |
| torch.autograd.gradcheck( |
| lambda x, w, b: convfn(x, w, b, **kwargs), |
| (inputs.cpu(), weight.cpu(), bias.cpu()), |
| ) |
| |
| def test_Conv2d_inconsistent_types(self): |
| inputs = torch.randn(4, 1, 7, 7, dtype=torch.float) |
| weights = torch.randn(1, 1, 3, 3, dtype=torch.double) |
| # inconsistent types should raise an exception |
| self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights)) |
| # but it should work with the same type |
| nn.functional.conv2d(inputs.float(), weights.float()) |
| |
| @unittest.skipIf(not TEST_CUDA, "CUDA not available") |
| def test_Conv2d_inconsistent_types_on_GPU_without_cudnn(self): |
| inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda") |
| weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda") |
| bias = torch.randn(1, dtype=torch.double, device="cuda") |
| |
| with torch.backends.cudnn.flags(enabled=False): |
| # inconsistent types should raise an exception |
| self.assertRaises( |
| RuntimeError, lambda: nn.functional.conv2d(inputs, weights) |
| ) |
| self.assertRaises( |
| RuntimeError, |
| lambda: nn.functional.conv2d(inputs, weights.float(), bias), |
| ) |
| |
| # but it should work with the same type |
| nn.functional.conv2d(inputs.float(), weights.float(), bias.float()) |
| |
| def test_Conv2d_1x1(self): |
| in_channels = 2 |
| out_channels = 2 |
| mod = torch.nn.Conv2d(2, 2, 1, bias=False).to(dtype=torch.double) |
| input = torch.randn( |
| 1, in_channels, 5, 5, requires_grad=True, dtype=torch.double |
| ) |
| for enabled in (False, True): |
| with torch.backends.mkldnn.flags(enabled=enabled): |
| gradcheck(F.conv2d, (input, mod.weight)) |
| |
| def test_Conv2d_OneDNN(self): |
| def run_once(group_val=24, dilation=1): |
| ifm = torch.ones([1, group_val, 6, 6], dtype=torch.float32) |
| weights = torch.ones([group_val, 1, 3, 3], dtype=torch.float32) |
| op = torch.nn.Conv2d( |
| in_channels=group_val, |
| out_channels=group_val, |
| kernel_size=[3, 3], |
| stride=[2, 2], |
| padding=[1, 1], |
| dilation=[dilation, dilation], |
| groups=group_val, |
| bias=False, |
| padding_mode="zeros", |
| ) |
| |
| op.weight.data = weights |
| res = op(ifm) |
| grad_in = torch.ones(res.shape, dtype=torch.float32) |
| res.backward(grad_in) |
| return op.weight.grad |
| |
| for gorup_val in (24, 48, 23, 25): |
| for dilation in (1, 2): |
| with torch.backends.mkldnn.flags(enabled=False): |
| without_onednn = run_once(gorup_val, dilation) |
| |
| with torch.backends.mkldnn.flags(enabled=True): |
| with_onednn = run_once(gorup_val, dilation) |
| |
| self.assertEqual(without_onednn, with_onednn) |
| |
| @unittest.skipIf(not TEST_CUDA, "CUDA not available") |
| @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") |
| def test_cudnn_non_contiguous(self): |
| x = torch.randn(192, 16, 50).cuda() |
| x = x.permute(0, 2, 1).contiguous().permute(0, 2, 1) |
| m = torch.nn.Conv1d( |
| in_channels=16, out_channels=32, kernel_size=2, bias=True |
| ).cuda() |
| result = m(x) |
| |
| @unittest.skipIf(not TEST_CUDA, "CUDA not available") |
| @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") |
| def test_cudnn_not_mutate_stride(self): |
| weight = torch.randn(64, 64, 1, 1) |
| x = torch.randn(2, 64, 10, 10).to(memory_format=torch.channels_last) |
| weight_stride = weight.stride() |
| |
| def conv(x, weight): |
| return torch.convolution( |
| x, |
| weight, |
| stride=(1, 1), |
| padding=(0, 0), |
| dilation=(1, 1), |
| transposed=False, |
| output_padding=(0, 0), |
| groups=1, |
| bias=None, |
| ) |
| |
| # should have run in nhwc without mutating input strides |
| out_nhwc = conv(x, weight) |
| self.assertEqual(weight.stride(), weight_stride) |
| self.assertTrue(out_nhwc.is_contiguous(memory_format=torch.channels_last)) |
| |
| x = x.contiguous(memory_format=torch.contiguous_format) |
| out_c = conv(x, weight) |
| self.assertTrue(out_c.is_contiguous(memory_format=torch.contiguous_format)) |
| self.assertEqual(out_c, out_nhwc) |
| self.assertEqual(weight.stride(), weight_stride) |
| |
| @unittest.skipIf(not TEST_CUDA, "CUDA not available") |
| @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") |
| def test_Conv2d_inconsistent_types_on_GPU_with_cudnn(self): |
| inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda") |
| weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda") |
| bias = torch.randn(1, dtype=torch.double, device="cuda") |
| |
| with torch.backends.cudnn.flags(enabled=True): |
| # inconsistent types should raise an exception |
| self.assertRaises( |
| RuntimeError, lambda: nn.functional.conv2d(inputs, weights) |
| ) |
| self.assertRaises( |
| RuntimeError, |
| lambda: nn.functional.conv2d(inputs, weights.float(), bias), |
| ) |
| |
| # but it should work with the same type |
| nn.functional.conv2d(inputs.float(), weights.float(), bias.float()) |
| |
| def test_Conv2d_missing_argument(self): |
| c = nn.Conv2d(3, 3, 3) |
| self.assertRaises(TypeError, lambda: c(None)) |
| |
| def test_Conv2d_backward_twice(self): |
| input = torch.randn(2, 3, 5, 5) |
| c = nn.Conv2d(3, 3, 3) |
| o1 = c(input) |
| o1.sum().backward() |
| self.assertRaisesRegex( |
| RuntimeError, "Specify retain_graph=True", lambda: o1.sum().backward() |
| ) |
| |
| def test_conv_modules_raise_error_on_incorrect_input_size(self): |
| for dtype in [torch.half, torch.bfloat16, torch.double, torch.float]: |
| modules = [ |
| nn.Conv1d(3, 8, 3).to(dtype), |
| nn.ConvTranspose1d(3, 8, 3).to(dtype), |
| nn.Conv2d(3, 8, 3).to(dtype), |
| nn.ConvTranspose2d(3, 8, 3).to(dtype), |
| nn.Conv3d(3, 8, 3).to(dtype), |
| nn.ConvTranspose3d(3, 8, 3).to(dtype), |
| ] |
| |
| invalid_input_dims = [(1, 4), (1, 4), (2, 5), (2, 5), (3, 6), (3, 6)] |
| |
| for invalid_dims, module in zip(invalid_input_dims, modules): |
| for dims in invalid_dims: |
| input = torch.empty(torch.Size((3,) * dims)) |
| self.assertRaises(RuntimeError, lambda: module(input)) |
| |
| def test_conv_shapecheck(self): |
| def test(should_raise, module, input_size, dtype): |
| input = torch.empty(3, *input_size).to(dtype) |
| if should_raise: |
| self.assertRaises(RuntimeError, lambda: module(input)) |
| else: |
| # just run it to ensure no exception raised. |
| module(input) |
| |
| for dtype in [ |
| torch.half, |
| torch.bfloat16, |
| torch.float, |
| torch.double, |
| torch.cfloat, |
| torch.cdouble, |
| ]: |
| # Conv1d |
| test(True, nn.Conv1d(1, 1, 3).to(dtype), (1, 2), dtype) |
| test(True, nn.Conv1d(1, 1, 3, stride=2).to(dtype), (1, 2), dtype) |
| test(False, nn.Conv1d(1, 1, 2).to(dtype), (1, 2), dtype) |
| test(False, nn.Conv1d(1, 1, 2, stride=2).to(dtype), (1, 2), dtype) |
| test( |
| False, nn.Conv1d(1, 1, 3, stride=2, padding=1).to(dtype), (1, 2), dtype |
| ) |
| |
| # Conv2d |
| test(True, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 2, 2), dtype) |
| test(False, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 3, 3), dtype) |
| test(False, nn.Conv2d(1, 1, (3, 3), padding=1).to(dtype), (1, 2, 2), dtype) |
| |
| # Conv3D |
| test(True, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 2, 2, 2), dtype) |
| test(False, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 3, 3, 3), dtype) |
| test( |
| False, |
| nn.Conv3d(1, 1, (3, 3, 3), padding=1).to(dtype), |
| (1, 2, 2, 2), |
| dtype, |
| ) |
| |
| def test_ConvTranspose2d_output_size(self): |
| m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2) |
| i = torch.randn(2, 3, 6, 6) |
| for h in range(15, 22): |
| for w in range(15, 22): |
| if 18 <= h <= 20 and 18 <= w <= 20: |
| output = m(i, output_size=(h, w)) |
| self.assertEqual(output.size()[2:], (h, w)) |
| else: |
| self.assertRaises(ValueError, lambda: m(i, (h, w))) |
| |
| def test_ConvTranspose2d_output_size_downsample_upsample(self): |
| b, c, hid_c = 2, 3, 2 |
| for h in range(13, 24): |
| for w in range(13, 17): |
| for k in range(2, 5): |
| for d in range(1, 5): |
| for s in range(1, 4): |
| for p in range(3): |
| conv = nn.Conv2d( |
| in_channels=c, |
| out_channels=hid_c, |
| kernel_size=k, |
| stride=s, |
| padding=p, |
| dilation=d, |
| ) |
| |
| t_conv = nn.ConvTranspose2d( |
| in_channels=hid_c, |
| out_channels=c, |
| kernel_size=k, |
| stride=s, |
| padding=p, |
| dilation=d, |
| ) |
| |
| i = torch.randn(b, c, h, w) |
| |
| out = t_conv(conv(i), output_size=i.shape) |
| |
| self.assertEqual(out.size()[2:], i.size()[2:]) |
| |
| def test_ConvTranspose3d_correct_output_size(self): |
| # Check that ConvTranspose3d can take a 5d output_size. |
| m = nn.ConvTranspose3d(2, 2, 2) |
| i = torch.rand(1, 2, 1, 1, 1) |
| out = m(i, output_size=(1, 2, 2, 2, 2)) |
| |
| @unittest.skipIf(not TEST_CUDA, "CUDA not available") |
| def test_ConvTranspose2d_half_cublas_gemm(self): |
| with torch.backends.cudnn.flags(enabled=False): |
| inputs = torch.randn(1, 1, 16, 16, device="cuda", dtype=torch.half) |
| deconv = ( |
| nn.ConvTranspose2d(1, 1, 3, stride=2, padding=1, output_padding=1) |
| .cuda() |
| .half() |
| ) |
| output = deconv(inputs) |
| output.mean().backward() |
| |
| # For https://github.com/pytorch/pytorch/pull/1273 |
| # Almost identical to the above `test_Conv2d_naive_groups` |
| @torch.backends.cudnn.flags(enabled=True, benchmark=False) |
| @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7") |
| def test_Conv2d_groups_nobias(self): |
| dev_dtypes = [("cpu", torch.float)] |
| if TEST_CUDA: |
| dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)] |
| if AMPERE_OR_ROCM: |
| dev_dtypes += [("cuda", torch.bfloat16)] |
| for device, dtype in dev_dtypes: |
| m = nn.Conv2d(4, 4, kernel_size=3, groups=2, bias=False).to(device, dtype) |
| i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True) |
| output = m(i) |
| grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype) |
| output.backward(grad_output) |
| |
| m1 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype) |
| m1.weight.data.copy_(m.weight.data[:2]) |
| i1 = i.data[:, :2].contiguous().requires_grad_(True) |
| output1 = m1(i1) |
| output1.backward(grad_output[:, :2].contiguous()) |
| |
| m2 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype) |
| m2.weight.data.copy_(m.weight.data[2:]) |
| i2 = i.data[:, 2:].contiguous().requires_grad_(True) |
| output2 = m2(i2) |
| output2.backward(grad_output[:, 2:].contiguous()) |
| |
| self.assertEqual(output, torch.cat([output1, output2], 1)) |
| self.assertEqual( |
| i.grad.data, |
| torch.cat([i1.grad.data, i2.grad.data], 1), |
| atol=dtype2prec_DONTUSE[dtype], |
| rtol=0, |
| ) |
| self.assertEqual( |
| m.weight.grad.data, |
| torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), |
| atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], |
| rtol=0, |
| ) |
| |
| # Almost identical to the above `test_Conv2d_naive_groups` |
| # Covering special case when group > 1, input-channel / group < 16 and output-channel is multiple of 16 |
| # See also https://github.com/pytorch/pytorch/pull/18463#issuecomment-476563686 |
| # and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024 |
| @torch.backends.cudnn.flags(enabled=True, benchmark=False) |
| @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7") |
| def test_Conv2d_groups_nobias_v2(self): |
| torch.manual_seed(123) |
| dev_dtypes = [("cpu", torch.float)] |
| if TEST_CUDA: |
| dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)] |
| if AMPERE_OR_ROCM: |
| dev_dtypes += [("cuda", torch.bfloat16)] |
| for device, dtype in dev_dtypes: |
| m = nn.Conv2d(4, 16, kernel_size=3, groups=2, bias=False).to(device, dtype) |
| i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True) |
| output = m(i) |
| grad_output = torch.randn(2, 16, 4, 4, device=device, dtype=dtype) |
| output.backward(grad_output) |
| |
| m1 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype) |
| m1.weight.data.copy_(m.weight.data[:8]) |
| i1 = i.data[:, :2].contiguous().requires_grad_(True) |
| output1 = m1(i1) |
| output1.backward(grad_output[:, :8].contiguous()) |
| |
| m2 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype) |
| m2.weight.data.copy_(m.weight.data[8:]) |
| i2 = i.data[:, 2:].contiguous().requires_grad_(True) |
| output2 = m2(i2) |
| output2.backward(grad_output[:, 8:].contiguous()) |
| |
| self.assertEqual(output, torch.cat([output1, output2], 1)) |
| self.assertEqual( |
| i.grad.data, |
| torch.cat([i1.grad.data, i2.grad.data], 1), |
| atol=dtype2prec_DONTUSE[dtype], |
| rtol=0, |
| ) |
| self.assertEqual( |
| m.weight.grad.data, |
| torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), |
| atol=1e-1 if dtype == torch.half else dtype2prec_DONTUSE[dtype], |
| rtol=0, |
| ) |
| |
| # CPU-only test for group conv3d fast implementation using bmm |
| # See: https://github.com/pytorch/pytorch/pull/36355 |
| def test_Conv3d_groups_nobias(self): |
| torch.manual_seed(123) |
| m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=False).to("cpu", torch.float) |
| i = torch.randn( |
| 2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True |
| ) |
| output = m(i) |
| grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float) |
| output.backward(grad_output) |
| |
| m1 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float) |
| m1.weight.data.copy_(m.weight.data[:8]) |
| i1 = i.data[:, :2].contiguous().requires_grad_(True) |
| output1 = m1(i1) |
| output1.backward(grad_output[:, :8].contiguous()) |
| |
| m2 = nn.Conv3d(2, 8, kernel_size=3, bias=False).to("cpu", torch.float) |
| m2.weight.data.copy_(m.weight.data[8:]) |
| i2 = i.data[:, 2:].contiguous().requires_grad_(True) |
| output2 = m2(i2) |
| output2.backward(grad_output[:, 8:].contiguous()) |
| |
| self.assertEqual(output, torch.cat([output1, output2], 1)) |
| self.assertEqual( |
| i.grad.data, |
| torch.cat([i1.grad.data, i2.grad.data], 1), |
| atol=dtype2prec_DONTUSE[torch.float], |
| rtol=0, |
| ) |
| self.assertEqual( |
| m.weight.grad.data, |
| torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), |
| atol=dtype2prec_DONTUSE[torch.float], |
| rtol=dtype2prec_DONTUSE[torch.float], |
| ) |
| |
| def test_Conv3d_groups_wbias(self): |
| torch.manual_seed(123) |
| m = nn.Conv3d(4, 16, kernel_size=3, groups=2, bias=True).to("cpu", torch.float) |
| i = torch.randn( |
| 2, 4, 6, 6, 6, device="cpu", dtype=torch.float, requires_grad=True |
| ) |
| output = m(i) |
| grad_output = torch.randn(2, 16, 4, 4, 4, device="cpu", dtype=torch.float) |
| output.backward(grad_output) |
| |
| m1 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float) |
| m1.weight.data.copy_(m.weight.data[:8]) |
| m1.bias.data.copy_(m.bias.data[:8]) |
| i1 = i.data[:, :2].contiguous().requires_grad_(True) |
| output1 = m1(i1) |
| output1.backward(grad_output[:, :8].contiguous()) |
| |
| m2 = nn.Conv3d(2, 8, kernel_size=3, bias=True).to("cpu", torch.float) |
| m2.weight.data.copy_(m.weight.data[8:]) |
| m2.bias.data.copy_(m.bias.data[8:]) |
| i2 = i.data[:, 2:].contiguous().requires_grad_(True) |
| output2 = m2(i2) |
| output2.backward(grad_output[:, 8:].contiguous()) |
| |
| self.assertEqual(output, torch.cat([output1, output2], 1)) |
| self.assertEqual( |
| i.grad.data, |
| torch.cat([i1.grad.data, i2.grad.data], 1), |
| atol=dtype2prec_DONTUSE[torch.float], |
| rtol=dtype2prec_DONTUSE[torch.float], |
| ) |
| self.assertEqual( |
| m.weight.grad.data, |
| torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), |
| atol=dtype2prec_DONTUSE[torch.float], |
| rtol=dtype2prec_DONTUSE[torch.float], |
| ) |
| self.assertEqual( |
| m.bias.grad.data, |
| torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0), |
| atol=dtype2prec_DONTUSE[torch.float], |
| rtol=dtype2prec_DONTUSE[torch.float], |
| ) |
| |
| def test_conv_tbc(self): |
| with set_default_dtype(torch.double): |
| inp = torch.randn(9, 4, 5, requires_grad=True) |
| weight = torch.randn(3, 5, 6, requires_grad=True) |
| bias = torch.randn(6, requires_grad=True) |
| |
| gradcheck( |
| lambda i, w, b, pad: F.conv_tbc(i, w, b, pad), (inp, weight, bias, 3) |
| ) |
| |
| @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") |
| @unittest.skipIf(not TEST_CUDNN, "needs cudnn") |
| @skipIfRocmVersionLessThan((4, 3)) |
| @skipIfNotMiopenSuggestNHWC |
| def test_grouped_conv_cudnn_nhwc_support(self): |
| # in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version |
| input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to( |
| memory_format=torch.channels_last |
| ) |
| weight = torch.randn((8, 4, 3, 3), dtype=torch.float16, device="cuda").to( |
| memory_format=torch.channels_last |
| ) |
| out = torch.convolution( |
| input, weight, None, (1, 1), (1, 1), (1, 1), False, (0, 0), 4 |
| ) |
| input = torch.randn((16, 8, 8, 8), dtype=torch.float16, device="cuda").to( |
| memory_format=torch.channels_last |
| ) |
| out_transpose = torch.convolution( |
| input, weight, None, (1, 1), (1, 1), (1, 1), True, (0, 0), 4 |
| ) |
| |
| @unittest.expectedFailure |
| @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") |
| @unittest.skipIf(not TEST_CUDNN, "needs cudnn") |
| def test_conv_cudnn_memory_layout_dominance(self): |
| # desired behavior here is to have the memory_layout of conv.weight to |
| # dominante the layout of output. |
| # which is not the same as current behavior, we'll fix this in |
| # following up PRs and remove the `expectedFailure` tag |
| input = torch.randint( |
| 1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True |
| ) |
| conv = nn.Conv2d(8, 4, 3).cuda().float() |
| |
| out = conv(input) |
| self.assertTrue(out.is_contiguous()) |
| |
| input = input.contiguous(memory_format=torch.channels_last) |
| out = conv(input) |
| self.assertTrue(out.is_contiguous()) |
| |
| conv.weight.data = conv.weight.contiguous(memory_format=torch.channels_last) |
| out = conv(input) |
| self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) |
| |
| input = input.contiguous() |
| out = conv(input) |
| self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) |
| |
| @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") |
| def test_cudnn_noncontiguous_weight(self): |
| # Noncontiguous weights must be contiguous() before being |
| # passed to cuDNN |
| input = torch.tensor([1, 1, 1], dtype=torch.double, device="cuda").view(1, 1, 3) |
| weights1 = torch.tensor([1], dtype=torch.double, device="cuda").expand(1, 1, 2) |
| weights2 = ( |
| torch.tensor([1], dtype=torch.double, device="cuda") |
| .expand(1, 1, 2) |
| .contiguous() |
| ) |
| self.assertEqual( |
| F.conv1d(input, weights1, bias=None, stride=2, dilation=2), |
| F.conv1d(input, weights2, bias=None, stride=2, dilation=2), |
| ) |
| |
| def run_grad_conv_test(self, func_forward, func_backward, dim=1, gradient="input"): |
| for kern, inp_size in [(3, 6), (3, 7), (4, 9)]: |
| for batch, stride, padding, chan_in, chan_out, dilation in product( |
| [1, 2], [1, 2], [0, 1, 2], [2], [3], [1] |
| ): |
| for has_bias in [True, False]: |
| input_shape = [batch, chan_in] |
| weight_shape = [chan_out, chan_in] |
| for _ in range(dim): |
| input_shape.append(inp_size) |
| weight_shape.append(kern) |
| |
| input = torch.randn(input_shape, requires_grad=True) |
| weight = torch.randn(weight_shape, requires_grad=True) |
| if has_bias: |
| bias = torch.randn([chan_out], requires_grad=True) |
| output = func_forward( |
| input, |
| weight, |
| stride=stride, |
| padding=padding, |
| dilation=dilation, |
| bias=bias, |
| ) |
| |
| gradient_o = torch.randn(output.shape) |
| gradient_w = torch.autograd.grad( |
| output, input if (gradient == "input") else weight, gradient_o |
| ) |
| |
| self.assertEqual( |
| gradient_w[0], |
| func_backward( |
| input_shape if (gradient == "input") else input, |
| weight_shape if (gradient == "weight") else weight, |
| gradient_o, |
| stride=stride, |
| padding=padding, |
| dilation=dilation, |
| ), |
| ) |
| |
| def test_grad_conv1d_input(self): |
| self.run_grad_conv_test(F.conv1d, F.grad.conv1d_input, 1, "input") |
| |
| def test_grad_conv1d_weight(self): |
| self.run_grad_conv_test(F.conv1d, F.grad.conv1d_weight, 1, "weight") |
| |
| def test_grad_conv2d_input(self): |
| self.run_grad_conv_test(F.conv2d, F.grad.conv2d_input, 2, "input") |
| |
| def test_grad_conv2d_weight(self): |
| self.run_grad_conv_test(F.conv2d, F.grad.conv2d_weight, 2, "weight") |
| |
| def test_grad_conv3d_input(self): |
| self.run_grad_conv_test(F.conv3d, F.grad.conv3d_input, 3, "input") |
| |
| def test_grad_conv3d_weight(self): |
| self.run_grad_conv_test(F.conv3d, F.grad.conv3d_weight, 3, "weight") |
| |
| @unittest.skipIf(not torch._nnpack_available(), "NNPACK unavailable") |
| def test_nnpack_conv(self): |
| for kern, inp_size in [(3, 6), (3, 7), (4, 9)]: |
| for batch, stride, padding, chan_in, chan_out in product( |
| [1, 2, 3, 4], [1, 2], [0, 1, 2], [2], [3] |
| ): |
| for has_bias in [True, False]: |
| input_shape = [batch, chan_in] |
| weight_shape = [chan_out, chan_in] |
| for _ in range(2): |
| input_shape.append(inp_size) |
| weight_shape.append(kern) |
| |
| input = torch.randn( |
| input_shape, requires_grad=True, dtype=torch.float |
| ) |
| weight = torch.randn( |
| weight_shape, requires_grad=True, dtype=torch.float |
| ) |
| if has_bias: |
| bias = torch.randn( |
| [chan_out], requires_grad=True, dtype=torch.float |
| ) |
| output = torch._nnpack_spatial_convolution( |
| input, weight, stride=stride, padding=padding, bias=bias |
| ) |
| output_expected = torch.nn.functional.conv2d( |
| input, weight, stride=stride, padding=padding, bias=bias |
| ) |
| self.assertEqual(output, output_expected, atol=3e-4, rtol=0) |
| |
| gradient_o = torch.randn(output.shape, dtype=torch.float) |
| |
| grads = torch.autograd.grad(output, [input, weight], gradient_o) |
| grads_expected = torch.autograd.grad( |
| output_expected, [input, weight], gradient_o |
| ) |
| for gr, gr_expected in zip(grads, grads_expected): |
| self.assertEqual(gr, gr_expected, atol=3e-4, rtol=0) |
| |
| def test_conv_padding_mode(self): |
| with self.assertRaisesRegex(ValueError, "padding_mode must be one of"): |
| nn.Conv2d(3, 3, 3, padding_mode="xyz") |
| |
| with self.assertRaisesRegex(ValueError, "padding_mode must be one of"): |
| nn.Conv2d(3, 3, 3, padding_mode=3) |
| |
| with self.assertRaisesRegex(ValueError, 'Only "zeros" '): |
| nn.ConvTranspose2d(3, 3, 3, padding_mode="reflect") |
| |
| def test_functional_grad_conv(self): |
| # Conv 1D |
| input = torch.randn(1, 1, 5, requires_grad=True) |
| weight = torch.randn(1, 1, 3, requires_grad=True) |
| output = F.conv1d(input, weight, dilation=2) |
| grad_output = torch.randn(output.shape) |
| |
| grad_input_autograd, grad_weight_autograd = torch.autograd.grad( |
| output, (input, weight), grad_output |
| ) |
| |
| grad_input_functional = torch.nn.grad.conv1d_input( |
| input.shape, weight, grad_output, dilation=2 |
| ) |
| self.assertEqual(grad_input_functional, grad_input_autograd) |
| |
| grad_weight_functional = torch.nn.grad.conv1d_weight( |
| input, weight.shape, grad_output, dilation=2 |
| ) |
| self.assertEqual(grad_weight_functional, grad_weight_autograd) |
| |
| # Conv 2D |
| input = torch.randn(1, 1, 5, 5, requires_grad=True) |
| weight = torch.randn(1, 1, 3, 3, requires_grad=True) |
| output = F.conv2d(input, weight, dilation=2) |
| grad_output = torch.randn(output.shape) |
| |
| (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad( |
| output, (input, weight), grad_output |
| ) |
| |
| grad_input_functional = torch.nn.grad.conv2d_input( |
| input.shape, weight, grad_output, dilation=2 |
| ) |
| self.assertEqual(grad_input_functional, grad_input_autograd) |
| |
| grad_weight_functional = torch.nn.grad.conv2d_weight( |
| input, weight.shape, grad_output, dilation=2 |
| ) |
| self.assertEqual(grad_weight_functional, grad_weight_autograd) |
| |
| # Conv 3D |
| input = torch.randn(1, 1, 5, 5, 5, requires_grad=True) |
| weight = torch.randn(1, 1, 3, 3, 3, requires_grad=True) |
| output = F.conv3d(input, weight, dilation=2) |
| grad_output = torch.randn(output.shape) |
| |
| (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad( |
| output, (input, weight), grad_output |
| ) |
| |
| grad_input_functional = torch.nn.grad.conv3d_input( |
| input.shape, weight, grad_output, dilation=2 |
| ) |
| self.assertEqual(grad_input_functional, grad_input_autograd) |
| |
| grad_weight_functional = torch.nn.grad.conv3d_weight( |
| input, weight.shape, grad_output, dilation=2 |
| ) |
| self.assertEqual(grad_weight_functional, grad_weight_autograd) |
| |
| def test_functional_grad_conv2d(self): |
| BATCH_SIZE = 4 |
| IN_CH = 8 |
| OUT_CH = 16 |
| SPATIAL = 32 |
| |
| def _test_conv2d(stride, kernel_size, groups, dilation): |
| padding = kernel_size // 2 |
| |
| input = ( |
| torch.empty(BATCH_SIZE, IN_CH, SPATIAL, SPATIAL) |
| .uniform_(-8.0, 8.0) |
| .requires_grad_(True) |
| ) |
| |
| weight = ( |
| torch.empty(OUT_CH, IN_CH // groups, kernel_size, kernel_size) |
| .uniform_(-4.0, 4.0) |
| .requires_grad_(True) |
| ) |
| |
| output = F.conv2d( |
| input, |
| weight, |
| stride=stride, |
| padding=padding, |
| dilation=dilation, |
| groups=groups, |
| ) |
| |
| grad_output = torch.randn(output.shape) |
| |
| (grad_input_autograd, grad_weight_autograd) = torch.autograd.grad( |
| output, (input, weight), grad_output |
| ) |
| |
| grad_input_functional = torch.nn.grad.conv2d_input( |
| input.shape, |
| weight, |
| grad_output, |
| stride=stride, |
| padding=padding, |
| dilation=dilation, |
| groups=groups, |
| ) |
| self.assertEqual(grad_input_functional, grad_input_autograd) |
| |
| grad_weight_functional = torch.nn.grad.conv2d_weight( |
| input, |
| weight.shape, |
| grad_output, |
| stride=stride, |
| padding=padding, |
| dilation=dilation, |
| groups=groups, |
| ) |
| self.assertEqual(grad_weight_functional, grad_weight_autograd) |
| |
| strides = [1, 2] |
| kernel_sizes = [1, 3, 5] |
| groups = [1, 2, 4] |
| dilates = [1, 2] |
| |
| for s, k, g, d in product(strides, kernel_sizes, groups, dilates): |
| _test_conv2d(s, k, g, d) |
| |
| def test_permute_conv2d_issue_120211(self): |
| def reproducer(radius: int): |
| image = torch.rand(1, 1024, 1024, 3) |
| image = image.permute(0, 3, 1, 2) |
| kernel_x = torch.zeros([3, 1, 1, radius * 2 + 1], device=image.device) |
| image = torch.nn.functional.conv2d(image, kernel_x, groups=image.shape[-3]) |
| |
| for i in range(0, 128): |
| # This should not fail |
| reproducer(radius=i) |
| |
| def test_conv3d_issue_120406(self): |
| # This should not fail |
| F.conv3d(torch.ones(2, 3, 8, 9, 26), torch.ones(3, 1, 1, 1, 17), groups=3) |
| |
| def test_conv1d_issue_120547(self): |
| weight = torch.ones([16, 1, 32]) |
| bias = torch.ones([16]) |
| stride, padding, dilation, groups = (1, 16, 1, 16) |
| input = torch.rand((1, 1, 16)) |
| input = input.transpose(1, 2) |
| # This should not fail |
| F.conv1d(input, weight, bias, stride, padding, dilation, groups) |
| |
| |
| class TestConvolutionNNDeviceType(NNTestCase): |
| def run_conv_double_back_test( |
| self, |
| kern, |
| stride, |
| padding, |
| chan_in, |
| chan_out, |
| batch_size, |
| inp_size, |
| dilation, |
| no_weight, |
| groups=1, |
| use_cuda=False, |
| use_bias=True, |
| dtype=torch.double, |
| ): |
| if use_cuda: |
| device = torch.device("cuda") |
| else: |
| device = torch.device("cpu") |
| |
| x = torch.randn( |
| batch_size, |
| chan_in, |
| inp_size, |
| inp_size, |
| device=device, |
| dtype=dtype, |
| requires_grad=True, |
| ) |
| weight = torch.randn( |
| chan_out, |
| chan_in // groups, |
| kern, |
| kern, |
| device=device, |
| dtype=dtype, |
| requires_grad=not no_weight, |
| ) |
| if use_bias: |
| bias = torch.randn(chan_out, device=device, dtype=dtype, requires_grad=True) |
| else: |
| bias = None |
| |
| def func(*inputs): |
| if use_bias: |
| lx, lweight, lbias = inputs |
| else: |
| lx, lweight = inputs |
| lbias = None |
| # We disable cudnn during forward to avoid finite difference imprecision issues |
| with cudnn.flags(enabled=False): |
| out = F.conv2d(lx, lweight, lbias, stride, padding, dilation, groups) |
| return out |
| |
| if use_bias: |
| inputs = x, weight, bias |
| else: |
| inputs = x, weight |
| |
| dummy_out = func(*inputs) |
| grad_y = torch.randn_like( |
| dummy_out, device=device, dtype=dtype, requires_grad=True |
| ) |
| |
| # Issue #15353: test mkldnn double backward, don't run gradgradcheck due |
| # to imprecision issues |
| if dtype == torch.float: |
| (g,) = torch.autograd.grad(dummy_out.sum(), x, create_graph=True) |
| return g.requires_grad |
| |
| return gradgradcheck(func, inputs, (grad_y,)) |
| |
| @onlyCUDA |
| @skipCUDAIfNoCudnn |
| @dtypes( |
| *floating_and_complex_types_and( |
| torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else [] |
| ) |
| ) |
| def test_Conv2d_deterministic_cudnn(self, device, dtype): |
| inputs = torch.randn(2, 3, 5, 5, device=device, dtype=dtype, requires_grad=True) |
| with cudnn.flags(enabled=True, benchmark=True, deterministic=True): |
| conv1 = torch.nn.Conv2d(3, 3, 3).to(device, dtype) |
| conv2 = torch.nn.Conv2d(3, 3, 3).to(device, dtype) |
| conv2.bias.data.copy_(conv1.bias.data) |
| conv2.weight.data.copy_(conv1.weight.data) |
| out1 = conv1(inputs) |
| out2 = conv2(inputs) |
| self.assertEqual(out1, out2, atol=0.0, rtol=0) |
| y = torch.randn(out1.size(), device=device, dtype=dtype) |
| out1.backward(y) |
| out2.backward(y) |
| self.assertEqual( |
| conv1.bias.grad.data, conv2.bias.grad.data, atol=0.0, rtol=0 |
| ) |
| self.assertEqual( |
| conv1.weight.grad.data, conv2.weight.grad.data, atol=0.0, rtol=0 |
| ) |
| |
| @onlyCUDA |
| @dtypes( |
| *floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []) |
| ) |
| def test_Conv2d_large_workspace(self, device, dtype): |
| # These sizes require huge cuDNN workspaces. Make sure we choose a |
| # reasonable algorithm that does not run out of memory |
| sizes = [ |
| (1, 256, 109, 175), |
| (1, 256, 80, 128), |
| (1, 256, 120, 192), |
| ] |
| |
| def run_test(benchmark): |
| with torch.backends.cudnn.flags(enabled=True, benchmark=benchmark): |
| conv = torch.nn.Conv2d(256, 256, kernel_size=3, padding=1).to( |
| device, dtype |
| ) |
| for size in sizes: |
| x = torch.randn(size, device=device, dtype=dtype) |
| out = conv(x.detach().clone().requires_grad_()) |
| out.backward(torch.ones_like(out)) |
| |
| run_test(benchmark=False) |
| run_test(benchmark=True) |
| |
| @onlyCUDA |
| @dtypes(torch.half, torch.float) |
| def test_ConvTranspose2d_large_output_padding(self, device, dtype): |
| net1 = torch.nn.ConvTranspose2d( |
| 128, 64, kernel_size=3, stride=2, padding=1, output_padding=1 |
| ).to(device=device, dtype=dtype) |
| net2 = torch.nn.ConvTranspose2d( |
| 64, 32, kernel_size=3, stride=2, padding=1, output_padding=1 |
| ).to(device=device, dtype=dtype) |
| net3 = torch.nn.ConvTranspose2d( |
| 32, 3, kernel_size=3, stride=2, padding=1, output_padding=1 |
| ).to(device=device, dtype=dtype) |
| x = torch.rand(1, 128, 6, 6, device=device, dtype=dtype, requires_grad=True) |
| x = net1(x) |
| x = net2(x) |
| x = net3(x) |
| x.backward(torch.randn_like(x)) |
| torch.cuda.synchronize() |
| |
| @onlyCUDA |
| @dtypes(torch.float, torch.double, torch.half) |
| # Very similar to test_Conv2d_naive_groups but with special care to handle |
| # the number of groups == number of input channels |
| @torch.backends.cudnn.flags(enabled=True, benchmark=False) |
| @tf32_on_and_off(0.01) |
| def test_Conv2d_depthwise_naive_groups(self, device, dtype): |
| for depth_multiplier in [1, 2]: |
| m = nn.Conv2d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to( |
| device, dtype |
| ) |
| i = ( |
| torch.randn(2, 2, 6, 6, device="cuda", dtype=dtype) |
| .div_(2) |
| .requires_grad_() |
| ) |
| output = m(i) |
| grad_output = ( |
| torch.randn(2, 2 * depth_multiplier, 4, 4, device=device, dtype=dtype) |
| / 2 |
| ) |
| output.backward(grad_output) |
| |
| offset = 1 * depth_multiplier |
| |
| m1 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) |
| m1.weight.data = m.weight.data[:offset].clone() |
| m1.bias.data = m.bias.data[:offset].clone() |
| i1 = i.detach()[:, :1].clone().requires_grad_() |
| output1 = m1(i1) |
| output1.backward(grad_output[:, :offset].contiguous()) |
| |
| m2 = nn.Conv2d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) |
| m2.weight.data.copy_(m.weight.data[offset:]) |
| m2.bias.data.copy_(m.bias.data[offset:]) |
| i2 = i.detach()[:, 1:].clone().requires_grad_() |
| output2 = m2(i2) |
| output2.backward(grad_output[:, offset:].contiguous()) |
| |
| self.assertEqual( |
| output, |
| torch.cat([output1, output2], 1), |
| atol=dtype2prec_DONTUSE[dtype], |
| rtol=0, |
| ) |
| self.assertEqual( |
| i.grad.data, |
| torch.cat([i1.grad.data, i2.grad.data], 1), |
| atol=dtype2prec_DONTUSE[dtype], |
| rtol=0, |
| ) |
| self.assertEqual( |
| m.bias.grad.data, |
| torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0), |
| atol=dtype2prec_DONTUSE[dtype], |
| rtol=0, |
| ) |
| self.assertEqual( |
| m.weight.grad.data, |
| torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), |
| atol=dtype2prec_DONTUSE[dtype], |
| rtol=0, |
| ) |
| |
| @onlyCUDA |
| @dtypes(torch.float, torch.double, torch.half) |
| @torch.backends.cudnn.flags(enabled=True, benchmark=False) |
| @tf32_on_and_off(0.01) |
| def test_Conv3d_depthwise_naive_groups(self, device, dtype): |
| for depth_multiplier in [1, 2]: |
| m = nn.Conv3d(2, 2 * depth_multiplier, kernel_size=3, groups=2).to( |
| device, dtype |
| ) |
| i = ( |
| torch.randn(2, 2, 6, 6, 6, device="cuda", dtype=dtype) |
| .div_(2) |
| .requires_grad_() |
| ) |
| output = m(i) |
| grad_output = ( |
| torch.randn( |
| 2, 2 * depth_multiplier, 4, 4, 4, device=device, dtype=dtype |
| ) |
| / 2 |
| ) |
| output.backward(grad_output) |
| |
| offset = 1 * depth_multiplier |
| |
| m1 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) |
| m1.weight.data = m.weight.data[:offset].clone() |
| m1.bias.data = m.bias.data[:offset].clone() |
| i1 = i.detach()[:, :1].clone().requires_grad_() |
| output1 = m1(i1) |
| output1.backward(grad_output[:, :offset].contiguous()) |
| |
| m2 = nn.Conv3d(1, 1 * depth_multiplier, kernel_size=3).to(device, dtype) |
| m2.weight.data.copy_(m.weight.data[offset:]) |
| m2.bias.data.copy_(m.bias.data[offset:]) |
| i2 = i.detach()[:, 1:].clone().requires_grad_() |
| output2 = m2(i2) |
| output2.backward(grad_output[:, offset:].contiguous()) |
| is_cuda_sm86 = device.startswith( |
| "cuda" |
| ) and torch.cuda.get_device_capability(0) == (8, 6) |
| atol, rtol = ( |
| (3e-4, 3e-2) |
| if dtype == torch.float32 and is_cuda_sm86 |
| else (dtype2prec_DONTUSE[dtype], 0) |
| ) |
| |
| self.assertEqual( |
| output, torch.cat([output1, output2], 1), atol=atol, rtol=rtol |
| ) |
| self.assertEqual( |
| i.grad.data, |
| torch.cat([i1.grad.data, i2.grad.data], 1), |
| atol=dtype2prec_DONTUSE[dtype], |
| rtol=0, |
| ) |
| self.assertEqual( |
| m.bias.grad.data, |
| torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0), |
| atol=dtype2prec_DONTUSE[dtype], |
| rtol=0, |
| ) |
| self.assertEqual( |
| m.weight.grad.data, |
| torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), |
| atol=atol, |
| rtol=rtol, |
| ) |
| |
| @onlyCUDA |
| @dtypes( |
| *floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []) |
| ) |
| def test_noncontig_conv_grad(self, device, dtype): |
| # FIXME: remove after adding non-contiguous grad tests for all modules |
| module = nn.Conv2d(3, 5, kernel_size=3, padding=1).to(device, dtype) |
| input = torch.randn( |
| 2, 3, 10, 10, dtype=dtype, device=device, requires_grad=True |
| ) |
| output = module(input) |
| |
| grad = torch.randn(2, 2, 5, 10, 10, dtype=dtype, device=device)[:, 1] |
| assert not grad.is_contiguous() |
| output.backward(grad, retain_graph=True) |
| self.assertIsNotNone(input.grad) |
| result = input.grad.data.clone() |
| input.grad.data.zero_() |
| |
| output.backward(grad.contiguous()) |
| self.assertEqual( |
| result, input.grad.data, atol=dtype2prec_DONTUSE[dtype], rtol=0 |
| ) |
| |
| @onlyCUDA |
| @dtypes(torch.double) |
| def test_conv_double_backward(self, device, dtype): |
| with torch.backends.cudnn.flags(enabled=True, deterministic=True): |
| # Double backward only runs with DoubleTensor due to precision reason |
| batch_size = 1 |
| for kern, inp_size, dilations in [(3, 5, [1, 2]), (4, 9, [1])]: |
| for stride, padding, chan_in, chan_out, dilation in product( |
| [1], [2], [2], [3], dilations |
| ): |
| no_weight = stride == 2 |
| result = self.run_conv_double_back_test( |
| kern, |
| stride, |
| padding, |
| chan_in, |
| chan_out, |
| batch_size, |
| inp_size, |
| dilation, |
| no_weight, |
| use_cuda=True, |
| dtype=dtype, |
| ) |
| self.assertTrue( |
| result, |
| "Conv double backward test failed with parameters:" |
| + "\nkern: " |
| + str(kern) |
| + "\nstride: " |
| + str(stride) |
| + "\npadding: " |
| + str(padding) |
| + "\nchan_in: " |
| + str(chan_in) |
| + "\nchan_out: " |
| + str(chan_out) |
| + "\nbatch_size: " |
| + str(batch_size) |
| + "\ninp_size: " |
| + str(inp_size) |
| + "\ndilation: " |
| + str(dilation), |
| ) |
| |
| def test_conv_double_backward_no_bias(self): |
| kern = 3 |
| stride = 2 |
| chan_in, chan_out = 2, 4 |
| batch_size = 2 |
| inp_size = 5 |
| padding = 1 |
| dilation = 1 |
| no_weight = False |
| use_bias = True |
| result = self.run_conv_double_back_test( |
| kern, |
| stride, |
| padding, |
| chan_in, |
| chan_out, |
| batch_size, |
| inp_size, |
| dilation, |
| no_weight, |
| use_bias=use_bias, |
| ) |
| self.assertTrue( |
| result, |
| "Conv double backward test failed with parameters:" |
| + "\nkern: " |
| + str(kern) |
| + "\nstride: " |
| + str(stride) |
| + "\npadding: " |
| + str(padding) |
| + "\nchan_in: " |
| + str(chan_in) |
| + "\nchan_out: " |
| + str(chan_out) |
| + "\nbatch_size: " |
| + str(batch_size) |
| + "\ninp_size: " |
| + str(inp_size) |
| + "\ndilation: " |
| + str(dilation), |
| ) |
| |
| def test_conv_double_backward_groups(self): |
| kern = 3 |
| stride = 1 |
| padding = 2 |
| chan_in, chan_out = 2, 4 |
| batch_size = 2 |
| inp_size = 6 |
| dilation = 1 |
| no_weight = False |
| groups = 2 |
| result = self.run_conv_double_back_test( |
| kern, |
| stride, |
| padding, |
| chan_in * groups, |
| chan_out * groups, |
| batch_size, |
| inp_size, |
| dilation, |
| no_weight, |
| groups=groups, |
| ) |
| self.assertTrue( |
| result, |
| "Conv double backward test failed with parameters:" |
| + "\nkern: " |
| + str(kern) |
| + "\nstride: " |
| + str(stride) |
| + "\npadding: " |
| + str(padding) |
| + "\nchan_in: " |
| + str(chan_in) |
| + "\nchan_out: " |
| + str(chan_out) |
| + "\nbatch_size: " |
| + str(batch_size) |
| + "\ninp_size: " |
| + str(inp_size) |
| + "\ndilation: " |
| + str(dilation) |
| + "\ngroups: " |
| + str(groups), |
| ) |
| |
| def test_conv_double_backward_stride(self): |
| batch_size = 2 |
| |
| # Cannot provide ggW when stride is > 1 |
| for kern, inp_size, dilations in [(3, 5, [1, 2]), (3, 7, [1])]: |
| for stride, padding, chan_in, chan_out, dilation in product( |
| [2], [0, 1], [1], [2], dilations |
| ): |
| no_weight = False |
| self.run_conv_double_back_test( |
| kern, |
| stride, |
| padding, |
| chan_in, |
| chan_out, |
| batch_size, |
| inp_size, |
| dilation, |
| no_weight, |
| ) |
| |
| @dtypes(torch.float, torch.cfloat) |
| @torch.backends.cudnn.flags(enabled=True, benchmark=False) |
| def test_conv1d_same_padding(self, device, dtype): |
| # Test padding='same' outputs the correct shape |
| test_args = [ |
| # in_size |
| range(50, 55), |
| # kernel_size |
| [1, 2, 3, 8], |
| # dilation |
| range(1, 4), |
| # stride |
| [1], |
| ] |
| for in_size, k_size, dilation, stride in itertools.product(*test_args): |
| x = torch.rand(1, 1, in_size, device=device, dtype=dtype) |
| y = torch.rand(1, 1, k_size, device=device, dtype=dtype) |
| z = F.conv1d(x, y, padding="same", dilation=dilation, stride=stride) |
| self.assertEqual(z.size(2), int(math.ceil(in_size / stride))) |
| |
| # Compare F.conv1d padding='same' output against manual padding |
| # Without strides/dilation |
| x = torch.rand(1, 1, 12, device=device, dtype=dtype) |
| y = torch.rand(1, 1, 3, device=device, dtype=dtype) |
| expect = F.conv1d(x, y, padding=1) |
| actual = F.conv1d(x, y, padding="same") |
| self.assertEqual(expect, actual) |
| |
| # With dilation |
| x = torch.rand(1, 1, 12, device=device, dtype=dtype) |
| y = torch.rand(1, 1, 4, device=device, dtype=dtype) |
| expect = F.conv1d(x, y, padding=3, dilation=2) |
| actual = F.conv1d(x, y, padding="same", dilation=2) |
| self.assertEqual(expect, actual) |
| |
| # Dilation with asymmetric padding |
| expect = F.conv1d(x, y, padding=5, dilation=3)[..., 1:] |
| actual = F.conv1d(x, y, padding="same", dilation=3) |
| self.assertEqual(expect, actual) |
| |
| @dtypes(torch.float, torch.cfloat) |
| def test_conv2d_same_padding(self, device, dtype): |
| if dtype is torch.cfloat: |
| rtol, atol = 2e-6, 2e-6 |
| else: |
| rtol, atol = None, None |
| # Compare F.conv2d padding='same' output against manual padding |
| # Without strides/dilation |
| x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype) |
| y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype) |
| expect = F.conv2d(x, y, padding=(2, 2))[..., 1:, :] |
| actual = F.conv2d(x, y, padding="same") |
| self.assertEqual(expect, actual, rtol=rtol, atol=atol) |
| |
| # With dilation |
| y = torch.rand(1, 1, 3, 4, device=device, dtype=dtype) |
| expect = F.conv2d(x, y, padding=(2, 3), dilation=2) |
| actual = F.conv2d(x, y, padding="same", dilation=2) |
| self.assertEqual(expect, actual, rtol=rtol, atol=atol) |
| |
| # Dilation with asymmetric padding |
| y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype) |
| expect = F.conv2d(x, y, padding=5, dilation=3)[..., 1:, 1:] |
| actual = F.conv2d(x, y, padding="same", dilation=3) |
| self.assertEqual(expect, actual, rtol=rtol, atol=atol) |
| |
| @dtypes(torch.float, torch.cfloat) |
| def test_conv3d_same_padding(self, device, dtype): |
| if dtype is torch.cfloat: |
| rtol, atol = 2e-6, 2e-6 |
| else: |
| rtol, atol = None, None |
| # Compare F.conv3d padding='same' output against manual padding |
| # Without strides/dilation |
| x = torch.rand(1, 1, 10, 11, 12, device=device, dtype=dtype) |
| y = torch.rand(1, 1, 1, 2, 5, device=device, dtype=dtype) |
| expect = F.conv3d(x, y, padding=(0, 1, 2))[..., :, 1:, :] |
| actual = F.conv3d(x, y, padding="same") |
| self.assertEqual(expect, actual, rtol=rtol, atol=atol) |
| |
| # With dilation |
| expect = F.conv3d(x, y, padding=(0, 1, 4), dilation=2) |
| actual = F.conv3d(x, y, padding="same", dilation=2) |
| self.assertEqual(expect, actual, rtol=rtol, atol=atol) |
| |
| # Dilation with asymmetric padding |
| y = torch.rand(1, 1, 4, 4, 4, device=device, dtype=dtype) |
| expect = F.conv3d(x, y, padding=5, dilation=3)[..., 1:, 1:, 1:] |
| actual = F.conv3d(x, y, padding="same", dilation=3) |
| self.assertEqual(expect, actual, rtol=rtol, atol=atol) |
| |
| @dtypes(torch.float, torch.cfloat) |
| def test_conv1d_valid_padding(self, device, dtype): |
| # Test F.conv1d padding='valid' is the same as no padding |
| x = torch.rand(1, 1, 10, device=device, dtype=dtype) |
| y = torch.rand(1, 1, 4, device=device, dtype=dtype) |
| expect = F.conv1d(x, y) |
| actual = F.conv1d(x, y, padding="valid") |
| self.assertEqual(expect, actual) |
| |
| @dtypes(torch.float, torch.cfloat) |
| def test_conv2d_valid_padding(self, device, dtype): |
| # Test F.conv2d padding='valid' is the same as no padding |
| x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype) |
| y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype) |
| expect = F.conv2d(x, y) |
| actual = F.conv2d(x, y, padding="valid") |
| self.assertEqual(expect, actual) |
| |
| @dtypes(torch.float, torch.cfloat) |
| def test_conv3d_valid_padding(self, device, dtype): |
| # Test F.conv3d padding='valid' is the same as no padding |
| x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device) |
| y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device) |
| expect = F.conv3d(x, y) |
| actual = F.conv3d(x, y, padding="valid") |
| self.assertEqual(expect, actual) |
| |
| @dtypes(torch.float, torch.cfloat) |
| def test_conv1d_same_padding_backward(self, device, dtype): |
| # Test F.conv1d gradients work with padding='same' |
| x = torch.rand(1, 1, 12, dtype=dtype, device=device, requires_grad=True) |
| y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True) |
| |
| # Symmetric padding |
| z = F.conv1d(x, y, padding=3, dilation=2) |
| z.sum().abs().backward() |
| gx_expect, gy_expect = x.grad, y.grad |
| x.grad, y.grad = None, None |
| |
| z = F.conv1d(x, y, padding="same", dilation=2) |
| z.sum().abs().backward() |
| self.assertEqual(gx_expect, x.grad) |
| self.assertEqual(gy_expect, y.grad) |
| x.grad, y.grad = None, None |
| |
| # Asymmetric padding |
| z = F.conv1d(x, y, padding=2)[..., 1:] |
| z.sum().abs().backward() |
| gx_expect, gy_expect = x.grad, y.grad |
| x.grad, y.grad = None, None |
| |
| z = F.conv1d(x, y, padding="same") |
| z.sum().abs().backward() |
| self.assertEqual(gx_expect, x.grad) |
| self.assertEqual(gy_expect, y.grad) |
| |
| @dtypes(torch.float, torch.cfloat) |
| @tf32_on_and_off(0.001) |
| def test_conv2d_same_padding_backward(self, device, dtype): |
| # Test F.conv2d gradients work with padding='same' |
| x = torch.rand(1, 1, 10, 11, device=device, dtype=dtype, requires_grad=True) |
| y = torch.rand(1, 1, 4, 5, device=device, dtype=dtype, requires_grad=True) |
| |
| # Symmetric padding |
| z = F.conv2d(x, y, padding=(3, 4), dilation=2) |
| z.sum().abs().backward() |
| gx_expect, gy_expect = x.grad, y.grad |
| x.grad, y.grad = None, None |
| |
| z = F.conv2d(x, y, padding="same", dilation=2) |
| z.sum().abs().backward() |
| self.assertEqual(gx_expect, x.grad) |
| self.assertEqual(gy_expect, y.grad) |
| x.grad, y.grad = None, None |
| |
| # Asymmetric padding |
| y = torch.rand(1, 1, 4, 4, device=device, dtype=dtype, requires_grad=True) |
| z = F.conv2d(x, y, padding=2)[..., 1:, 1:] |
| z.sum().abs().backward() |
| gx_expect, gy_expect = x.grad, y.grad |
| x.grad, y.grad = None, None |
| |
| z = F.conv2d(x, y, padding="same") |
| z.sum().abs().backward() |
| self.assertEqual(gx_expect, x.grad) |
| self.assertEqual(gy_expect, y.grad) |
| |
| @dtypes(torch.double, torch.cdouble) |
| def test_conv3d_same_padding_backward(self, device, dtype): |
| check_forward_ad = torch.device(device).type != "xla" |
| |
| # Test F.conv3d gradients work with padding='same' |
| x = torch.rand(1, 1, 1, 11, 12, dtype=dtype, device=device, requires_grad=True) |
| y = torch.rand(1, 1, 1, 2, 5, dtype=dtype, device=device, requires_grad=True) |
| |
| # Symmetric padding |
| z = F.conv3d(x, y, padding=(0, 1, 4), dilation=2) |
| z.sum().abs().backward() |
| gx_expect, gy_expect = x.grad, y.grad |
| x.grad, y.grad = None, None |
| |
| z = F.conv3d(x, y, padding="same", dilation=2) |
| z.sum().abs().backward() |
| self.assertEqual(gx_expect, x.grad) |
| self.assertEqual(gy_expect, y.grad) |
| x.grad, y.grad = None, None |
| |
| gradcheck( |
| lambda x, y: F.conv3d(x, y, padding="same", dilation=2), |
| (x, y), |
| check_forward_ad=check_forward_ad, |
| nondet_tol=1e-5, |
| ) |
| if torch.device(device).type != "cuda": |
| # https://github.com/pytorch/pytorch/issues/70702 |
| gradgradcheck( |
| lambda x, y: F.conv3d(x, y, padding="same", dilation=2), |
| (x, y), |
| check_fwd_over_rev=True, |
| ) |
| |
| # Asymmetric padding |
| y = torch.rand(1, 1, 1, 4, 4, dtype=dtype, device=device, requires_grad=True) |
| z = F.conv3d(x, y, padding=2)[..., 1:, 1:] |
| z.sum().abs().backward() |
| gx_expect, gy_expect = x.grad, y.grad |
| x.grad, y.grad = None, None |
| |
| z = F.conv3d(x, y, padding="same") |
| z.sum().abs().backward() |
| self.assertEqual(gx_expect, x.grad) |
| self.assertEqual(gy_expect, y.grad) |
| |
| gradcheck( |
| lambda x, y: F.conv3d(x, y, padding="same"), |
| (x, y), |
| check_forward_ad=check_forward_ad, |
| nondet_tol=1e-5, |
| ) |
| if torch.device(device).type != "cuda": |
| # https://github.com/pytorch/pytorch/issues/70702 |
| gradgradcheck( |
| lambda x, y: F.conv3d(x, y, padding="same"), |
| (x, y), |
| check_fwd_over_rev=True, |
| ) |
| |
| @dtypes(torch.float, torch.cfloat) |
| def test_conv1d_valid_padding_backward(self, device, dtype): |
| # Test F.conv1d gradients work with padding='valid' |
| x = torch.rand(1, 1, 10, dtype=dtype, device=device, requires_grad=True) |
| y = torch.rand(1, 1, 4, dtype=dtype, device=device, requires_grad=True) |
| F.conv1d(x, y, padding=0).sum().abs().backward() |
| gx_expect, gy_expect = x.grad, y.grad |
| x.grad, y.grad = None, None |
| |
| F.conv1d(x, y, padding="valid").sum().abs().backward() |
| gx_actual, gy_actual = x.grad, y.grad |
| self.assertEqual(gx_expect, gx_actual) |
| self.assertEqual(gy_expect, gy_actual) |
| |
| @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.") |
| @dtypes(torch.float, torch.cfloat) |
| @parametrize_test("mode", ("valid", "same")) |
| def test_conv1d_vs_scipy(self, device, dtype, mode): |
| t = make_tensor((1, 10), device=device, dtype=dtype) |
| feat_dim = t.shape[1] |
| weight_even = make_tensor((1, 1, 4), device=device, dtype=dtype) |
| weight_odd = make_tensor((1, 1, 5), device=device, dtype=dtype) |
| |
| def _test(t, weight, mode): |
| # SciPy expects two 1-D inputs. |
| t_a = t.view(-1).cpu().numpy() |
| w_a = weight.view(-1).cpu().numpy() |
| expected = scipy.signal.convolve(t_a, w_a, mode=mode) |
| |
| kwargs = {"padding": mode} |
| if mode == "same": |
| # `same` padding in PyTorch conv1d is different |
| # from SciPy |
| p = weight.shape[2] // 2 |
| t = torch.nn.functional.pad(t, (p, p)) |
| # We have already taken care of padding |
| kwargs.pop("padding") |
| |
| # second input is flipped in SciPy's convolve |
| weight_flipped = torch.flip(weight, (2,)) |
| actual = torch.nn.functional.conv1d(t, weight_flipped, **kwargs).squeeze(0) |
| if mode == "same": |
| actual = actual[:feat_dim] |
| |
| self.assertEqual(actual, expected, atol=2e-5, rtol=2e-5) |
| |
| # Global dtype for this test suite is torch.double |
| # This leads to change in type-promotion |
| # and conv1d outputs `complex128` for `complex64` input. |
| with set_default_dtype(torch.float): |
| _test(t, weight_even, mode) |
| _test(t, weight_odd, mode) |
| |
| @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.") |
| @dtypes(torch.float, torch.cfloat) |
| @parametrize_test("mode", ("valid", "same")) |
| def test_conv2d_vs_scipy(self, device, dtype, mode): |
| t = make_tensor((1, 5, 10), device=device, dtype=dtype) |
| weight_even = make_tensor((1, 1, 2, 4), device=device, dtype=dtype) |
| weight_odd = make_tensor((1, 1, 3, 5), device=device, dtype=dtype) |
| |
| def _test(t, weight, mode): |
| # SciPy expects two 2-D inputs. |
| t_a = t.squeeze(0).cpu().numpy() |
| w_a = weight.squeeze(0).squeeze(0).cpu().numpy() |
| expected = scipy.signal.convolve2d(t_a, w_a, mode=mode) |
| |
| kwargs = {"padding": mode} |
| if mode == "same": |
| # `same` padding in PyTorch conv2d is different |
| # from SciPy |
| left_right_pad = weight.shape[3] // 2 |
| top_bottom_pad = weight.shape[2] // 2 |
| p = (left_right_pad, left_right_pad, top_bottom_pad, top_bottom_pad) |
| t = torch.nn.functional.pad(t, p) |
| # We have already taken care of padding |
| kwargs.pop("padding") |
| |
| # second input is flipped in SciPy's convolve2d |
| weight_flipped = torch.flip(weight, (2, 3)) |
| actual = torch.nn.functional.conv2d(t, weight_flipped, **kwargs).squeeze(0) |
| if mode == "same": |
| actual = actual[:5, :10] |
| |
| self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6) |
| |
| # Global dtype for this test suite is torch.double |
| # This leads to change in type-promotion |
| # and conv1d outputs `complex128` for `complex64` input. |
| with set_default_dtype(torch.float): |
| _test(t, weight_even, mode) |
| _test(t, weight_odd, mode) |
| |
| @unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.") |
| @dtypes(torch.float, torch.cfloat) |
| @parametrize_test("mode", ("valid", "same")) |
| def test_conv3d_vs_scipy(self, device, dtype, mode): |
| t = make_tensor((1, 5, 5, 10), device=device, dtype=dtype) |
| weight_even = make_tensor((1, 1, 2, 2, 4), device=device, dtype=dtype) |
| weight_odd = make_tensor((1, 1, 2, 3, 5), device=device, dtype=dtype) |
| |
| def _test(t, weight, mode): |
| # SciPy expects two 3-D inputs. |
| t_a = t.squeeze(0).cpu().numpy() |
| w_a = weight.squeeze(0).squeeze(0).cpu().numpy() |
| expected = scipy.signal.convolve(t_a, w_a, mode=mode) |
| |
| kwargs = {"padding": mode} |
| if mode == "same": |
| # `same` padding in PyTorch conv3d is different |
| # from SciPy |
| left_right_pad = weight.shape[4] // 2 |
| top_bottom_pad = weight.shape[3] // 2 |
| front_back_pad = weight.shape[2] // 2 |
| p = ( |
| left_right_pad, |
| left_right_pad, |
| top_bottom_pad, |
| top_bottom_pad, |
| front_back_pad, |
| front_back_pad, |
| ) |
| t = torch.nn.functional.pad(t, p) |
| # We have already taken care of padding |
| kwargs.pop("padding") |
| |
| # second input is flipped in SciPy's convolve |
| weight_flipped = torch.flip(weight, (2, 3, 4)) |
| actual = torch.nn.functional.conv3d(t, weight_flipped, **kwargs).squeeze(0) |
| if mode == "same": |
| actual = actual[:5, :5, :10] |
| |
| if tf32_is_not_fp32() and ( |
| dtype == torch.float or dtype == torch.complex64 |
| ): |
| self.assertEqual(actual, expected, atol=0.05, rtol=0.05) |
| else: |
| self.assertEqual(actual, expected, rtol=2e-5, atol=5e-6) |
| |
| # Global dtype for this test suite is torch.double |
| # This leads to change in type-promotion |
| # and conv1d outputs `complex128` for `complex64` input. |
| with set_default_dtype(torch.float): |
| _test(t, weight_even, mode) |
| _test(t, weight_odd, mode) |
| |
| @dtypes(torch.float, torch.complex64) |
| def test_conv2d_valid_padding_backward(self, device, dtype): |
| # Test F.conv2d gradients work with padding='valid' |
| x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype, requires_grad=True) |
| y = torch.rand(1, 1, 1, 4, device=device, dtype=dtype, requires_grad=True) |
| F.conv2d(x, y, padding=0).sum().abs().backward() |
| gx_expect, gy_expect = x.grad, y.grad |
| x.grad, y.grad = None, None |
| |
| F.conv2d(x, y, padding="valid").sum().abs().backward() |
| gx_actual, gy_actual = x.grad, y.grad |
| self.assertEqual(gx_expect, gx_actual) |
| self.assertEqual(gy_expect, gy_actual) |
| |
| @dtypes(torch.double, torch.cdouble) |
| def test_conv3d_valid_padding_backward(self, device, dtype): |
| check_forward_ad = torch.device(device).type != "xla" |
| |
| # Test F.conv3d gradients work with padding='valid' |
| x = torch.rand(1, 1, 1, 1, 10, dtype=dtype, device=device, requires_grad=True) |
| y = torch.rand(1, 1, 1, 1, 4, dtype=dtype, device=device, requires_grad=True) |
| F.conv3d(x, y, padding=0).sum().abs().backward() |
| gx_expect, gy_expect = x.grad, y.grad |
| x.grad, y.grad = None, None |
| |
| F.conv3d(x, y, padding="valid").sum().abs().backward() |
| gx_actual, gy_actual = x.grad, y.grad |
| self.assertEqual(gx_expect, gx_actual) |
| self.assertEqual(gy_expect, gy_actual) |
| |
| gradcheck( |
| lambda x, y: F.conv3d(x, y, padding="valid"), |
| (x, y), |
| check_forward_ad=check_forward_ad, |
| ) |
| gradgradcheck( |
| lambda x, y: F.conv3d(x, y, padding="valid"), |
| (x, y), |
| check_fwd_over_rev=check_forward_ad, |
| ) |
| |
| @parametrize_test("N", range(2, 4), name_fn=lambda N: f"ConvTranspose{N}d") |
| def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N): |
| # For inputs with no batch dim, verify output is the correct shape when output_size is set. |
| # See https://github.com/pytorch/pytorch/issues/75889 |
| inp = torch.randn((1, 15, 13) if N == 2 else (1, 15, 13, 13), device=device) |
| output_size = (1, 240, 200) if N == 2 else (1, 240, 200, 200) |
| ConvTransposeNd = getattr(nn, f"ConvTranspose{N}d") |
| m = ConvTransposeNd( |
| 1, 1, kernel_size=16, stride=16, padding=7, bias=False, device=device |
| ) |
| output = m(inp, output_size=output_size) |
| self.assertEqual(output.shape, output_size) |
| |
| @skipMeta |
| @parametrize_test( |
| "input_shape,transposed,dilated,groups,layout,backend_expected", |
| [ |
| # === slow === |
| subtest( |
| ( |
| (2, 6, 7), |
| False, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Slow2d, |
| ), |
| decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], |
| name="slow1d", |
| ), |
| subtest( |
| ( |
| (2, 6, 7), |
| True, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.SlowTranspose2d, |
| ), |
| decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], |
| name="slow1d_transposed", |
| ), |
| subtest( |
| ( |
| (2, 6, 7), |
| False, |
| True, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.SlowDilated2d, |
| ), |
| decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], |
| name="slow1d_dilated", |
| ), |
| subtest( |
| ( |
| (2, 6, 7), |
| True, |
| True, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.SlowTranspose2d, |
| ), |
| decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], |
| name="slow1d_dilated_transposed", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8), |
| False, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Slow2d, |
| ), |
| decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], |
| name="slow2d", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8), |
| True, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.SlowTranspose2d, |
| ), |
| decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], |
| name="slow2d_transposed", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8), |
| False, |
| True, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.SlowDilated2d, |
| ), |
| decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], |
| name="slow2d_dilated", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8), |
| True, |
| True, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.SlowTranspose2d, |
| ), |
| decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], |
| name="slow2d_dilated_transposed", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8, 9), |
| False, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Slow3d, |
| ), |
| decorators=[onlyCPU, disableMkldnn], |
| name="slow3d_cpu", |
| ), |
| # CUDA doesn't have a slow 3D implementation, so it goes to the dilated 3D implementation instead |
| subtest( |
| ( |
| (2, 6, 7, 8, 9), |
| False, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.SlowDilated3d, |
| ), |
| decorators=[onlyCUDA, disablecuDNN], |
| name="slow3d_cuda", |
| ), |
| # FIXME: RuntimeError: CUDA out of memory. |
| # subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d), |
| # decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_transposed'), |
| subtest( |
| ( |
| (2, 6, 7, 8, 9), |
| False, |
| True, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.SlowDilated3d, |
| ), |
| decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], |
| name="slow3d_dilated", |
| ), |
| # FIXME: RuntimeError: CUDA out of memory. |
| # subtest(((2, 6, 7, 8, 9), True, True, 3, torch.strided, torch._C._ConvBackend.SlowTranspose3d), |
| # decorators=[onlyNativeDeviceTypes, disableMkldnn, disablecuDNN], name='slow3d_dilated_transposed'), |
| subtest( |
| ( |
| (0, 6, 7), |
| False, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Empty, |
| ), |
| decorators=[onlyNativeDeviceTypes, disableMkldnn], |
| name="empty_batch1d", |
| ), |
| subtest( |
| ( |
| (2, 0, 7), |
| False, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Empty, |
| ), |
| decorators=[onlyNativeDeviceTypes, disableMkldnn], |
| name="empty_channel1d", |
| ), |
| subtest( |
| ( |
| (0, 0, 7), |
| False, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Empty, |
| ), |
| decorators=[onlyNativeDeviceTypes, disableMkldnn], |
| name="empty_batch_channel1d", |
| ), |
| subtest( |
| ( |
| (0, 6, 7, 8), |
| False, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Empty, |
| ), |
| decorators=[onlyNativeDeviceTypes, disableMkldnn], |
| name="empty_batch2d", |
| ), |
| subtest( |
| ( |
| (2, 0, 7, 8), |
| False, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Empty, |
| ), |
| decorators=[onlyNativeDeviceTypes, disableMkldnn], |
| name="empty_channel2d", |
| ), |
| subtest( |
| ( |
| (0, 0, 7, 8), |
| False, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Empty, |
| ), |
| decorators=[onlyNativeDeviceTypes, disableMkldnn], |
| name="empty_batch_channel2d", |
| ), |
| subtest( |
| ( |
| (0, 6, 7, 8, 9), |
| False, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Empty, |
| ), |
| decorators=[onlyNativeDeviceTypes, disableMkldnn], |
| name="empty_batch3d", |
| ), |
| subtest( |
| ( |
| (2, 0, 7, 8, 9), |
| False, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Empty, |
| ), |
| decorators=[onlyNativeDeviceTypes, disableMkldnn], |
| name="empty_channel3d", |
| ), |
| subtest( |
| ( |
| (0, 0, 7, 8, 9), |
| False, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Empty, |
| ), |
| decorators=[onlyNativeDeviceTypes, disableMkldnn], |
| name="empty_batch_channel3d", |
| ), |
| # === cuda === |
| # Note that disablecuDNN disables miopen as well. |
| subtest( |
| ( |
| (2, 6, 7), |
| False, |
| False, |
| 6, |
| torch.strided, |
| torch._C._ConvBackend.CudaDepthwise2d, |
| ), |
| decorators=[onlyCUDA, disablecuDNN], |
| name="cuda_depthwise1d", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8), |
| False, |
| False, |
| 6, |
| torch.strided, |
| torch._C._ConvBackend.CudaDepthwise2d, |
| ), |
| decorators=[onlyCUDA, disablecuDNN], |
| name="cuda_depthwise2d", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8, 9), |
| False, |
| False, |
| 6, |
| torch.strided, |
| torch._C._ConvBackend.CudaDepthwise3d, |
| ), |
| decorators=[onlyCUDA, disablecuDNN], |
| name="cuda_depthwise3d", |
| ), |
| # === cudnn === |
| subtest( |
| ( |
| (2, 6, 7), |
| False, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Cudnn, |
| ), |
| decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], |
| name="cudnn1d", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8), |
| False, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Cudnn, |
| ), |
| decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], |
| name="cudnn2d", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8, 9), |
| False, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Cudnn, |
| ), |
| decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], |
| name="cudnn3d", |
| ), |
| subtest( |
| ( |
| (2, 6, 7), |
| True, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.CudnnTranspose, |
| ), |
| decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], |
| name="cudnn1d_transposed", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8), |
| True, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.CudnnTranspose, |
| ), |
| decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], |
| name="cudnn2d_transposed", |
| ), |
| # FIXME: RuntimeError: CUDA out of memory. |
| # subtest(((2, 6, 7, 8, 9), True, False, 3, torch.strided, torch._C._ConvBackend.CudnnTranspose), |
| # decorators=[onlyCUDA, skipCUDAIfNoCudnn, skipCUDAIfMiopen], name='cudnn3d_transposed'), |
| # === miopen === |
| subtest( |
| ( |
| (2, 6, 7), |
| False, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Miopen, |
| ), |
| decorators=[onlyCUDA, skipCUDAIfNoMiopen], |
| name="miopen1d", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8), |
| False, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Miopen, |
| ), |
| decorators=[onlyCUDA, skipCUDAIfNoMiopen], |
| name="miopen2d", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8, 9), |
| False, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Miopen, |
| ), |
| decorators=[onlyCUDA, skipCUDAIfNoMiopen], |
| name="miopen3d", |
| ), |
| subtest( |
| ( |
| (2, 6, 7), |
| True, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.MiopenTranspose, |
| ), |
| decorators=[onlyCUDA, skipCUDAIfNoMiopen], |
| name="miopen1d_transposed", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8), |
| True, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.MiopenTranspose, |
| ), |
| decorators=[onlyCUDA, skipCUDAIfNoMiopen], |
| name="miopen2d_transposed", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8, 9), |
| True, |
| False, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.MiopenTranspose, |
| ), |
| decorators=[onlyCUDA, skipCUDAIfNoMiopen], |
| name="miopen3d_transposed", |
| ), |
| subtest( |
| ( |
| (2, 6, 7), |
| False, |
| False, |
| 6, |
| torch.strided, |
| torch._C._ConvBackend.MiopenDepthwise, |
| ), |
| decorators=[onlyCUDA, skipCUDAIfNoMiopen], |
| name="miopen_depthwise1d", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8), |
| False, |
| False, |
| 6, |
| torch.strided, |
| torch._C._ConvBackend.MiopenDepthwise, |
| ), |
| decorators=[onlyCUDA, skipCUDAIfNoMiopen], |
| name="miopen_depthwise2d", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8, 9), |
| False, |
| False, |
| 6, |
| torch.strided, |
| torch._C._ConvBackend.MiopenDepthwise, |
| ), |
| decorators=[onlyCUDA, skipCUDAIfNoMiopen], |
| name="miopen_depthwise3d", |
| ), |
| # === mkldnn === |
| subtest( |
| ( |
| (2, 6, 7), |
| False, |
| False, |
| 3, |
| torch._mkldnn, |
| torch._C._ConvBackend.Mkldnn, |
| ), |
| decorators=[onlyCPU, skipCPUIfNoMkldnn], |
| name="mkldnn1d", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8), |
| False, |
| False, |
| 3, |
| torch._mkldnn, |
| torch._C._ConvBackend.Mkldnn, |
| ), |
| decorators=[onlyCPU, skipCPUIfNoMkldnn], |
| name="mkldnn2d", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8, 9), |
| False, |
| False, |
| 3, |
| torch._mkldnn, |
| torch._C._ConvBackend.Mkldnn, |
| ), |
| decorators=[onlyCPU, skipCPUIfNoMkldnn], |
| name="mkldnn3d", |
| ), |
| # Transposed convolution is broken for mkldnn. See https://github.com/pytorch/pytorch/issues/68775. |
| subtest( |
| ( |
| (2, 6, 7), |
| True, |
| False, |
| 3, |
| torch._mkldnn, |
| torch._C._ConvBackend.Mkldnn, |
| ), |
| decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure], |
| name="mkldnn1d_transposed", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8), |
| True, |
| False, |
| 3, |
| torch._mkldnn, |
| torch._C._ConvBackend.Mkldnn, |
| ), |
| decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure], |
| name="mkldnn2d_transposed", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8, 9), |
| True, |
| False, |
| 3, |
| torch._mkldnn, |
| torch._C._ConvBackend.Mkldnn, |
| ), |
| decorators=[onlyCPU, skipCPUIfNoMkldnn, unittest.expectedFailure], |
| name="mkldnn3d_transposed", |
| ), |
| subtest( |
| ( |
| (2, 6, 7), |
| False, |
| True, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Mkldnn, |
| ), |
| decorators=[onlyCPU, skipCPUIfNoMkldnn], |
| name="mkldnn1d_cpu_input", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8), |
| False, |
| True, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Mkldnn, |
| ), |
| decorators=[onlyCPU, skipCPUIfNoMkldnn], |
| name="mkldnn2d_cpu_input", |
| ), |
| subtest( |
| ( |
| (2, 6, 7, 8, 9), |
| False, |
| True, |
| 3, |
| torch.strided, |
| torch._C._ConvBackend.Mkldnn, |
| ), |
| decorators=[onlyCPU, skipCPUIfNoMkldnn], |
| name="mkldnn3d_cpu_input", |
| ), |
| subtest( |
| ( |
| (0, 6, 7), |
| False, |
| False, |
| 3, |
| torch._mkldnn, |
| torch._C._ConvBackend.MkldnnEmpty, |
| ), |
| decorators=[onlyCPU, skipCPUIfNoMkldnn], |
| name="mkldnn_empty_batch1d", |
| ), |
| subtest( |
| ( |
| (2, 0, 7), |
| False, |
| False, |
| 3, |
| torch._mkldnn, |
| torch._C._ConvBackend.MkldnnEmpty, |
| ), |
| decorators=[onlyCPU, skipCPUIfNoMkldnn], |
| name="mkldnn_empty_channel1d", |
| ), |
| subtest( |
| ( |
| (0, 0, 7), |
| False, |
| False, |
| 3, |
| torch._mkldnn, |
| torch._C._ConvBackend.MkldnnEmpty, |
| ), |
| decorators=[onlyCPU, skipCPUIfNoMkldnn], |
| name="mkldnn_empty_batch_channel1d", |
| ), |
| subtest( |
| ( |
| (0, 6, 7, 8), |
| False, |
| False, |
| 3, |
| torch._mkldnn, |
| torch._C._ConvBackend.MkldnnEmpty, |
| ), |
| decorators=[onlyCPU, skipCPUIfNoMkldnn], |
| name="mkldnn_empty_batch2d", |
| ), |
| subtest( |
| ( |
| (2, 0, 7, 8), |
| False, |
| False, |
| 3, |
| torch._mkldnn, |
| torch._C._ConvBackend.MkldnnEmpty, |
| ), |
| decorators=[onlyCPU, skipCPUIfNoMkldnn], |
| name="mkldnn_empty_channel2d", |
| ), |
| subtest( |
| ( |
| (0, 0, 7, 8), |
| False, |
| False, |
| 3, |
| torch._mkldnn, |
| torch._C._ConvBackend.MkldnnEmpty, |
| ), |
| decorators=[onlyCPU, skipCPUIfNoMkldnn], |
| name="mkldnn_empty_batch_channel2d", |
| ), |
| subtest( |
| ( |
| (0, 6, 7, 8, 9), |
| False, |
| False, |
| 3, |
| torch._mkldnn, |
| torch._C._ConvBackend.MkldnnEmpty, |
| ), |
| decorators=[onlyCPU, skipCPUIfNoMkldnn], |
| name="mkldnn_empty_batch3d", |
| ), |
| subtest( |
| ( |
| (2, 0, 7, 8, 9), |
| False, |
| False, |
| 3, |
| torch._mkldnn, |
| torch._C._ConvBackend.MkldnnEmpty, |
| ), |
| decorators=[onlyCPU, skipCPUIfNoMkldnn], |
| name="mkldnn_empty_channel3d", |
| ), |
| subtest( |
| ( |
| (0, 0, 7, 8, 9), |
| False, |
| False, |
| 3, |
| torch._mkldnn, |
| torch._C._ConvBackend.MkldnnEmpty, |
| ), |
| decorators=[onlyCPU, skipCPUIfNoMkldnn], |
| name="mkldnn_empty_batch_channel3d", |
| ), |
| # Note: Tests for mobile backends are not currently supported. This comprises |
| # NnpackSpatial, Winograd3x3Depthwise, and Xnnpack2d backends. Testing these |
| # requires the ability to gate tests by whether PyTorch is built with USE_MOBILE=1. |
| ], |
| ) |
| # Test with both bias and no bias. |
| @parametrize_test("has_bias", [False, True]) |
| # Test with both stride=1 and stride>1 cases. |
| @parametrize_test("strided", [False, True]) |
| # Test with both contiguous and non-contiguous inputs. |
| @parametrize_test("contiguous", [False, True]) |
| def test_conv_backend( |
| self, |
| device, |
| input_shape, |
| has_bias, |
| strided, |
| contiguous, |
| transposed, |
| dilated, |
| groups, |
| layout, |
| backend_expected, |
| ): |
| # Build up inputs. |
| dtype = torch.float32 |
| C_in, C_out, dim, kernel_size = input_shape[1], 12, len(input_shape) - 2, 3 |
| x = torch.randn(*input_shape, device=device, dtype=dtype, requires_grad=True) |
| weight = torch.randn( |
| C_in if transposed else C_out, |
| C_out // groups if transposed else C_in // groups, |
| *[kernel_size for _ in range(dim)], |
| device=device, |
| dtype=dtype, |
| requires_grad=True, |
| ) |
| bias = ( |
| torch.randn(C_out, device=device, dtype=dtype, requires_grad=True) |
| if has_bias |
| else None |
| ) |
| |
| def _make_noncontiguous(inp): |
| if inp is None: |
| return None |
| old_requires_grad = inp.requires_grad |
| inp = torch.repeat_interleave(inp, 2, dim=-1) |
| inp = inp[..., ::2].detach().requires_grad_(old_requires_grad) |
| return inp |
| |
| if not contiguous: |
| x = _make_noncontiguous(x) |
| weight = _make_noncontiguous(weight) |
| bias = _make_noncontiguous(bias) |
| |
| if layout is torch._mkldnn: |
| x = x.to_mkldnn() |
| # Note that weight and bias are not supported as mkldnn tensors during training. |
| |
| stride = (2,) * dim if strided else (1,) * dim |
| padding = (0,) * dim |
| dilation = (2,) * dim if dilated else (1,) * dim |
| output_padding = (0,) * dim |
| inputs = [ |
| x, |
| weight, |
| bias, |
| stride, |
| padding, |
| dilation, |
| transposed, |
| output_padding, |
| groups, |
| ] |
| |
| # Ensure correct backend is selected. |
| backend_actual = torch._C._select_conv_backend(*inputs) |
| self.assertEqual(backend_actual, backend_expected) |
| |
| # Ensure backward call succeeds. |
| convolution = torch.ops.aten.convolution |
| output = convolution(*inputs) |
| grad_output = torch.randn(output.shape, device=device, dtype=dtype) |
| if not contiguous: |
| grad_output = _make_noncontiguous(grad_output) |
| if layout is torch._mkldnn: |
| grad_output = grad_output.to_mkldnn() |
| output.backward(grad_output) |
| |
| # mkldnn doesn't support gradcheck :( |
| if layout is torch._mkldnn: |
| return |
| |
| if backend_actual != torch._C._ConvBackend.Empty: # FIXME: forward AD fails |
| # Forward AD and forward-over-reverse AD smoke test in float32 |
| # TODO: remove this if we introduce per-op gradient tests for float32 |
| with fwAD.dual_level(): |
| dual_inputs = [ |
| ( |
| fwAD.make_dual(i, torch.rand_like(i)) |
| if isinstance(i, torch.Tensor) |
| else i |
| ) |
| for i in inputs |
| ] |
| # Forward AD |
| output = convolution(*dual_inputs) |
| # Forward over reverse AD |
| grad_output_d = fwAD.make_dual( |
| torch.rand_like(output), torch.rand_like(output) |
| ) |
| if has_bias: |
| torch.autograd.grad(output, [x, weight, bias], grad_output_d) |
| else: |
| torch.autograd.grad(output, [x, weight], grad_output_d) |
| |
| # Convert to float64 for gradcheck. |
| x = x.to(torch.float64).detach().requires_grad_(True) |
| weight = weight.to(torch.float64).detach().requires_grad_(True) |
| if bias is not None: |
| bias = bias.to(torch.float64).detach().requires_grad_(True) |
| inputs = [ |
| x, |
| weight, |
| bias, |
| stride, |
| padding, |
| dilation, |
| transposed, |
| output_padding, |
| groups, |
| ] |
| |
| # Set some backend-specific validation settings. |
| gradcheck_nondet_tol = 0.0 |
| if torch.backends.cudnn.is_available(): |
| # cuDNN introduces non-determinism |
| gradcheck_nondet_tol = GRADCHECK_NONDET_TOL |
| |
| self.assertTrue(gradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol)) |
| |
| # double backward doesn't support bias gradients |
| if bias is not None: |
| bias.requires_grad_(False) |
| self.assertTrue( |
| gradgradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol) |
| ) |
| |
| @onlyCPU |
| def test_conv_contiguous_for_oneDNN(self): |
| # See https://github.com/pytorch/pytorch/issues/80837. |
| for dtype in [torch.float, torch.bfloat16, torch.half]: |
| conv = nn.Conv2d( |
| 1, |
| 128, |
| kernel_size=(5, 2), |
| stride=(2, 1), |
| padding=(0, 1), |
| dilation=(1, 1), |
| groups=1, |
| bias=True, |
| padding_mode="zeros", |
| ).to(dtype=dtype) |
| |
| x = torch.rand([1, 2, 321, 201, 1]).to(dtype=dtype) |
| x = torch.transpose(x, 1, 4) |
| x2 = x[..., 0] |
| inputs = [ |
| x2, |
| conv.weight, |
| conv.bias, |
| (2, 1), |
| (0, 1), |
| (1, 1), |
| False, |
| (0, 1), |
| 1, |
| ] |
| if torch.backends.mkldnn.is_available(): |
| y = conv(x2) |
| # Disable MKLDNN explicitly |
| with torch.backends.mkldnn.flags(enabled=False): |
| y_ = conv(x2) |
| self.assertEqual(y, y_) |
| |
| @onlyCPU |
| def test_conv_ic1_channels_last_for_oneDNN(self): |
| # See https://github.com/pytorch/pytorch/issues/82060, N > 1 will call in OneDNN path. |
| for dtype in [torch.float, torch.bfloat16, torch.half]: |
| conv = torch.nn.Conv2d( |
| 1, 64, kernel_size=(3, 3), padding=(1, 1), bias=False |
| ) |
| conv = conv.to(memory_format=torch.channels_last).to(dtype=dtype) |
| x = torch.rand(2, 1, 100, 100).to(dtype=dtype) |
| if torch.backends.mkldnn.is_available(): |
| y = conv(x) |
| # Disable MKLDNN explicitly |
| with torch.backends.mkldnn.flags(enabled=False): |
| y_ = conv(x) |
| self.assertEqual(y, y_) |
| |
| @dtypes(torch.float, torch.cfloat) |
| def test_conv_empty_channel(self, device, dtype): |
| in_channels = 0 |
| mod = torch.nn.Conv1d(in_channels, 8, 2, stride=2, dtype=dtype).to(device) |
| inp = torch.randn(2, 0, 15, device=device, dtype=dtype) |
| _test_module_empty_input(self, mod, inp, check_size=False) |
| |
| with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"): |
| inp = torch.randn(2, 1, 0, device=device, dtype=dtype) |
| mod(inp) |
| |
| mod = torch.nn.Conv2d(in_channels, 33, 3, stride=2, dtype=dtype).to(device) |
| inp = torch.randn(2, 0, 50, 100, device=device, dtype=dtype) |
| _test_module_empty_input(self, mod, inp, check_size=False) |
| |
| with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"): |
| inp = torch.randn(2, 1, 40, 0, device=device, dtype=dtype) |
| mod(inp) |
| |
| mod = torch.nn.Conv3d(in_channels, 33, 3, stride=2, dtype=dtype).to(device) |
| inp = torch.randn(2, 0, 50, 20, 40, device=device, dtype=dtype) |
| _test_module_empty_input(self, mod, inp, check_size=False) |
| |
| with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"): |
| inp = torch.randn(2, 1, 50, 0, 40, device=device, dtype=dtype) |
| mod(inp) |
| |
| def test_group_conv_empty(self, device): |
| mod = torch.nn.Conv2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to( |
| device |
| ) |
| inp = torch.randn(0, 4, 4, 4, device=device) |
| _test_module_empty_input(self, mod, inp, check_size=False) |
| if self.device_type == "cuda" and self.has_cudnn(): |
| with torch.backends.cudnn.flags(enabled=False): |
| _test_module_empty_input(self, mod, inp, check_size=False) |
| |
| def test_group_convTranspose_empty(self, device): |
| mod = torch.nn.ConvTranspose2d( |
| 4, 4, stride=2, kernel_size=3, padding=1, groups=4 |
| ).to(device) |
| inp = torch.randn(0, 4, 4, 4, device=device) |
| _test_module_empty_input(self, mod, inp, check_size=False) |
| if self.device_type == "cuda" and self.has_cudnn(): |
| with torch.backends.cudnn.flags(enabled=False): |
| _test_module_empty_input(self, mod, inp, check_size=False) |
| |
| def test_convTranspose_empty(self, device): |
| mod = torch.nn.ConvTranspose2d(4, 4, stride=2, kernel_size=3, padding=1).to( |
| device |
| ) |
| inp = torch.randn(0, 4, 4, 4, device=device) |
| _test_module_empty_input(self, mod, inp, check_size=False) |
| if self.device_type == "cuda" and self.has_cudnn(): |
| with torch.backends.cudnn.flags(enabled=False): |
| _test_module_empty_input(self, mod, inp, check_size=False) |
| |
| @onlyCUDA |
| @largeTensorTest("12GB") |
| def test_conv_large_nosplit(self, device): |
| # Here we just test the convolution correctly route to the fallback implementation |
| # that is, it does not crash. The correctness of fallback implementation should be |
| # covered in other tests |
| dtype = torch.half if self.device_type == "cuda" else torch.float |
| conv1 = nn.Conv2d(2, 2, 8, 8).to(device).to(dtype) |
| input_large = torch.randn(1, 2, 1024, 1024 * 1024, dtype=dtype, device=device) |
| conv1(input_large) |
| conv2 = torch.nn.Conv2d(1, 1024, 1, 1).to(device).to(dtype) |
| input_large = torch.randn(1, 1, 2048, 1024, dtype=dtype, device=device) |
| conv2(input_large) |
| |
| def test_conv_noncontig_weights(self, device): |
| for dim in (1, 2, 3): |
| for grouped in (False, True): |
| nc = 3 |
| groups = 3 if grouped else 1 |
| w = torch.randn([3] * dim, device=device) |
| w = w.expand([nc, int(nc / groups)] + list(w.shape)) |
| w = w.detach().requires_grad_() |
| x = torch.randn( |
| [1, nc] + ([5] * dim), device=device, requires_grad=True |
| ) |
| y = getattr(F, f"conv{dim}d")(x, w, groups=groups) |
| y.sum().backward() |
| y = getattr(F, f"conv_transpose{dim}d")(x, w, groups=groups) |
| y.sum().backward() |
| |
| def test_conv_noncontig_weights_and_bias(self, device): |
| # need floats to exercise https://github.com/pytorch/pytorch/issues/16018 |
| for bias in [True, False]: |
| conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=bias).to( |
| device, torch.float |
| ) |
| |
| input_nc = torch.randn( |
| (1, 3, 224, 224, 2), device=device, dtype=torch.float |
| )[:, :, :, :, 1] |
| input_c = input_nc.contiguous() |
| |
| weight_nc = torch.randn((64, 3, 7, 7, 2), device=device, dtype=torch.float)[ |
| :, :, :, :, 1 |
| ] |
| conv1.weight = nn.Parameter(weight_nc) |
| weight_c = conv1.weight.contiguous() |
| |
| if bias: |
| bias_nc = torch.randn((64, 2), device=device, dtype=torch.float)[:, 1] |
| conv1.bias = nn.Parameter(bias_nc) |
| bias_c = conv1.bias.contiguous() |
| |
| out1 = conv1(input_nc) |
| conv1.weight = nn.Parameter(weight_c) |
| if bias: |
| conv1.bias = nn.Parameter(bias_c) |
| out2 = conv1(input_c) |
| self.assertEqual(out1, out2) |
| |
| @onlyCUDA |
| @largeTensorTest("12GB") |
| @skipIfRocmVersionLessThan((6, 0)) |
| def test_conv_transposed_large(self, device): |
| dtype = torch.half if self.device_type == "cuda" else torch.float |
| conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype) |
| input_large = torch.randn(4096, 1, 512, 1024, dtype=dtype, device=device) |
| # forward |
| ret = conv(input_large) |
| maxdiff0 = ( |
| (ret.narrow(0, 0, 1024) - conv(input_large.narrow(0, 0, 1024))) |
| .abs_() |
| .max() |
| .item() |
| ) |
| maxdiff1 = ( |
| (ret.narrow(0, 1024, 1024) - conv(input_large.narrow(0, 1024, 1024))) |
| .abs_() |
| .max() |
| .item() |
| ) |
| maxdiff2 = ( |
| (ret.narrow(0, 2048, 1024) - conv(input_large.narrow(0, 2048, 1024))) |
| .abs_() |
| .max() |
| .item() |
| ) |
| maxdiff3 = ( |
| (ret.narrow(0, 3072, 1024) - conv(input_large.narrow(0, 3072, 1024))) |
| .abs_() |
| .max() |
| .item() |
| ) |
| if self.device_type == "cuda": |
| # cuDNN may use algorithms such as FFT that don't guarantee a diff of 0 |
| self.assertEqual(maxdiff0, 0, atol=2e-3, rtol=1e-5) |
| self.assertEqual(maxdiff1, 0, atol=2e-3, rtol=1e-5) |
| self.assertEqual(maxdiff2, 0, atol=2e-3, rtol=1e-5) |
| self.assertEqual(maxdiff3, 0, atol=2e-3, rtol=1e-5) |
| else: |
| self.assertEqual(maxdiff0, 0) |
| self.assertEqual(maxdiff1, 0) |
| self.assertEqual(maxdiff2, 0) |
| self.assertEqual(maxdiff3, 0) |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| @largeTensorTest("12GB") |
| def test_conv_large(self, device): |
| dtype = torch.half if self.device_type == "cuda" else torch.float |
| conv = nn.Conv2d(2, 2, 8, 8, bias=False).to(device).to(dtype) |
| input_large = torch.randn(4097, 2, 512, 512, dtype=dtype, device=device) |
| # forward |
| ret = conv(input_large) |
| self.assertEqual(ret[:2048], conv(input_large[:2048])) |
| self.assertEqual(ret[2048:4096], conv(input_large[2048:4096])) |
| self.assertEqual(ret[4096:], conv(input_large[4096:])) |
| |
| # backward |
| conv.zero_grad() |
| # When computing the backward, we are using the `max(dim=1)`` to create |
| # some sparsity. Without this sparsity, the rounding error would be |
| # too large (as large as 1e-5) to satisfy the creterion (1e-6) of `assertEqual` |
| ret.view(4097, -1).max(dim=1).values.sum().backward() |
| del ret |
| grad1 = conv.weight.grad.detach().clone() |
| conv.zero_grad() |
| conv(input_large[:2048]).view(2048, -1).max(dim=1).values.sum().backward() |
| conv(input_large[2048:4096]).view(2048, -1).max(dim=1).values.sum().backward() |
| conv(input_large[4096:]).view(1, -1).max(dim=1).values.sum().backward() |
| grad2 = conv.weight.grad.detach().clone() |
| # gradients are at the order of hundreds, we need to scale it to |
| # the order of one so that we can compare |
| scale = 1 / grad2.abs().mean() |
| grad1 = grad1 * scale |
| grad2 = grad2 * scale |
| self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3) |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| @largeTensorTest("20GB", "cpu") |
| @largeTensorTest("60GB", "cuda") |
| def test_conv_large_batch_1(self, device): |
| in_channels = 514 |
| dim = 2048 |
| out_channels = 1 |
| kernel_size = 3 |
| stride = 1 |
| padding = 1 |
| |
| input_tensor = torch.ones(1, in_channels, dim, dim).cuda().half() |
| model = ( |
| nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) |
| .cuda() |
| .half() |
| ) |
| output = model(input_tensor) |
| model_cpu = model.cpu().float() |
| output_cpu = model(input_tensor.float().cpu()) |
| self.assertEqual(output.cpu().float(), output_cpu, atol=1e-3, rtol=1e-3) |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| @largeTensorTest("24GB", "cpu") |
| @largeTensorTest("20GB", "cuda") |
| def test_conv3d_large_batch_1(self, device): |
| x = torch.rand(1, 32, 512, 512, 256) |
| m = torch.nn.Conv3d(32, 1, kernel_size=1, padding=0, stride=1, bias=False) |
| yref = m(x) |
| y = m.to(device=device)(x.to(device=device)) |
| self.assertEqual(yref, y.cpu()) |
| |
| @onlyCUDA |
| @skipCUDAIfNoCudnn |
| def test_contig_wrong_stride_cudnn(self, device): |
| # x has to have batch_size 1 to test contiguous checks |
| x = torch.randn(1, 16, 5, 5, device=device) |
| stride = list(x.stride()) |
| stride[0] = 20 |
| # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1 |
| x.set_(x.storage(), 0, x.size(), stride) |
| self.assertTrue(x.is_contiguous()) |
| F.conv_transpose2d(x, torch.randn(16, 1, 1, 1, device=device)) |
| F.conv2d(x, torch.randn(1, 16, 1, 1, device=device)) |
| |
| @onlyCUDA |
| @tf32_on_and_off(0.005) |
| def test_Conv2d_size_1_kernel(self, device): |
| x_cpu = torch.randn(2, 3, 5, 5) |
| conv_cpu = torch.nn.Conv2d(3, 3, kernel_size=1) |
| y_cpu = conv_cpu(x_cpu) |
| y = torch.rand_like(y_cpu) |
| y_cpu.backward(y) |
| |
| with cudnn.flags(enabled=False): |
| conv_cuda = torch.nn.Conv2d(3, 3, kernel_size=1).to(device) |
| conv_cuda.bias.data.copy_(conv_cpu.bias.data) |
| conv_cuda.weight.data.copy_(conv_cpu.weight.data) |
| y_cuda = conv_cuda(x_cpu.to(device)) |
| y_cuda.backward(y.to(device)) |
| |
| self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False) |
| self.assertEqual( |
| conv_cpu.bias.grad.data, |
| conv_cuda.bias.grad.data, |
| atol=1e-5, |
| rtol=0, |
| exact_device=False, |
| ) |
| self.assertEqual( |
| conv_cpu.weight.grad.data, |
| conv_cuda.weight.grad.data, |
| atol=1e-5, |
| rtol=0, |
| exact_device=False, |
| ) |
| |
| @onlyCUDA |
| @tf32_on_and_off(0.005) |
| def test_ConvTranspose2d_size_1_kernel(self, device): |
| x_cpu = torch.randn(2, 3, 5, 5) |
| conv_cpu = torch.nn.ConvTranspose2d(3, 3, kernel_size=1) |
| y_cpu = conv_cpu(x_cpu) |
| y = torch.rand_like(y_cpu) |
| y_cpu.backward(y) |
| |
| with cudnn.flags(enabled=False): |
| conv_cuda = torch.nn.ConvTranspose2d(3, 3, kernel_size=1).to(device) |
| conv_cuda.bias.data.copy_(conv_cpu.bias.data) |
| conv_cuda.weight.data.copy_(conv_cpu.weight.data) |
| y_cuda = conv_cuda(x_cpu.to(device)) |
| y_cuda.backward(y.to(device)) |
| |
| self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False) |
| self.assertEqual( |
| conv_cpu.bias.grad.data, |
| conv_cuda.bias.grad.data, |
| atol=1e-5, |
| rtol=0, |
| exact_device=False, |
| ) |
| self.assertEqual( |
| conv_cpu.weight.grad.data, |
| conv_cuda.weight.grad.data, |
| atol=1e-5, |
| rtol=0, |
| exact_device=False, |
| ) |
| |
| @onlyCUDA |
| def test_ConvTranspose3d_size_1_kernel(self, device): |
| with set_default_dtype(torch.double): |
| x_cpu = torch.randn(2, 3, 3, 5, 5) |
| conv_cpu = torch.nn.ConvTranspose3d(3, 3, kernel_size=1) |
| y_cpu = conv_cpu(x_cpu) |
| y = torch.rand_like(y_cpu) |
| y_cpu.backward(y) |
| |
| with cudnn.flags(enabled=False): |
| conv_cuda = torch.nn.ConvTranspose3d(3, 3, kernel_size=1).to(device) |
| conv_cuda.bias.data.copy_(conv_cpu.bias.data) |
| conv_cuda.weight.data.copy_(conv_cpu.weight.data) |
| y_cuda = conv_cuda(x_cpu.to(device)) |
| y_cuda.backward(y.to(device)) |
| |
| self.assertEqual(y_cpu, y_cuda, atol=1e-5, rtol=0, exact_device=False) |
| self.assertEqual( |
| conv_cpu.bias.grad.data, |
| conv_cuda.bias.grad.data, |
| atol=1e-5, |
| rtol=0, |
| exact_device=False, |
| ) |
| self.assertEqual( |
| conv_cpu.weight.grad.data, |
| conv_cuda.weight.grad.data, |
| atol=1e-5, |
| rtol=0, |
| exact_device=False, |
| ) |
| |
| @dtypesIfCUDA( |
| *floating_types_and(torch.half, *[torch.bfloat16] if AMPERE_OR_ROCM else []) |
| ) |
| @dtypes(torch.float) |
| @torch.backends.cudnn.flags(enabled=True, benchmark=False) |
| @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7") |
| def test_Conv2d_naive_groups(self, device, dtype): |
| # Check that grouped convolutions matches two half convolutions |
| m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype) |
| i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True) |
| output = m(i) |
| grad_output = torch.randn(2, 4, 4, 4, device=device, dtype=dtype) |
| output.backward(grad_output) |
| |
| m1 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype) |
| m1.weight.data.copy_(m.weight.data[:2]) |
| m1.bias.data.copy_(m.bias.data[:2]) |
| i1 = i.data[:, :2].contiguous().requires_grad_(True) |
| output1 = m1(i1) |
| output1.backward(grad_output[:, :2].contiguous()) |
| |
| m2 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype) |
| m2.weight.data.copy_(m.weight.data[2:]) |
| m2.bias.data.copy_(m.bias.data[2:]) |
| i2 = i.data[:, 2:].contiguous().requires_grad_(True) |
| output2 = m2(i2) |
| output2.backward(grad_output[:, 2:].contiguous()) |
| |
| self.assertEqual(output, torch.cat([output1, output2], 1)) |
| self.assertEqual( |
| i.grad.data, |
| torch.cat([i1.grad.data, i2.grad.data], 1), |
| atol=dtype2prec_DONTUSE[dtype], |
| rtol=0, |
| ) |
| self.assertEqual( |
| m.bias.grad.data, |
| torch.cat([m1.bias.grad.data, m2.bias.grad.data], 0), |
| atol=dtype2prec_DONTUSE[dtype], |
| rtol=0, |
| ) |
| self.assertEqual( |
| m.weight.grad.data, |
| torch.cat([m1.weight.grad.data, m2.weight.grad.data], 0), |
| atol=dtype2prec_DONTUSE[dtype], |
| rtol=0, |
| ) |
| |
| @dtypes(torch.double, torch.cdouble) |
| def test_Conv2d_backward_depthwise(self, device, dtype): |
| x = torch.randn(2, 2, 4, 20, device=device, dtype=dtype, requires_grad=True) |
| weight = torch.randn(2, 1, 3, 5, device=device, dtype=dtype, requires_grad=True) |
| |
| def conv2d_depthwise(x, weight): |
| return torch.nn.functional.conv2d( |
| x, weight, bias=None, stride=(1, 10), groups=2 |
| ) |
| |
| for cudnn_enabled in [False, True]: |
| with torch.backends.cudnn.flags(enabled=cudnn_enabled): |
| torch.autograd.gradcheck(conv2d_depthwise, (x, weight)) |
| |
| @onlyCPU |
| @dtypes(torch.float, torch.double) |
| def test_conv_thnn_nhwc(self, device, dtype): |
| def helper( |
| mod, |
| n, |
| c, |
| h, |
| w, |
| out_channels, |
| kernel_size, |
| dilation, |
| groups, |
| input_format, |
| weight_format, |
| ): |
| input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to( |
| memory_format=input_format |
| ) |
| input.requires_grad_() |
| conv = mod( |
| c, out_channels, kernel_size, dilation=dilation, groups=groups |
| ).to(device="cpu", dtype=dtype, memory_format=weight_format) |
| for p in conv.parameters(): |
| p.data = torch.randint_like(p, -3, 3) |
| |
| ref_input = input.detach().clone().contiguous().requires_grad_() |
| ref_conv = mod( |
| c, out_channels, kernel_size, dilation=dilation, groups=groups |
| ) |
| # load_state_dict will restore the stride & memory_layout on ref_conv.weight. |
| ref_conv.load_state_dict(conv.state_dict()) |
| ref_conv = ref_conv.to( |
| device="cpu", dtype=dtype, memory_format=torch.contiguous_format |
| ) |
| |
| out = conv(input) |
| ref_out = ref_conv(ref_input) |
| |
| grad = torch.randint_like(out, -3, 3) |
| ref_grad = grad.detach().clone().contiguous() |
| |
| out.backward(grad) |
| ref_out.backward(ref_grad) |
| |
| self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) |
| self.assertTrue(ref_out.is_contiguous()) |
| self.assertEqual(out, ref_out, exact_dtype=False) |
| self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False) |
| self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False) |
| self.assertEqual(input.grad, ref_input.grad, exact_dtype=False) |
| |
| with torch.backends.mkldnn.flags(enabled=False): |
| formats = [ |
| [torch.channels_last, torch.channels_last], |
| [torch.channels_last, torch.contiguous_format], |
| [torch.contiguous_format, torch.channels_last], |
| ] |
| for input_format, weight_format in formats: |
| # non-dilated conv: thnn_conv2d normal path (with im2col) |
| helper( |
| nn.Conv2d, |
| 2, |
| 8, |
| 4, |
| 4, |
| out_channels=4, |
| kernel_size=3, |
| dilation=1, |
| groups=1, |
| input_format=input_format, |
| weight_format=weight_format, |
| ) |
| helper( |
| nn.Conv2d, |
| 2, |
| 8, |
| 4, |
| 4, |
| out_channels=8, |
| kernel_size=3, |
| dilation=1, |
| groups=8, |
| input_format=input_format, |
| weight_format=weight_format, |
| ) |
| # test when input chanels is 1 and not converted to channels last |
| helper( |
| nn.Conv2d, |
| 2, |
| 1, |
| 10, |
| 10, |
| out_channels=8, |
| kernel_size=3, |
| dilation=1, |
| groups=1, |
| input_format=torch.contiguous_format, |
| weight_format=torch.channels_last, |
| ) |
| # non-dilated conv: thnn_conv2d fast path (skip im2col) |
| helper( |
| nn.Conv2d, |
| 1, |
| 16, |
| 56, |
| 56, |
| out_channels=16, |
| kernel_size=1, |
| dilation=1, |
| groups=1, |
| input_format=input_format, |
| weight_format=weight_format, |
| ) |
| # ic == oc == 1 here, so need to stick input to CL to activate channels last |
| helper( |
| nn.Conv2d, |
| 1, |
| 16, |
| 56, |
| 56, |
| out_channels=16, |
| kernel_size=1, |
| dilation=1, |
| groups=16, |
| input_format=torch.channels_last, |
| weight_format=weight_format, |
| ) |
| # dilated conv: slow_conv_dilated2d |
| helper( |
| nn.Conv2d, |
| 2, |
| 8, |
| 11, |
| 13, |
| out_channels=16, |
| kernel_size=3, |
| dilation=2, |
| groups=1, |
| input_format=input_format, |
| weight_format=weight_format, |
| ) |
| helper( |
| nn.Conv2d, |
| 2, |
| 16, |
| 11, |
| 13, |
| out_channels=32, |
| kernel_size=3, |
| dilation=2, |
| groups=16, |
| input_format=input_format, |
| weight_format=weight_format, |
| ) |
| # transposed-conv: slow_conv_transpose2d |
| helper( |
| nn.ConvTranspose2d, |
| 2, |
| 8, |
| 4, |
| 4, |
| out_channels=4, |
| kernel_size=3, |
| dilation=1, |
| groups=1, |
| input_format=input_format, |
| weight_format=weight_format, |
| ) |
| helper( |
| nn.ConvTranspose2d, |
| 2, |
| 8, |
| 4, |
| 4, |
| out_channels=8, |
| kernel_size=3, |
| dilation=1, |
| groups=8, |
| input_format=input_format, |
| weight_format=weight_format, |
| ) |
| helper( |
| nn.ConvTranspose2d, |
| 1, |
| 16, |
| 56, |
| 56, |
| out_channels=16, |
| kernel_size=1, |
| dilation=1, |
| groups=1, |
| input_format=input_format, |
| weight_format=weight_format, |
| ) |
| helper( |
| nn.ConvTranspose2d, |
| 1, |
| 16, |
| 56, |
| 56, |
| out_channels=32, |
| kernel_size=1, |
| dilation=1, |
| groups=16, |
| input_format=input_format, |
| weight_format=weight_format, |
| ) |
| |
| @onlyCUDA |
| @skipCUDAIfRocmVersionLessThan((4, 3)) |
| @skipCUDAIfNotMiopenSuggestNHWC |
| @skipCUDAIfCudnnVersionLessThan(7603) |
| @dtypes(torch.half, torch.float, torch.cfloat) |
| def test_conv_cudnn_nhwc(self, device, dtype): |
| def helper(n, c, h, w, out_channels, kernel_size, groups): |
| input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to( |
| memory_format=torch.channels_last |
| ) |
| input.requires_grad_() |
| conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups).to( |
| device="cuda", dtype=dtype, memory_format=torch.channels_last |
| ) |
| for p in conv.parameters(): |
| p.data = torch.randint_like(p, -3, 3) |
| |
| # use FP64 channels-first conv as reference |
| ref_input = input.detach().clone().contiguous().double().requires_grad_() |
| ref_conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups) |
| # load_state_dict will restore the stride & memory_layout on ref_conv.weight. |
| ref_conv.load_state_dict(conv.state_dict()) |
| ref_conv = ref_conv.to( |
| device="cuda", dtype=torch.double, memory_format=torch.contiguous_format |
| ) |
| |
| out = conv(input) |
| ref_out = ref_conv(ref_input) |
| |
| grad = torch.randint_like(out, -3, 3) |
| ref_grad = grad.detach().clone().double().contiguous() |
| |
| out.backward(grad) |
| ref_out.backward(ref_grad) |
| |
| self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) |
| self.assertTrue(input.grad.is_contiguous(memory_format=torch.channels_last)) |
| self.assertTrue( |
| conv.weight.grad.is_contiguous(memory_format=torch.channels_last) |
| ) |
| |
| self.assertTrue(ref_out.is_contiguous()) |
| self.assertTrue(ref_input.grad.is_contiguous()) |
| self.assertTrue(ref_conv.weight.grad.is_contiguous()) |
| |
| self.assertEqual(out, ref_out, exact_dtype=False) |
| self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False) |
| self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False) |
| self.assertEqual(input.grad, ref_input.grad, exact_dtype=False) |
| |
| helper(2, 8, 4, 4, out_channels=4, kernel_size=3, groups=1) |
| helper(2, 8, 4, 4, out_channels=8, kernel_size=3, groups=8) |
| helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=1) |
| helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16) |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| @skipCUDAIfCudnnVersionLessThan(8005) |
| @dtypes(torch.half, torch.float) |
| def test_conv_cudnn_ndhwc(self, device, dtype): |
| def helper(n, c, d, h, w, out_channels, kernel_size, groups): |
| input = torch.randint( |
| -2, 2, (n, c, d, h, w), dtype=dtype, device=device |
| ).to(memory_format=torch.channels_last_3d) |
| input.requires_grad_() |
| conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups).to( |
| device="cuda", dtype=dtype, memory_format=torch.channels_last_3d |
| ) |
| for p in conv.parameters(): |
| p.data = torch.randint_like(p, -2, 2) |
| |
| # use FP64 channels-first conv as reference |
| ref_input = input.detach().clone().contiguous().double().requires_grad_() |
| ref_conv = nn.Conv3d(c, out_channels, kernel_size, groups=groups) |
| # load_state_dict will restore the stride & memory_layout on ref_conv.weight. |
| ref_conv.load_state_dict(conv.state_dict()) |
| ref_conv = ref_conv.to( |
| device="cuda", dtype=torch.double, memory_format=torch.contiguous_format |
| ) |
| |
| out = conv(input) |
| ref_out = ref_conv(ref_input) |
| |
| grad = torch.randint_like(out, -2, 2) |
| ref_grad = grad.detach().clone().double().contiguous() |
| |
| out.backward(grad) |
| ref_out.backward(ref_grad) |
| |
| self.assertTrue(out.is_contiguous(memory_format=torch.channels_last_3d)) |
| self.assertTrue( |
| input.grad.is_contiguous(memory_format=torch.channels_last_3d) |
| ) |
| self.assertTrue( |
| conv.weight.grad.is_contiguous(memory_format=torch.channels_last_3d) |
| ) |
| |
| self.assertTrue(ref_out.is_contiguous()) |
| self.assertTrue(ref_input.grad.is_contiguous()) |
| self.assertTrue(ref_conv.weight.grad.is_contiguous()) |
| |
| self.assertEqual(out, ref_out, exact_dtype=False) |
| self.assertEqual(conv.weight.grad, ref_conv.weight.grad, exact_dtype=False) |
| self.assertEqual(conv.bias.grad, ref_conv.bias.grad, exact_dtype=False) |
| self.assertEqual(input.grad, ref_input.grad, exact_dtype=False) |
| |
| helper(2, 8, 4, 4, 4, out_channels=4, kernel_size=3, groups=1) |
| helper(2, 8, 4, 4, 4, out_channels=8, kernel_size=3, groups=8) |
| helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=1) |
| helper(1, 16, 18, 18, 18, out_channels=16, kernel_size=3, groups=16) |
| |
| def _run_conv( |
| self, |
| layer, |
| device, |
| inp, |
| grad, |
| ref_conv, |
| ref_input, |
| ref_out, |
| input_format, |
| weight_format, |
| grad_format, |
| output_format, |
| ): |
| conv = ( |
| layer(inp.size(1), grad.size(1), ref_conv.weight.size(2)).float().to(device) |
| ) |
| # load_state_dict will restore the stride & memory_layout on ref_conv.weight. |
| conv.load_state_dict(ref_conv.state_dict()) |
| weight_data = ( |
| conv.weight.detach().clone().contiguous(memory_format=weight_format) |
| ) |
| conv.weight.data = weight_data.resize_( |
| weight_data.size(), memory_format=weight_format |
| ) |
| input = inp.clone().contiguous(memory_format=input_format) |
| input.resize_(input.size(), memory_format=input_format) |
| input = input.requires_grad_() |
| grad = grad.contiguous(memory_format=grad_format) |
| grad.resize_(grad.size(), memory_format=grad_format) |
| out = conv(input) |
| out.backward(grad) |
| self.assertTrue(out.is_contiguous(memory_format=output_format)) |
| self.assertEqual(out, ref_out) |
| self.assertEqual(conv.weight.grad, ref_conv.weight.grad) |
| self.assertEqual(conv.bias.grad, ref_conv.bias.grad) |
| self.assertEqual(input.grad, ref_input.grad) |
| |
| def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device): |
| data = torch.randint(1, 10, (n, c, h, w), dtype=torch.float32, device=device) |
| ref_input = data.clone().contiguous().requires_grad_(True) |
| ref_conv = layer(c, k, filter_size).float().to(device) |
| ref_out = ref_conv(ref_input) |
| grad = torch.randint(1, 10, ref_out.size(), dtype=torch.float32, device="cuda") |
| ref_out.backward(grad) |
| |
| for w_f in [torch.contiguous_format, torch.channels_last]: |
| for g_f in [torch.contiguous_format, torch.channels_last]: |
| for input_format in [torch.contiguous_format, torch.channels_last]: |
| output_format = torch.contiguous_format |
| # Older versions of CudNN have Channels Last support disabled |
| if torch.backends.cudnn.version() >= 7603: |
| if input_format == torch.channels_last: |
| output_format = torch.channels_last |
| # This is because we have N111 weight that cannot handle |
| # the ambiguous memory_format |
| if w_f == torch.channels_last: |
| if layer == nn.Conv2d and filter_size * c != 1: |
| output_format = torch.channels_last |
| if layer == nn.ConvTranspose2d and filter_size * k != 1: |
| output_format = torch.channels_last |
| self._run_conv( |
| layer, |
| device, |
| data, |
| grad, |
| ref_conv, |
| ref_input, |
| ref_out, |
| input_format, |
| w_f, |
| g_f, |
| output_format, |
| ) |
| |
| @onlyCUDA |
| @skipCUDAIfRocmVersionLessThan((4, 3)) |
| @skipCUDAIfNotMiopenSuggestNHWC |
| @skipCUDAIfCudnnVersionLessThan(7603) |
| @tf32_on_and_off(0.05) |
| def test_conv_cudnn_mismatch_memory_format(self, device): |
| configs = [ |
| [4, 2, 8, 8, 4, 2], |
| [4, 1, 8, 8, 4, 2], |
| [1, 1, 8, 8, 4, 2], |
| [4, 2, 2, 8, 4, 1], |
| [4, 2, 1, 8, 4, 1], |
| [4, 2, 8, 8, 4, 1], |
| [4, 1, 8, 8, 4, 1], |
| ] |
| for n, c, h, w, k, filter_size in configs: |
| self._test_conv_cudnn_nhwc_nchw( |
| nn.Conv2d, n, c, h, w, k, filter_size, device |
| ) |
| self._test_conv_cudnn_nhwc_nchw( |
| nn.ConvTranspose2d, n, c, h, w, k, filter_size, device |
| ) |
| |
| # torch.half is erroring out on Windows with CUDA 10.1 + cuDNN 7.6.4 |
| # returning CUDNN_STATUS_BAD_PARAM |
| # Disabling that specific test for now [see issue # 33918] |
| @onlyCUDA |
| @skipCUDAIfNoCudnn |
| @dtypes(torch.float, torch.double) |
| def test_conv_cudnn_nhwc_support(self, device, dtype): |
| input = torch.randn( |
| (1, 16, 1, 1), dtype=dtype, device="cuda", requires_grad=True |
| ) |
| weight = torch.randn( |
| (8, 16, 3, 3), dtype=dtype, device="cuda", requires_grad=True |
| ) |
| weight = weight.to(memory_format=torch.channels_last) |
| o = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1) |
| self.assertTrue(o.is_contiguous(memory_format=torch.channels_last)) |
| o.sum().backward() |
| |
| # Test that faster algorithms used for inference produce the same results |
| # Validates depthwise3x3 bug reported in https://github.com/pytorch/pytorch/issues/60176 |
| @onlyCPU |
| @dtypes(torch.float) |
| def test_conv2d_no_grad(self, device, dtype): |
| for batch in [1, 2, 3]: |
| for groups in [1, 2, 4]: |
| input = torch.rand(batch, groups, 8, 8, dtype=dtype, device=device) |
| m = nn.Conv2d( |
| groups, |
| 8, |
| kernel_size=(3, 3), |
| groups=groups, |
| dtype=dtype, |
| device=device, |
| ) |
| with torch.no_grad(): |
| output_ng = m(input) |
| output = m(input) |
| self.assertEqual(output, output_ng, rtol=1e-2, atol=1e-5) |
| |
| @onlyCUDA |
| @skipCUDAIfNoCudnn |
| @dtypes(torch.float, torch.float16) |
| @precisionOverride({torch.half: 0.002, torch.float: 1e-4}) |
| def test_cudnn_convolution_relu(self, device, dtype): |
| for batch, groups, image_size, kernel_size, memory_format in product( |
| (1, 2, 3), |
| (1, 2, 4), |
| ((1, 1), (8, 8)), |
| ((1, 1), (3, 3)), |
| (torch.channels_last, torch.contiguous_format), |
| ): |
| if image_size[0] < kernel_size[0]: |
| continue |
| inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device) |
| w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device) |
| conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1) |
| inp = inp.to(memory_format=memory_format) |
| w = w.to(memory_format=memory_format) |
| if torch.version.hip: |
| cudnn_out = torch.miopen_convolution_relu( |
| inp, w, None, (1, 1), (0, 0), (1, 1), 1 |
| ) |
| else: |
| cudnn_out = torch.cudnn_convolution_relu( |
| inp, w, None, (1, 1), (0, 0), (1, 1), 1 |
| ) |
| self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format)) |
| if tf32_is_not_fp32() and dtype == torch.float: |
| self.assertEqual(conv2d_out.relu(), cudnn_out, atol=4e-3, rtol=0.006) |
| else: |
| self.assertEqual(conv2d_out.relu(), cudnn_out) |
| |
| @onlyCUDA |
| @skipCUDAIfNoCudnn |
| @dtypes(torch.float, torch.float16) |
| @precisionOverride({torch.half: 0.002, torch.float: 1e-4}) |
| def test_cudnn_convolution_add_relu(self, device, dtype): |
| for batch, groups, image_size, kernel_size, memory_format in product( |
| (1, 2, 3), |
| (1, 2, 4), |
| ((1, 1), (8, 8)), |
| ((1, 1), (3, 3)), |
| (torch.channels_last, torch.contiguous_format), |
| ): |
| if image_size[0] < kernel_size[0]: |
| continue |
| inp = torch.rand(batch, groups, *image_size, dtype=dtype, device=device) |
| w = torch.randn(8, groups, *kernel_size, dtype=dtype, device=device) |
| conv2d_out = torch.conv2d(inp, w, None, (1, 1), (0, 0), (1, 1), 1) |
| alpha = 2.0 |
| z = torch.randn_like(conv2d_out) |
| |
| inp = inp.to(memory_format=memory_format) |
| w = w.to(memory_format=memory_format) |
| z = z.to(memory_format=memory_format) |
| if torch.version.hip: |
| cudnn_out = torch.miopen_convolution_add_relu( |
| inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1 |
| ) |
| else: |
| cudnn_out = torch.cudnn_convolution_add_relu( |
| inp, w, z, alpha, None, (1, 1), (0, 0), (1, 1), 1 |
| ) |
| |
| self.assertTrue(cudnn_out.is_contiguous(memory_format=memory_format)) |
| if tf32_is_not_fp32() and dtype == torch.float: |
| self.assertEqual( |
| F.relu(conv2d_out + alpha * z), cudnn_out, atol=2e-3, rtol=0.006 |
| ) |
| else: |
| self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out) |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| @skipCUDAIfCudnnVersionLessThan(7603) |
| def test_convert_conv2d_weight_memory_format(self, device): |
| input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device=device) |
| model = nn.Sequential(nn.Conv2d(8, 4, 3), nn.BatchNorm2d(4)).to(device).float() |
| for memory_format in [torch.channels_last, torch.contiguous_format]: |
| model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format) |
| out = model(input) |
| self.assertTrue(out.is_contiguous(memory_format=memory_format)) |
| |
| model = ( |
| nn.Sequential(nn.ConvTranspose2d(8, 4, 3), nn.BatchNorm2d(4)) |
| .to(device) |
| .float() |
| ) |
| for memory_format in [torch.channels_last, torch.contiguous_format]: |
| model = nn.utils.convert_conv2d_weight_memory_format(model, memory_format) |
| out = model(input) |
| self.assertTrue(out.is_contiguous(memory_format=memory_format)) |
| |
| @onlyCUDA |
| @skipCUDAIfRocm |
| @skipCUDAIfCudnnVersionLessThan(7603) |
| def test_convert_conv3d_weight_memory_format(self, device): |
| input = torch.randint( |
| 1, 10, (2, 8, 4, 4, 4), dtype=torch.float32, device=device |
| ) |
| model = ( |
| nn.Sequential(nn.ConvTranspose3d(8, 4, 3), nn.BatchNorm3d(4)) |
| .to(device) |
| .float() |
| ) |
| for memory_format in [torch.channels_last_3d, torch.contiguous_format]: |
| model = nn.utils.convert_conv3d_weight_memory_format(model, memory_format) |
| out = model(input) |
| self.assertTrue(out.is_contiguous(memory_format=memory_format)) |
| |
| def test_conv_double_backward_strided_with_3D_input_and_weight(self, device): |
| # Test that _convolution_double_backward() outputs the correct grad shapes |
| # for 3D input / weight when stride > 1. This is an ad-hoc regression test for a |
| # specific case that was uncovered during the convolution consolidation effort. |
| # The test can be safely deleted if _convolution_double_backward() is removed. |
| |
| input = torch.randn(2, 3, 6, device=device) |
| weight = torch.randn(3, 3, 3, device=device) |
| bias = torch.randn(3, device=device) |
| stride = (2,) |
| padding = (1,) |
| dilation = (1,) |
| transposed = False |
| output_padding = (0,) |
| groups = 1 |
| output = torch.ops.aten.convolution( |
| input, |
| weight, |
| bias, |
| stride, |
| padding, |
| dilation, |
| transposed, |
| output_padding, |
| groups, |
| ) |
| |
| ggI = torch.randn(input.shape, device=device) |
| ggW = torch.randn(weight.shape, device=device) |
| ggB = torch.randn(bias.shape, device=device) |
| gO = torch.randn(output.shape, device=device) |
| output_mask = [True, True, True] |
| ( |
| grad_grad_output, |
| grad_input, |
| grad_weight, |
| ) = torch.ops.aten._convolution_double_backward( |
| ggI, |
| ggW, |
| ggB, |
| gO, |
| weight, |
| input, |
| stride, |
| padding, |
| dilation, |
| transposed, |
| output_padding, |
| groups, |
| output_mask, |
| ) |
| |
| # Make sure the correct shapes are computed. |
| self.assertEqual(grad_grad_output.shape, gO.shape) |
| self.assertEqual(grad_input.shape, input.shape) |
| self.assertEqual(grad_weight.shape, weight.shape) |
| |
| @onlyCUDA |
| @largeTensorTest("40GB") |
| @largeTensorTest("24GB", "cpu") |
| def test_conv3d_64bit_indexing(self, device): |
| x = torch.rand(1, 32, 512, 512, 256) |
| m = torch.nn.Conv3d(32, 1, kernel_size=1, padding=0, stride=1, bias=False) |
| yref = m(x) |
| y = m.to(device=device)(x.to(device=device)) |
| self.assertEqual(yref, y) |
| |
| |
| instantiate_device_type_tests(TestConvolutionNNDeviceType, globals()) |
| instantiate_parametrized_tests(TestConvolutionNN) |
| |
| if __name__ == "__main__": |
| run_tests() |