[pt2] add metas for `pad` ops (#103815)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103815
Approved by: https://github.com/ezyang
diff --git a/aten/src/ATen/native/ReflectionPad.cpp b/aten/src/ATen/native/ReflectionPad.cpp
index 9252b36..25b3f8e 100644
--- a/aten/src/ATen/native/ReflectionPad.cpp
+++ b/aten/src/ATen/native/ReflectionPad.cpp
@@ -61,10 +61,9 @@
TORCH_CHECK(
output_w >= 1,
- 2,
"input (W: ",
input_w,
- ")is too small. Calculated output W: ",
+ ") is too small. Calculated output W: ",
output_w);
if (input.ndimension() == 2) {
@@ -202,7 +201,7 @@
TORCH_CHECK(output_h == grad_output.size(dim_h), "grad_output height unexpected."
" Expected: ", output_h, ", Got: ", grad_output.size(dim_h));
TORCH_CHECK(output_d == grad_output.size(dim_d), "grad_output depth unexpected."
- " Expected: ", output_h, ", Got: ", grad_output.size(dim_d));
+ " Expected: ", output_d, ", Got: ", grad_output.size(dim_d));
set_output_raw_strided(0, input.sizes(), {}, input.options());
}
@@ -244,15 +243,15 @@
TORCH_CHECK(pad_l < input_w && pad_r < input_w,
"Argument #4: Padding size should be less than the corresponding "
"input dimension, but got: padding (", pad_l, ", ", pad_r,
- ") at dimension ", dim_w, " of input ", ndim);
+ ") at dimension ", dim_w, " of input ", input.sizes());
TORCH_CHECK(pad_t < input_h && pad_b < input_h,
"Argument #6: Padding size should be less than the corresponding "
"input dimension, but got: padding (", pad_t, ", ", pad_b,
- ") at dimension ", dim_h, " of input ", ndim);
+ ") at dimension ", dim_h, " of input ", input.sizes());
TORCH_CHECK(output_w >= 1 || output_h >= 1,
- "input (H: ", input_h, ", W: ", input_w, ")is too small. Calculated "
+ "input (H: ", input_h, ", W: ", input_w, ") is too small. Calculated "
"output H: ", output_h, " W: ", output_w);
/* resize output */
diff --git a/aten/src/ATen/native/cuda/ReflectionPad.cu b/aten/src/ATen/native/cuda/ReflectionPad.cu
index 4372e6c..4bb1bdc 100644
--- a/aten/src/ATen/native/cuda/ReflectionPad.cu
+++ b/aten/src/ATen/native/cuda/ReflectionPad.cu
@@ -311,7 +311,7 @@
int output_w = input_w + pad_l + pad_r;
TORCH_CHECK(output_w >= 1 || output_h >= 1,
- "input (H: ", input_h, ", W: ", input_w, ")is too small. Calculated "
+ "input (H: ", input_h, ", W: ", input_w, ") is too small. Calculated "
"output H: ", output_h, " W: ", output_w);
if (input_.ndimension() == 3) {
diff --git a/aten/src/ATen/native/mps/operations/Pad.mm b/aten/src/ATen/native/mps/operations/Pad.mm
index e61d681..b608e68 100644
--- a/aten/src/ATen/native/mps/operations/Pad.mm
+++ b/aten/src/ATen/native/mps/operations/Pad.mm
@@ -135,7 +135,7 @@
") at dimension ",
dim_w,
" of input ",
- ndims);
+ input_.sizes());
if (padding_dim > 1) {
TORCH_CHECK(pad_t < input_h && pad_b < input_h,
@@ -147,7 +147,7 @@
") at dimension ",
dim_h,
" of input ",
- ndims);
+ input_.sizes());
}
if (padding_dim > 2) {
TORCH_CHECK(pad_front < input_d && pad_back < input_d,
@@ -159,7 +159,7 @@
") at dimension ",
dim_d,
" of input ",
- ndims);
+ input_.sizes());
}
}
outputSizes.insert(outputSizes.begin(), output_w);
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index ddedd05..003dc69 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -2845,8 +2845,6 @@
xfail('nn.functional.multi_margin_loss', ''), # could not find kernel
xfail('nn.functional.multilabel_margin_loss', ''), # could not find kernel
xfail('nn.functional.nll_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
- xfail('nn.functional.pad', 'reflect'), # aten.reflection_pad1d.default - couldn't find symbolic meta fu...
- xfail('nn.functional.pad', 'replicate'), # aten.replication_pad1d.default - couldn't find symbolic meta...
xfail('nn.functional.pdist', ''), # could not find kernel
xfail('nn.functional.pixel_shuffle', ''), # aten.pixel_shuffle.default - couldn't find symbolic meta fun...
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta...
@@ -2982,11 +2980,6 @@
torch.nn.Transformer, # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool
torch.nn.TransformerEncoder, # DataDependentOutputException: aten.equal compares a mask input to a mask producing a bool
torch.nn.GaussianNLLLoss, # NotImplementedError: local_scalar_dense/item NYI for torch.bool
- torch.nn.ReplicationPad1d, # Cannot call sizes() on tensor with symbolic sizes/strides
- torch.nn.ReplicationPad2d, # Cannot call sizes() on tensor with symbolic sizes/strides
- torch.nn.ReplicationPad3d, # Cannot call sizes() on tensor with symbolic sizes/strides
- torch.nn.ReflectionPad1d, # Cannot call sizes() on tensor with symbolic sizes/strides
- torch.nn.ReflectionPad3d, # Cannot call sizes() on tensor with symbolic sizes/strides
torch.nn.AdaptiveAvgPool3d, # could not find kernel for aten._adaptive_avg_pool3d_backward.default at dispatch key
# DispatchKey.Meta
torch.nn.AdaptiveMaxPool1d, # Cannot call sizes() on tensor with symbolic sizes/strides
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 862ca36..9c9acce 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1517,8 +1517,6 @@
xfail('nn.functional.max_unpool3d', 'grad'), # aten.max_unpool3d.default - couldn't find symbolic meta function/decom...
xfail('nn.functional.multi_margin_loss', ''), # Could not run 'aten::multi_margin_loss' with arguments from the...
xfail('nn.functional.multilabel_margin_loss', ''), # Could not run 'aten::multilabel_margin_loss_forward' with ...
- xfail('nn.functional.pad', 'reflect'), # aten.reflection_pad1d.default - couldn't find symbolic meta function/decompo...
- xfail('nn.functional.pad', 'replicate'), # aten.replication_pad1d.default - couldn't find symbolic meta function/deco...
xfail('nn.functional.pdist', ''), # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend...
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...
xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 5277a7a..00d644c 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -1077,13 +1077,203 @@
return det, LU, pivots
-# From aten/src/ATen/native/ReflectionPad.cpp
+def _padding_check_valid_input(input, padding, *, dim):
+ torch._check(
+ len(padding) == 2 * dim,
+ lambda: f"padding size is expected to be {2 * dim}, but got: {len(padding)}",
+ )
+
+ input_dim = input.ndim
+
+ is_batch_mode = input_dim == (dim + 2)
+
+ valid_batch_mode = is_batch_mode
+ valid_non_batch_mode = not is_batch_mode
+
+ if is_batch_mode:
+ # allow batch size of 0-dim.
+ for d in range(1, input_dim):
+ valid_batch_mode = valid_batch_mode and input.size(d) != 0
+ else:
+ for d in range(0, input_dim):
+ valid_non_batch_mode = valid_non_batch_mode and input.size(d) != 0
+
+ # allow empty batch size but not other dimensions.
+ torch._check(
+ valid_batch_mode or valid_non_batch_mode,
+ lambda: (
+ f"Expected {dim + 1}D or {dim + 2}D (batch mode) tensor with possibly 0 batch size "
+ f"and other non-zero dimensions for input, but got: {input.shape}"
+ ),
+ )
+
+
+def _pad1d_common(input, padding, *, is_reflection):
+ dim_plane = 0
+ dim_w = 1
+ nbatch = 1
+
+ if input.ndim == 3:
+ nbatch = input.size(0)
+ dim_w += 1
+ dim_plane += 1
+
+ _padding_check_valid_input(input, padding, dim=1)
+
+ pad_l, pad_r = padding
+
+ nplane = input.size(dim_plane)
+ input_w = input.size(dim_w)
+ output_w = input_w + pad_l + pad_r
+
+ if is_reflection:
+ torch._check(
+ pad_l < input_w and pad_r < input_w,
+ lambda: (
+ f"Argument #4: Padding size should be less than the corresponding input dimension, "
+ f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
+ ),
+ )
+
+ torch._check(
+ output_w >= 1,
+ lambda: f"input (W: {input_w}) is too small. Calculated output W: {output_w}",
+ )
+
+ if input.ndim == 2:
+ return input.new_empty((nplane, output_w))
+ else:
+ return input.new_empty((nbatch, nplane, output_w))
+
+
+@register_meta(aten.reflection_pad1d)
+@out_wrapper()
+def meta_reflection_pad1d(input, padding):
+ return _pad1d_common(input, padding, is_reflection=True)
+
+
+@register_meta(aten.replication_pad1d)
+@out_wrapper()
+def meta_replication_pad1d(input, padding):
+ return _pad1d_common(input, padding, is_reflection=False)
+
+
+def _pad1d_backward_common(grad_output, input, padding, *, is_reflection):
+ dim_w = 1
+ if not is_reflection:
+ torch._check(len(padding) == 2, lambda: "padding size is expected to be 2")
+
+ if input.ndim == 3:
+ dim_w += 1
+
+ pad_l, pad_r = padding
+
+ input_w = input.size(dim_w)
+ output_w = input_w + pad_l + pad_r
+
+ if is_reflection:
+ torch._check(
+ pad_l < input_w and pad_r < input_w,
+ lambda: (
+ f"Argument #4: Padding size should be less than the corresponding input dimension, "
+ f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
+ ),
+ )
+
+ torch._check(
+ output_w == grad_output.size(dim_w),
+ lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
+ )
+
+ return input.new_empty(input.shape)
+
+
+@register_meta(aten.reflection_pad1d_backward)
+@out_wrapper()
+def meta_reflection_pad1d_backward(grad_output, input, padding):
+ return _pad1d_backward_common(grad_output, input, padding, is_reflection=True)
+
+
+@register_meta(aten.replication_pad1d_backward)
+@out_wrapper()
+def meta_replication_pad1d_backward(grad_output, input, padding):
+ return _pad1d_backward_common(grad_output, input, padding, is_reflection=False)
+
+
+def _pad2d_common(input, padding, *, is_reflection):
+ dim_w = 2
+ dim_h = 1
+ dim_slices = 0
+ nbatch = 1
+
+ _padding_check_valid_input(input, padding, dim=2)
+
+ ndim = input.ndim
+ if ndim == 4:
+ nbatch = input.size(0)
+ dim_w += 1
+ dim_h += 1
+ dim_slices += 1
+
+ pad_l, pad_r, pad_t, pad_b = padding
+
+ nplane = input.size(dim_slices)
+ input_h = input.size(dim_h)
+ input_w = input.size(dim_w)
+ output_h = input_h + pad_t + pad_b
+ output_w = input_w + pad_l + pad_r
+
+ if is_reflection:
+ torch._check(
+ pad_l < input_w and pad_r < input_w,
+ lambda: (
+ f"Argument #4: Padding size should be less than the corresponding input dimension, "
+ f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
+ ),
+ )
+ torch._check(
+ pad_t < input_h and pad_b < input_h,
+ lambda: (
+ f"Argument #6: Padding size should be less than the corresponding input dimension, "
+ f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}"
+ ),
+ )
+
+ torch._check(
+ output_w >= 1 or output_h >= 1,
+ lambda: (
+ f"input (H: {input_h} W: {input_w}) is too small. "
+ f"Calculated output H: {output_h} W: {output_w}"
+ ),
+ )
+
+ if input.ndim == 3:
+ return input.new_empty((nplane, output_h, output_w))
+ else:
+ return input.new_empty((nbatch, nplane, output_h, output_w))
+
+
+@register_meta(aten.reflection_pad2d)
+@out_wrapper()
+def meta_reflection_pad2d(input, padding):
+ return _pad2d_common(input, padding, is_reflection=True)
+
+
+@register_meta(aten.replication_pad2d)
+@out_wrapper()
+def meta_replication_pad2d(input, padding):
+ return _pad2d_common(input, padding, is_reflection=False)
+
+
@register_meta(
[
aten.reflection_pad2d_backward.default,
+ aten.reflection_pad2d_backward.grad_input,
aten.replication_pad2d_backward.default,
+ aten.replication_pad2d_backward.grad_input,
]
)
+@out_wrapper()
def meta_pad2d_backward(grad_output, self, padding):
dim_w = 2
dim_h = 1
@@ -1097,10 +1287,7 @@
dim_h += 1
dim_plane += 1
- pad_l = padding[0]
- pad_r = padding[1]
- pad_t = padding[2]
- pad_b = padding[3]
+ pad_l, pad_r, pad_t, pad_b = padding
nplane = self_shape[dim_plane]
input_h = self_shape[dim_h]
@@ -1109,39 +1296,137 @@
output_w = input_w + pad_l + pad_r
torch._check(
- output_w == grad_output.shape[dim_w],
- lambda: f"gradOutput width unexpected. Expected: {output_w}, Got: {grad_output.shape[dim_w]}",
+ output_w == grad_output.size(dim_w),
+ lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
)
torch._check(
- output_h == grad_output.shape[dim_h],
- lambda: f"gradOutput height unexpected. Expected: {output_h}, Got: {grad_output.shape[dim_h]}",
+ output_h == grad_output.size(dim_h),
+ lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}",
)
return self.new_empty(self.shape)
-@register_meta(aten.reflection_pad2d.default)
-def meta_pad2d(self, padding):
- valid_dims = self.size(1) != 0 and self.size(2) != 0
- torch._check(
- (self.ndim == 3 and valid_dims)
- or (self.ndim == 4 and valid_dims and self.size(3) != 0),
- lambda: f"3D or 4D (batch mode) tensor expected for input, but got: {self}",
- )
- if self.ndim == 4:
- nbatch, nplane, input_h, input_w = self.shape
- else:
- nbatch = 1
- nplane, input_h, input_w = self.shape
+def _pad3d_common(input, padding, *, is_reflection):
+ dim_w = 3
+ dim_h = 2
+ dim_d = 1
+ dim_plane = 0
- pad_l, pad_r, pad_t, pad_b = padding
+ _padding_check_valid_input(input, padding, dim=3)
+ batch_mode = input.ndim == 5
+ if batch_mode:
+ nbatch = input.size(0)
+ dim_w += 1
+ dim_h += 1
+ dim_d += 1
+ dim_plane += 1
+
+ pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding
+
+ nplane = input.size(dim_plane)
+ input_d = input.size(dim_d)
+ input_h = input.size(dim_h)
+ input_w = input.size(dim_w)
+ output_d = input_d + pad_f + pad_bk
output_h = input_h + pad_t + pad_b
output_w = input_w + pad_l + pad_r
- if self.ndim == 3:
- return self.new_empty((nplane, output_h, output_w))
+ if is_reflection:
+ torch._check(
+ pad_l < input_w and pad_r < input_w,
+ lambda: (
+ f"Argument #4: Padding size should be less than the corresponding input dimension, "
+ f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
+ ),
+ )
+ torch._check(
+ pad_t < input_h and pad_b < input_h,
+ lambda: (
+ f"Argument #6: Padding size should be less than the corresponding input dimension, "
+ f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}"
+ ),
+ )
+ torch._check(
+ pad_f < input_d and pad_bk < input_d,
+ lambda: (
+ f"Argument #8: Padding size should be less than the corresponding input dimension, "
+ f"but got: padding ({pad_f}, {pad_bk}) at dimension {dim_d} of input {input.shape}"
+ ),
+ )
+
+ torch._check(
+ output_w >= 1 or output_h >= 1 or output_d >= 1,
+ lambda: (
+ f"input (D: {input_d} H: {input_h} W: {input_w}) is too small. "
+ f"Calculated output D: {output_d} H: {output_h} W: {output_w}"
+ ),
+ )
+
+ if batch_mode:
+ return input.new_empty((nbatch, nplane, output_d, output_h, output_w))
else:
- return self.new_empty((nbatch, nplane, output_h, output_w))
+ return input.new_empty((nplane, output_d, output_h, output_w))
+
+
+@register_meta(aten.reflection_pad3d)
+@out_wrapper()
+def meta_reflection_pad3d(input, padding):
+ return _pad3d_common(input, padding, is_reflection=True)
+
+
+@register_meta(aten.replication_pad3d)
+@out_wrapper()
+def meta_replication_pad3d(input, padding):
+ return _pad3d_common(input, padding, is_reflection=False)
+
+
+@register_meta(
+ [
+ aten.reflection_pad3d_backward.default,
+ aten.reflection_pad3d_backward.grad_input,
+ aten.replication_pad3d_backward.default,
+ aten.replication_pad3d_backward.grad_input,
+ ]
+)
+@out_wrapper()
+def meta_pad3d_backward(grad_output, input, padding):
+ torch._check(len(padding) == 6, lambda: "padding size is expected to be 6")
+ assert input.ndim > 3
+ assert grad_output.ndim == input.ndim
+
+ dim_w = 3
+ dim_h = 2
+ dim_d = 1
+
+ if input.ndim == 5:
+ dim_w += 1
+ dim_h += 1
+ dim_d += 1
+
+ pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding
+
+ input_d = input.size(dim_d)
+ input_h = input.size(dim_h)
+ input_w = input.size(dim_w)
+ output_d = input_d + pad_f + pad_bk
+ output_h = input_h + pad_t + pad_b
+ output_w = input_w + pad_l + pad_r
+
+ torch._check(
+ output_w == grad_output.size(dim_w),
+ lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
+ )
+ torch._check(
+ output_h == grad_output.size(dim_h),
+ lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}",
+ )
+ torch._check(
+ output_d == grad_output.size(dim_d),
+ lambda: f"grad_output depth unexpected. Expected: {output_d}, Got: {grad_output.size(dim_d)}",
+ )
+
+ return input.new_empty(input.shape)
@register_meta([aten.baddbmm.default, aten.baddbmm.out])