ENH Enables No-batch for *Pad1d Modules (#61060)

Summary:
Toward https://github.com/pytorch/pytorch/issues/60585

This PR adds a `single_batch_reference_fn` that uses the single batch implementation to check no-batch.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61060

Reviewed By: mrshenli

Differential Revision: D29739823

Pulled By: jbschlosser

fbshipit-source-id: d90d88a3671177a647171801cc6ec7aa3df35482
diff --git a/test/test_nn.py b/test/test_nn.py
index 49d1cde..91afa62 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -13050,11 +13050,6 @@
                 (torch.nn.ReplicationPad3d(3), torch.randn(0, 3, 10, 10, 10, device=device, dtype=dtype))]:
             self._test_module_empty_input(mod, inp, check_size=False)
 
-        with self.assertRaisesRegex(NotImplementedError, 'Only 3D'):
-            mod = torch.nn.ReplicationPad1d(2)
-            inp = torch.randn(3, 10, device=device, dtype=dtype)
-            mod(inp)
-
         with self.assertRaisesRegex(RuntimeError, 'Expected 2D or 3D'):
             mod = torch.nn.ReplicationPad1d(2)
             inp = torch.randn(3, 0, 10, device=device, dtype=dtype)
diff --git a/torch/csrc/api/include/torch/nn/functional/padding.h b/torch/csrc/api/include/torch/nn/functional/padding.h
index 7a9554b..3781fb1 100644
--- a/torch/csrc/api/include/torch/nn/functional/padding.h
+++ b/torch/csrc/api/include/torch/nn/functional/padding.h
@@ -44,8 +44,7 @@
       "Padding mode \"",
       torch::enumtype::get_enum_name(mode),
       "\" doesn't take in value argument");
-    if (input.dim() == 3) {
-      TORCH_CHECK(pad.size() == 2, "3D tensors expect 2 values for padding");
+    if (pad.size() == 2 && (input.dim() == 2 || input.dim() == 3)) {
       if (c10::get_if<enumtype::kReflect>(&mode)) {
         return torch::reflection_pad1d(input, pad);
       } else if (c10::get_if<enumtype::kReplicate>(&mode)) {
@@ -78,7 +77,7 @@
         TORCH_CHECK(false, "NotImplementedError");
       }
     } else {
-      TORCH_CHECK(false, "Only 3D, 4D, 5D padding with non-constant padding are supported for now");
+      TORCH_CHECK(false, "Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now");
     }
   }
 }
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index df75e55..3621d70 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -4154,8 +4154,7 @@
         return _VF.constant_pad_nd(input, pad, value)
     else:
         assert value == 0, 'Padding mode "{}"" doesn\'t take in value argument'.format(mode)
-        if input.dim() == 3:
-            assert len(pad) == 2, "3D tensors expect 2 values for padding"
+        if len(pad) == 2 and (input.dim() == 2 or input.dim() == 3):
             if mode == "reflect":
                 return torch._C._nn.reflection_pad1d(input, pad)
             elif mode == "replicate":
@@ -4187,7 +4186,7 @@
             else:
                 raise NotImplementedError
         else:
-            raise NotImplementedError("Only 3D, 4D, 5D padding with non-constant padding are supported for now")
+            raise NotImplementedError("Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now")
 
 
 # We define this function as _pad because it takes an argument
diff --git a/torch/nn/modules/padding.py b/torch/nn/modules/padding.py
index 381141d..bc13afc 100644
--- a/torch/nn/modules/padding.py
+++ b/torch/nn/modules/padding.py
@@ -37,8 +37,8 @@
             (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
 
     Shape:
-        - Input: :math:`(N, C, W_{in})`
-        - Output: :math:`(N, C, W_{out})` where
+        - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
+        - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where
 
           :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
 
@@ -188,8 +188,8 @@
             (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
 
     Shape:
-        - Input: :math:`(N, C, W_{in})`
-        - Output: :math:`(N, C, W_{out})` where
+        - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
+        - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where
 
           :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
 
@@ -341,8 +341,8 @@
             (:math:`\text{padding\_left}`, :math:`\text{padding\_right}`)
 
     Shape:
-        - Input: :math:`(N, C, W_{in})`
-        - Output: :math:`(N, C, W_{out})` where
+        - Input: :math:`(C, W_{in})` or :math:`(N, C, W_{in})`.
+        - Output: :math:`(C, W_{out})` or :math:`(N, C, W_{out})`, where
 
           :math:`W_{out} = W_{in} + \text{padding\_left} + \text{padding\_right}`
 
diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py
index 6a05e4f..92f5089 100644
--- a/torch/testing/_internal/common_nn.py
+++ b/torch/testing/_internal/common_nn.py
@@ -2215,6 +2215,14 @@
         module_name='ReflectionPad1d',
         constructor_args=((1, 2),),
         cpp_constructor_args='torch::nn::ReflectionPad1dOptions({1, 2})',
+        input_size=(3, 8),
+        reference_fn=single_batch_reference_fn,
+        desc='batch',
+    ),
+    dict(
+        module_name='ReflectionPad1d',
+        constructor_args=((1, 2),),
+        cpp_constructor_args='torch::nn::ReflectionPad1dOptions({1, 2})',
         input_fn=lambda: torch.rand(2, 3, 8, dtype=torch.complex128, requires_grad=True),
         skip_half=True,
         desc='complex'
@@ -2257,6 +2265,14 @@
         module_name='ReplicationPad1d',
         constructor_args=((1, 2),),
         cpp_constructor_args='torch::nn::ReplicationPad1dOptions({1, 2})',
+        input_size=(3, 4),
+        reference_fn=single_batch_reference_fn,
+        desc='batch',
+    ),
+    dict(
+        module_name='ReplicationPad1d',
+        constructor_args=((1, 2),),
+        cpp_constructor_args='torch::nn::ReplicationPad1dOptions({1, 2})',
         input_fn=lambda: torch.rand(2, 3, 4, dtype=torch.complex128, requires_grad=True),
         skip_half=True,
         desc='complex'
@@ -2306,6 +2322,14 @@
         module_name='ConstantPad1d',
         constructor_args=((1, 2), 2.),
         cpp_constructor_args='torch::nn::ConstantPad1dOptions({1, 2}, 2.)',
+        input_size=(3, 4),
+        reference_fn=single_batch_reference_fn,
+        desc='batch',
+    ),
+    dict(
+        module_name='ConstantPad1d',
+        constructor_args=((1, 2), 2.),
+        cpp_constructor_args='torch::nn::ConstantPad1dOptions({1, 2}, 2.)',
         input_fn=lambda: torch.rand(2, 3, 4, dtype=torch.complex128, requires_grad=True),
         skip_half=True,
         desc='complex'