[pt2] add metas for `pad` ops (#103815)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103815
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/native/ReflectionPad.cpp b/aten/src/ATen/native/ReflectionPad.cpp
index 9252b36..25b3f8e 100644
--- a/aten/src/ATen/native/ReflectionPad.cpp
+++ b/aten/src/ATen/native/ReflectionPad.cpp
@@ -61,10 +61,9 @@
 
   TORCH_CHECK(
       output_w >= 1,
-      2,
       "input (W: ",
       input_w,
-      ")is too small. Calculated output W: ",
+      ") is too small. Calculated output W: ",
       output_w);
 
   if (input.ndimension() == 2) {
@@ -202,7 +201,7 @@
   TORCH_CHECK(output_h == grad_output.size(dim_h), "grad_output height unexpected."
     " Expected: ", output_h, ", Got: ", grad_output.size(dim_h));
   TORCH_CHECK(output_d == grad_output.size(dim_d), "grad_output depth unexpected."
-    " Expected: ", output_h, ", Got: ", grad_output.size(dim_d));
+    " Expected: ", output_d, ", Got: ", grad_output.size(dim_d));
 
   set_output_raw_strided(0, input.sizes(), {}, input.options());
 }
@@ -244,15 +243,15 @@
   TORCH_CHECK(pad_l < input_w && pad_r < input_w,
     "Argument #4: Padding size should be less than the corresponding "
     "input dimension, but got: padding (", pad_l, ", ", pad_r,
-    ") at dimension ", dim_w, " of input ", ndim);
+    ") at dimension ", dim_w, " of input ", input.sizes());
 
   TORCH_CHECK(pad_t < input_h && pad_b < input_h,
     "Argument #6: Padding size should be less than the corresponding "
     "input dimension, but got: padding (", pad_t, ", ", pad_b,
-    ") at dimension ", dim_h, " of input ", ndim);
+    ") at dimension ", dim_h, " of input ", input.sizes());
 
   TORCH_CHECK(output_w >= 1 || output_h >= 1,
-    "input (H: ", input_h, ", W: ", input_w, ")is too small. Calculated "
+    "input (H: ", input_h, ", W: ", input_w, ") is too small. Calculated "
     "output H: ", output_h, " W: ", output_w);
 
   /* resize output */
diff --git a/aten/src/ATen/native/cuda/ReflectionPad.cu b/aten/src/ATen/native/cuda/ReflectionPad.cu
index 4372e6c..4bb1bdc 100644
--- a/aten/src/ATen/native/cuda/ReflectionPad.cu
+++ b/aten/src/ATen/native/cuda/ReflectionPad.cu
@@ -311,7 +311,7 @@
   int output_w  = input_w + pad_l + pad_r;
 
   TORCH_CHECK(output_w >= 1 || output_h >= 1,
-    "input (H: ", input_h, ", W: ", input_w, ")is too small.  Calculated "
+    "input (H: ", input_h, ", W: ", input_w, ") is too small.  Calculated "
     "output H: ", output_h, " W: ", output_w);
 
   if (input_.ndimension() == 3) {
diff --git a/aten/src/ATen/native/mps/operations/Pad.mm b/aten/src/ATen/native/mps/operations/Pad.mm
index e61d681..b608e68 100644
--- a/aten/src/ATen/native/mps/operations/Pad.mm
+++ b/aten/src/ATen/native/mps/operations/Pad.mm
@@ -135,7 +135,7 @@
                     ") at dimension ",
                     dim_w,
                     " of input ",
-                    ndims);
+                    input_.sizes());
 
         if (padding_dim > 1) {
           TORCH_CHECK(pad_t < input_h && pad_b < input_h,
@@ -147,7 +147,7 @@
                       ") at dimension ",
                       dim_h,
                       " of input ",
-                      ndims);
+                      input_.sizes());
         }
         if (padding_dim > 2) {
           TORCH_CHECK(pad_front < input_d && pad_back < input_d,
@@ -159,7 +159,7 @@
                       ") at dimension ",
                       dim_d,
                       " of input ",
-                      ndims);
+                      input_.sizes());
         }
       }
       outputSizes.insert(outputSizes.begin(), output_w);
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index ddedd05..003dc69 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -2845,8 +2845,6 @@
     xfail('nn.functional.multi_margin_loss', ''),  # could not find kernel
     xfail('nn.functional.multilabel_margin_loss', ''),  # could not find kernel
     xfail('nn.functional.nll_loss', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
-    xfail('nn.functional.pad', 'reflect'),  # aten.reflection_pad1d.default - couldn't find symbolic meta fu...
-    xfail('nn.functional.pad', 'replicate'),  # aten.replication_pad1d.default - couldn't find symbolic meta...
     xfail('nn.functional.pdist', ''),  # could not find kernel
     xfail('nn.functional.pixel_shuffle', ''),  # aten.pixel_shuffle.default - couldn't find symbolic meta fun...
     xfail('nn.functional.pixel_unshuffle', ''),  # aten.pixel_unshuffle.default - couldn't find symbolic meta...
@@ -2982,11 +2980,6 @@
     torch.nn.Transformer,  # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool
     torch.nn.TransformerEncoder,  # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool
     torch.nn.GaussianNLLLoss,  # NotImplementedError: local_scalar_dense/item NYI for torch.bool
-    torch.nn.ReplicationPad1d,  # Cannot call sizes() on tensor with symbolic sizes/strides
-    torch.nn.ReplicationPad2d,  # Cannot call sizes() on tensor with symbolic sizes/strides
-    torch.nn.ReplicationPad3d,  # Cannot call sizes() on tensor with symbolic sizes/strides
-    torch.nn.ReflectionPad1d,  # Cannot call sizes() on tensor with symbolic sizes/strides
-    torch.nn.ReflectionPad3d,  # Cannot call sizes() on tensor with symbolic sizes/strides
     torch.nn.AdaptiveAvgPool3d,  # could not find kernel for aten._adaptive_avg_pool3d_backward.default at dispatch key
                                  # DispatchKey.Meta
     torch.nn.AdaptiveMaxPool1d,  # Cannot call sizes() on tensor with symbolic sizes/strides
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 862ca36..9c9acce 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1517,8 +1517,6 @@
     xfail('nn.functional.max_unpool3d', 'grad'),  # aten.max_unpool3d.default - couldn't find symbolic meta function/decom...
     xfail('nn.functional.multi_margin_loss', ''),  # Could not run 'aten::multi_margin_loss' with arguments from the...
     xfail('nn.functional.multilabel_margin_loss', ''),  # Could not run 'aten::multilabel_margin_loss_forward' with ...
-    xfail('nn.functional.pad', 'reflect'),  # aten.reflection_pad1d.default - couldn't find symbolic meta function/decompo...
-    xfail('nn.functional.pad', 'replicate'),  # aten.replication_pad1d.default - couldn't find symbolic meta function/deco...
     xfail('nn.functional.pdist', ''),  # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend...
     xfail('nn.functional.pixel_unshuffle', ''),  # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...
     xfail('normal', 'number_mean'),  # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 5277a7a..00d644c 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -1077,13 +1077,203 @@
     return det, LU, pivots
 
 
-# From aten/src/ATen/native/ReflectionPad.cpp
+def _padding_check_valid_input(input, padding, *, dim):
+    torch._check(
+        len(padding) == 2 * dim,
+        lambda: f"padding size is expected to be {2 * dim}, but got: {len(padding)}",
+    )
+
+    input_dim = input.ndim
+
+    is_batch_mode = input_dim == (dim + 2)
+
+    valid_batch_mode = is_batch_mode
+    valid_non_batch_mode = not is_batch_mode
+
+    if is_batch_mode:
+        # allow batch size of 0-dim.
+        for d in range(1, input_dim):
+            valid_batch_mode = valid_batch_mode and input.size(d) != 0
+    else:
+        for d in range(0, input_dim):
+            valid_non_batch_mode = valid_non_batch_mode and input.size(d) != 0
+
+    # allow empty batch size but not other dimensions.
+    torch._check(
+        valid_batch_mode or valid_non_batch_mode,
+        lambda: (
+            f"Expected {dim + 1}D or {dim + 2}D (batch mode) tensor with possibly 0 batch size "
+            f"and other non-zero dimensions for input, but got: {input.shape}"
+        ),
+    )
+
+
+def _pad1d_common(input, padding, *, is_reflection):
+    dim_plane = 0
+    dim_w = 1
+    nbatch = 1
+
+    if input.ndim == 3:
+        nbatch = input.size(0)
+        dim_w += 1
+        dim_plane += 1
+
+    _padding_check_valid_input(input, padding, dim=1)
+
+    pad_l, pad_r = padding
+
+    nplane = input.size(dim_plane)
+    input_w = input.size(dim_w)
+    output_w = input_w + pad_l + pad_r
+
+    if is_reflection:
+        torch._check(
+            pad_l < input_w and pad_r < input_w,
+            lambda: (
+                f"Argument #4: Padding size should be less than the corresponding input dimension, "
+                f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
+            ),
+        )
+
+    torch._check(
+        output_w >= 1,
+        lambda: f"input (W: {input_w}) is too small. Calculated output W: {output_w}",
+    )
+
+    if input.ndim == 2:
+        return input.new_empty((nplane, output_w))
+    else:
+        return input.new_empty((nbatch, nplane, output_w))
+
+
+@register_meta(aten.reflection_pad1d)
+@out_wrapper()
+def meta_reflection_pad1d(input, padding):
+    return _pad1d_common(input, padding, is_reflection=True)
+
+
+@register_meta(aten.replication_pad1d)
+@out_wrapper()
+def meta_replication_pad1d(input, padding):
+    return _pad1d_common(input, padding, is_reflection=False)
+
+
+def _pad1d_backward_common(grad_output, input, padding, *, is_reflection):
+    dim_w = 1
+    if not is_reflection:
+        torch._check(len(padding) == 2, lambda: "padding size is expected to be 2")
+
+    if input.ndim == 3:
+        dim_w += 1
+
+    pad_l, pad_r = padding
+
+    input_w = input.size(dim_w)
+    output_w = input_w + pad_l + pad_r
+
+    if is_reflection:
+        torch._check(
+            pad_l < input_w and pad_r < input_w,
+            lambda: (
+                f"Argument #4: Padding size should be less than the corresponding input dimension, "
+                f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
+            ),
+        )
+
+    torch._check(
+        output_w == grad_output.size(dim_w),
+        lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
+    )
+
+    return input.new_empty(input.shape)
+
+
+@register_meta(aten.reflection_pad1d_backward)
+@out_wrapper()
+def meta_reflection_pad1d_backward(grad_output, input, padding):
+    return _pad1d_backward_common(grad_output, input, padding, is_reflection=True)
+
+
+@register_meta(aten.replication_pad1d_backward)
+@out_wrapper()
+def meta_replication_pad1d_backward(grad_output, input, padding):
+    return _pad1d_backward_common(grad_output, input, padding, is_reflection=False)
+
+
+def _pad2d_common(input, padding, *, is_reflection):
+    dim_w = 2
+    dim_h = 1
+    dim_slices = 0
+    nbatch = 1
+
+    _padding_check_valid_input(input, padding, dim=2)
+
+    ndim = input.ndim
+    if ndim == 4:
+        nbatch = input.size(0)
+        dim_w += 1
+        dim_h += 1
+        dim_slices += 1
+
+    pad_l, pad_r, pad_t, pad_b = padding
+
+    nplane = input.size(dim_slices)
+    input_h = input.size(dim_h)
+    input_w = input.size(dim_w)
+    output_h = input_h + pad_t + pad_b
+    output_w = input_w + pad_l + pad_r
+
+    if is_reflection:
+        torch._check(
+            pad_l < input_w and pad_r < input_w,
+            lambda: (
+                f"Argument #4: Padding size should be less than the corresponding input dimension, "
+                f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
+            ),
+        )
+        torch._check(
+            pad_t < input_h and pad_b < input_h,
+            lambda: (
+                f"Argument #6: Padding size should be less than the corresponding input dimension, "
+                f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}"
+            ),
+        )
+
+    torch._check(
+        output_w >= 1 or output_h >= 1,
+        lambda: (
+            f"input (H: {input_h} W: {input_w}) is too small. "
+            f"Calculated output H: {output_h} W: {output_w}"
+        ),
+    )
+
+    if input.ndim == 3:
+        return input.new_empty((nplane, output_h, output_w))
+    else:
+        return input.new_empty((nbatch, nplane, output_h, output_w))
+
+
+@register_meta(aten.reflection_pad2d)
+@out_wrapper()
+def meta_reflection_pad2d(input, padding):
+    return _pad2d_common(input, padding, is_reflection=True)
+
+
+@register_meta(aten.replication_pad2d)
+@out_wrapper()
+def meta_replication_pad2d(input, padding):
+    return _pad2d_common(input, padding, is_reflection=False)
+
+
 @register_meta(
     [
         aten.reflection_pad2d_backward.default,
+        aten.reflection_pad2d_backward.grad_input,
         aten.replication_pad2d_backward.default,
+        aten.replication_pad2d_backward.grad_input,
     ]
 )
+@out_wrapper()
 def meta_pad2d_backward(grad_output, self, padding):
     dim_w = 2
     dim_h = 1
@@ -1097,10 +1287,7 @@
         dim_h += 1
         dim_plane += 1
 
-    pad_l = padding[0]
-    pad_r = padding[1]
-    pad_t = padding[2]
-    pad_b = padding[3]
+    pad_l, pad_r, pad_t, pad_b = padding
 
     nplane = self_shape[dim_plane]
     input_h = self_shape[dim_h]
@@ -1109,39 +1296,137 @@
     output_w = input_w + pad_l + pad_r
 
     torch._check(
-        output_w == grad_output.shape[dim_w],
-        lambda: f"gradOutput width unexpected. Expected: {output_w}, Got: {grad_output.shape[dim_w]}",
+        output_w == grad_output.size(dim_w),
+        lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
     )
     torch._check(
-        output_h == grad_output.shape[dim_h],
-        lambda: f"gradOutput height unexpected. Expected: {output_h}, Got: {grad_output.shape[dim_h]}",
+        output_h == grad_output.size(dim_h),
+        lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}",
     )
     return self.new_empty(self.shape)
 
 
-@register_meta(aten.reflection_pad2d.default)
-def meta_pad2d(self, padding):
-    valid_dims = self.size(1) != 0 and self.size(2) != 0
-    torch._check(
-        (self.ndim == 3 and valid_dims)
-        or (self.ndim == 4 and valid_dims and self.size(3) != 0),
-        lambda: f"3D or 4D (batch mode) tensor expected for input, but got: {self}",
-    )
-    if self.ndim == 4:
-        nbatch, nplane, input_h, input_w = self.shape
-    else:
-        nbatch = 1
-        nplane, input_h, input_w = self.shape
+def _pad3d_common(input, padding, *, is_reflection):
+    dim_w = 3
+    dim_h = 2
+    dim_d = 1
+    dim_plane = 0
 
-    pad_l, pad_r, pad_t, pad_b = padding
+    _padding_check_valid_input(input, padding, dim=3)
 
+    batch_mode = input.ndim == 5
+    if batch_mode:
+        nbatch = input.size(0)
+        dim_w += 1
+        dim_h += 1
+        dim_d += 1
+        dim_plane += 1
+
+    pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding
+
+    nplane = input.size(dim_plane)
+    input_d = input.size(dim_d)
+    input_h = input.size(dim_h)
+    input_w = input.size(dim_w)
+    output_d = input_d + pad_f + pad_bk
     output_h = input_h + pad_t + pad_b
     output_w = input_w + pad_l + pad_r
 
-    if self.ndim == 3:
-        return self.new_empty((nplane, output_h, output_w))
+    if is_reflection:
+        torch._check(
+            pad_l < input_w and pad_r < input_w,
+            lambda: (
+                f"Argument #4: Padding size should be less than the corresponding input dimension, "
+                f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
+            ),
+        )
+        torch._check(
+            pad_t < input_h and pad_b < input_h,
+            lambda: (
+                f"Argument #6: Padding size should be less than the corresponding input dimension, "
+                f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}"
+            ),
+        )
+        torch._check(
+            pad_f < input_d and pad_bk < input_d,
+            lambda: (
+                f"Argument #8: Padding size should be less than the corresponding input dimension, "
+                f"but got: padding ({pad_f}, {pad_bk}) at dimension {dim_d} of input {input.shape}"
+            ),
+        )
+
+    torch._check(
+        output_w >= 1 or output_h >= 1 or output_d >= 1,
+        lambda: (
+            f"input (D: {input_d} H: {input_h} W: {input_w}) is too small. "
+            f"Calculated output D: {output_d} H: {output_h} W: {output_w}"
+        ),
+    )
+
+    if batch_mode:
+        return input.new_empty((nbatch, nplane, output_d, output_h, output_w))
     else:
-        return self.new_empty((nbatch, nplane, output_h, output_w))
+        return input.new_empty((nplane, output_d, output_h, output_w))
+
+
+@register_meta(aten.reflection_pad3d)
+@out_wrapper()
+def meta_reflection_pad3d(input, padding):
+    return _pad3d_common(input, padding, is_reflection=True)
+
+
+@register_meta(aten.replication_pad3d)
+@out_wrapper()
+def meta_replication_pad3d(input, padding):
+    return _pad3d_common(input, padding, is_reflection=False)
+
+
+@register_meta(
+    [
+        aten.reflection_pad3d_backward.default,
+        aten.reflection_pad3d_backward.grad_input,
+        aten.replication_pad3d_backward.default,
+        aten.replication_pad3d_backward.grad_input,
+    ]
+)
+@out_wrapper()
+def meta_pad3d_backward(grad_output, input, padding):
+    torch._check(len(padding) == 6, lambda: "padding size is expected to be 6")
+    assert input.ndim > 3
+    assert grad_output.ndim == input.ndim
+
+    dim_w = 3
+    dim_h = 2
+    dim_d = 1
+
+    if input.ndim == 5:
+        dim_w += 1
+        dim_h += 1
+        dim_d += 1
+
+    pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding
+
+    input_d = input.size(dim_d)
+    input_h = input.size(dim_h)
+    input_w = input.size(dim_w)
+    output_d = input_d + pad_f + pad_bk
+    output_h = input_h + pad_t + pad_b
+    output_w = input_w + pad_l + pad_r
+
+    torch._check(
+        output_w == grad_output.size(dim_w),
+        lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
+    )
+    torch._check(
+        output_h == grad_output.size(dim_h),
+        lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}",
+    )
+    torch._check(
+        output_d == grad_output.size(dim_d),
+        lambda: f"grad_output depth unexpected. Expected: {output_d}, Got: {grad_output.size(dim_d)}",
+    )
+
+    return input.new_empty(input.shape)
 
 
 @register_meta([aten.baddbmm.default, aten.baddbmm.out])