ENH Enables No-batch for *Pad1d Modules (#61060)
Summary:
Toward https://github.com/pytorch/pytorch/issues/60585
This PR adds a `single_batch_reference_fn` that uses the single batch implementation to check no-batch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61060
Reviewed By: mrshenli
Differential Revision: D29739823
Pulled By: jbschlosser
fbshipit-source-id: d90d88a3671177a647171801cc6ec7aa3df35482
diff --git a/test/test_nn.py b/test/test_nn.py
index 49d1cde..91afa62 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -13050,11 +13050,6 @@
(torch.nn.ReplicationPad3d(3), torch.randn(0, 3, 10, 10, 10, device=device, dtype=dtype))]:
self._test_module_empty_input(mod, inp, check_size=False)
- with self.assertRaisesRegex(NotImplementedError, 'Only 3D'):
- mod = torch.nn.ReplicationPad1d(2)
- inp = torch.randn(3, 10, device=device, dtype=dtype)
- mod(inp)
-
with self.assertRaisesRegex(RuntimeError, 'Expected 2D or 3D'):
mod = torch.nn.ReplicationPad1d(2)
inp = torch.randn(3, 0, 10, device=device, dtype=dtype)
diff --git a/torch/csrc/api/include/torch/nn/functional/padding.h b/torch/csrc/api/include/torch/nn/functional/padding.h
index 7a9554b..3781fb1 100644
--- a/torch/csrc/api/include/torch/nn/functional/padding.h
+++ b/torch/csrc/api/include/torch/nn/functional/padding.h
@@ -44,8 +44,7 @@
"Padding mode \"",
torch::enumtype::get_enum_name(mode),
"\" doesn't take in value argument");
- if (input.dim() == 3) {
- TORCH_CHECK(pad.size() == 2, "3D tensors expect 2 values for padding");
+ if (pad.size() == 2 && (input.dim() == 2 || input.dim() == 3)) {
if (c10::get_if<enumtype::kReflect>(&mode)) {
return torch::reflection_pad1d(input, pad);
} else if (c10::get_if<enumtype::kReplicate>(&mode)) {
@@ -78,7 +77,7 @@
TORCH_CHECK(false, "NotImplementedError");
}
} else {
- TORCH_CHECK(false, "Only 3D, 4D, 5D padding with non-constant padding are supported for now");
+ TORCH_CHECK(false, "Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now");
}
}
}
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index df75e55..3621d70 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -4154,8 +4154,7 @@
return _VF.constant_pad_nd(input, pad, value)
else:
assert value == 0, 'Padding mode "{}"" doesn\'t take in value argument'.format(mode)
- if input.dim() == 3:
- assert len(pad) == 2, "3D tensors expect 2 values for padding"
+ if len(pad) == 2 and (input.dim() == 2 or input.dim() == 3):
if mode == "reflect":
return torch._C._nn.reflection_pad1d(input, pad)
elif mode == "replicate":
@@ -4187,7 +4186,7 @@
else:
raise NotImplementedError
else:
- raise NotImplementedError("Only 3D, 4D, 5D padding with non-constant padding are supported for now")
+ raise NotImplementedError("Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now")
# We define this function as _pad because it takes an argument
diff --git a/torch/nn/modules/padding.py b/torch/nn/modules/padding.py
index 381141d..bc13afc 100644
--- a/torch/nn/modules/padding.py
+++ b/torch/nn/modules/padding.py
@@ -37,8 +37,8 @@
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
Shape:
- - Input: :math:`(N, C, W_{in})`
- - Output: :math:`(N, C, W_{out})` where
+ - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
+ - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
@@ -188,8 +188,8 @@
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
Shape:
- - Input: :math:`(N, C, W_{in})`
- - Output: :math:`(N, C, W_{out})` where
+ - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
+ - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
@@ -341,8 +341,8 @@
(:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
Shape:
- - Input: :math:`(N, C, W_{in})`
- - Output: :math:`(N, C, W_{out})` where
+ - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
+ - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where
:math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py
index 6a05e4f..92f5089 100644
--- a/torch/testing/_internal/common_nn.py
+++ b/torch/testing/_internal/common_nn.py
@@ -2215,6 +2215,14 @@
module_name='ReflectionPad1d',
constructor_args=((1, 2),),
cpp_constructor_args='torch::nn::ReflectionPad1dOptions({1, 2})',
+ input_size=(3, 8),
+ reference_fn=single_batch_reference_fn,
+ desc='batch',
+ ),
+ dict(
+ module_name='ReflectionPad1d',
+ constructor_args=((1, 2),),
+ cpp_constructor_args='torch::nn::ReflectionPad1dOptions({1, 2})',
input_fn=lambda: torch.rand(2, 3, 8, dtype=torch.complex128, requires_grad=True),
skip_half=True,
desc='complex'
@@ -2257,6 +2265,14 @@
module_name='ReplicationPad1d',
constructor_args=((1, 2),),
cpp_constructor_args='torch::nn::ReplicationPad1dOptions({1, 2})',
+ input_size=(3, 4),
+ reference_fn=single_batch_reference_fn,
+ desc='batch',
+ ),
+ dict(
+ module_name='ReplicationPad1d',
+ constructor_args=((1, 2),),
+ cpp_constructor_args='torch::nn::ReplicationPad1dOptions({1, 2})',
input_fn=lambda: torch.rand(2, 3, 4, dtype=torch.complex128, requires_grad=True),
skip_half=True,
desc='complex'
@@ -2306,6 +2322,14 @@
module_name='ConstantPad1d',
constructor_args=((1, 2), 2.),
cpp_constructor_args='torch::nn::ConstantPad1dOptions({1, 2}, 2.)',
+ input_size=(3, 4),
+ reference_fn=single_batch_reference_fn,
+ desc='batch',
+ ),
+ dict(
+ module_name='ConstantPad1d',
+ constructor_args=((1, 2), 2.),
+ cpp_constructor_args='torch::nn::ConstantPad1dOptions({1, 2}, 2.)',
input_fn=lambda: torch.rand(2, 3, 4, dtype=torch.complex128, requires_grad=True),
skip_half=True,
desc='complex'