Remove F.pad python implementation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73433
Approved by: https://github.com/albanD, https://github.com/jbschlosser
diff --git a/test/onnx/expect/TestOperators.test_pad.expect b/test/onnx/expect/TestOperators.test_pad.expect
index 0a25fb0..293877a 100644
--- a/test/onnx/expect/TestOperators.test_pad.expect
+++ b/test/onnx/expect/TestOperators.test_pad.expect
@@ -161,7 +161,7 @@
}
}
node {
- input: "input"
+ input: "onnx::Pad_0"
input: "onnx::Pad_22"
output: "23"
name: "Pad_13"
@@ -186,7 +186,7 @@
raw_data: "\002\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000"
}
input {
- name: "input"
+ name: "onnx::Pad_0"
type {
tensor_type {
elem_type: 1
diff --git a/test/test_fx.py b/test/test_fx.py
index 4a5436b..6710c21 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -3810,6 +3810,7 @@
"linear": BUILT_IN_FUNC,
"logsigmoid": BUILT_IN_FUNC,
"one_hot": BUILT_IN_FUNC,
+ "pad": BUILT_IN_FUNC,
"pairwise_distance": BUILT_IN_FUNC,
"pdist": BUILT_IN_FUNC,
"pixel_shuffle": BUILT_IN_FUNC,
@@ -3827,7 +3828,6 @@
"adaptive_max_pool2d_with_indices": LEN_ERROR,
"adaptive_max_pool3d_with_indices": LEN_ERROR,
"instance_norm": CONTROL_FLOW,
- "pad": LEN_ERROR,
"adaptive_max_pool1d": PROXY_ITERABLE,
"adaptive_max_pool2d": PROXY_ITERABLE,
diff --git a/test/test_jit.py b/test/test_jit.py
index 37d9ec5..cd6bc43 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -15513,7 +15513,7 @@
self.assertEqual(m.int64_min, imported.int64_min)
def test_script_scope(self):
- scripted = torch.jit.script(torch.nn.functional.pad)
+ scripted = torch.jit.script(torch.nn.functional.triplet_margin_loss)
@unittest.skipIf(IS_WINDOWS, "NYI: TemporaryFileName on Windows")
def test_serialization_sharing(self):
diff --git a/test/test_nn.py b/test/test_nn.py
index f71c234..ed1251a 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -5430,8 +5430,8 @@
def test_pad_scalar_error(self):
inputs = torch.tensor(0., requires_grad=True)
- self.assertRaises(AssertionError, lambda: F.pad(inputs, (1, 1)))
- self.assertRaises(AssertionError, lambda: F.pad(inputs, (1,)))
+ self.assertRaises(RuntimeError, lambda: F.pad(inputs, (1, 1)))
+ self.assertRaises(RuntimeError, lambda: F.pad(inputs, (1,)))
@unittest.skipIf(not TEST_NUMPY, "numpy not found")
@parametrize_test("average_attn_weights", [True, False])
@@ -14345,10 +14345,10 @@
# Assert assertion errors are raised for invalid circular padding values
inputs = torch.randn(1, 1, 4, device=device, dtype=dtype, requires_grad=True)
# Should raise error when trying to wrap around more than once
- self.assertRaises(AssertionError, lambda: F.pad(inputs, (5, 4), mode='circular'))
- self.assertRaises(AssertionError, lambda: F.pad(inputs, (3, 6), mode='circular'))
+ self.assertRaises(RuntimeError, lambda: F.pad(inputs, (5, 4), mode='circular'))
+ self.assertRaises(RuntimeError, lambda: F.pad(inputs, (3, 6), mode='circular'))
# Should raise error when negative padding results in negative output shape
- self.assertRaises(AssertionError, lambda: F.pad(inputs, (-3, -2), mode='circular'))
+ self.assertRaises(RuntimeError, lambda: F.pad(inputs, (-3, -2), mode='circular'))
# assert that relfection padding errors when pad >= input size
expected_err_msg = r"Padding size should be less than the corresponding input dimension"
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index e8734ff..cd42554 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -4263,111 +4263,70 @@
return torch.affine_grid_generator(theta, size, align_corners)
-# NOTE: Do not edit. This code will be removed once the forward-compatibility
-# period is over for PR #73431
-def _pad(input: Tensor, pad: BroadcastingList1[int], mode: str = "constant", value: Union[int, float] = 0.0) -> Tensor:
- r"""Pads tensor.
+pad = _add_docstr(
+ torch._C._nn.pad,
+ r"""
+pad(input, pad, mode="constant", value=None) -> Tensor
- Padding size:
- The padding size by which to pad some dimensions of :attr:`input`
- are described starting from the last dimension and moving forward.
- :math:`\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor` dimensions
- of ``input`` will be padded.
- For example, to pad only the last dimension of the input tensor, then
- :attr:`pad` has the form
- :math:`(\text{padding\_left}, \text{padding\_right})`;
- to pad the last 2 dimensions of the input tensor, then use
- :math:`(\text{padding\_left}, \text{padding\_right},`
- :math:`\text{padding\_top}, \text{padding\_bottom})`;
- to pad the last 3 dimensions, use
- :math:`(\text{padding\_left}, \text{padding\_right},`
- :math:`\text{padding\_top}, \text{padding\_bottom}`
- :math:`\text{padding\_front}, \text{padding\_back})`.
+Pads tensor.
- Padding mode:
- See :class:`torch.nn.ConstantPad2d`, :class:`torch.nn.ReflectionPad2d`, and
- :class:`torch.nn.ReplicationPad2d` for concrete examples on how each of the
- padding modes works. Constant padding is implemented for arbitrary dimensions.
- Replicate and reflection padding are implemented for padding the last 3
- dimensions of a 4D or 5D input tensor, the last 2 dimensions of a 3D
- or 4D input tensor, or the last dimension of a 2D or 3D input tensor.
+Padding size:
+ The padding size by which to pad some dimensions of :attr:`input`
+ are described starting from the last dimension and moving forward.
+ :math:`\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor` dimensions
+ of ``input`` will be padded.
+ For example, to pad only the last dimension of the input tensor, then
+ :attr:`pad` has the form
+ :math:`(\text{padding\_left}, \text{padding\_right})`;
+ to pad the last 2 dimensions of the input tensor, then use
+ :math:`(\text{padding\_left}, \text{padding\_right},`
+ :math:`\text{padding\_top}, \text{padding\_bottom})`;
+ to pad the last 3 dimensions, use
+ :math:`(\text{padding\_left}, \text{padding\_right},`
+ :math:`\text{padding\_top}, \text{padding\_bottom}`
+ :math:`\text{padding\_front}, \text{padding\_back})`.
- Note:
- When using the CUDA backend, this operation may induce nondeterministic
- behaviour in its backward pass that is not easily switched off.
- Please see the notes on :doc:`/notes/randomness` for background.
+Padding mode:
+ See :class:`torch.nn.ConstantPad2d`, :class:`torch.nn.ReflectionPad2d`, and
+ :class:`torch.nn.ReplicationPad2d` for concrete examples on how each of the
+ padding modes works. Constant padding is implemented for arbitrary dimensions.
+ Replicate and reflection padding are implemented for padding the last 3
+ dimensions of a 4D or 5D input tensor, the last 2 dimensions of a 3D
+ or 4D input tensor, or the last dimension of a 2D or 3D input tensor.
- Args:
- input (Tensor): N-dimensional tensor
- pad (tuple): m-elements tuple, where
- :math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even.
- mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
- Default: ``'constant'``
- value: fill value for ``'constant'`` padding. Default: ``0``
+Note:
+ When using the CUDA backend, this operation may induce nondeterministic
+ behaviour in its backward pass that is not easily switched off.
+ Please see the notes on :doc:`/notes/randomness` for background.
- Examples::
+Args:
+ input (Tensor): N-dimensional tensor
+ pad (tuple): m-elements tuple, where
+ :math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even.
+ mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ Default: ``'constant'``
+ value: fill value for ``'constant'`` padding. Default: ``0``
- >>> t4d = torch.empty(3, 3, 4, 2)
- >>> p1d = (1, 1) # pad last dim by 1 on each side
- >>> out = F.pad(t4d, p1d, "constant", 0) # effectively zero padding
- >>> print(out.size())
- torch.Size([3, 3, 4, 4])
- >>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2)
- >>> out = F.pad(t4d, p2d, "constant", 0)
- >>> print(out.size())
- torch.Size([3, 3, 8, 4])
- >>> t4d = torch.empty(3, 3, 4, 2)
- >>> p3d = (0, 1, 2, 1, 3, 3) # pad by (0, 1), (2, 1), and (3, 3)
- >>> out = F.pad(t4d, p3d, "constant", 0)
- >>> print(out.size())
- torch.Size([3, 9, 7, 3])
+Examples::
- """
- if has_torch_function_unary(input):
- return handle_torch_function(_pad, (input,), input, pad, mode=mode, value=value)
- assert len(pad) % 2 == 0, "Padding length must be divisible by 2"
- assert len(pad) // 2 <= input.dim(), "Padding length too large"
- if mode == "constant":
- return _VF.constant_pad_nd(input, pad, value)
- else:
- assert value == 0.0, 'Padding mode "{}"" doesn\'t take in value argument'.format(mode)
- if len(pad) == 2 and (input.dim() == 2 or input.dim() == 3):
- if mode == "reflect":
- return torch._C._nn.reflection_pad1d(input, pad)
- elif mode == "replicate":
- return torch._C._nn.replication_pad1d(input, pad)
- elif mode == "circular":
- return _pad_circular(input, pad)
- else:
- raise NotImplementedError
+ >>> t4d = torch.empty(3, 3, 4, 2)
+ >>> p1d = (1, 1) # pad last dim by 1 on each side
+ >>> out = F.pad(t4d, p1d, "constant", 0) # effectively zero padding
+ >>> print(out.size())
+ torch.Size([3, 3, 4, 4])
+ >>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2)
+ >>> out = F.pad(t4d, p2d, "constant", 0)
+ >>> print(out.size())
+ torch.Size([3, 3, 8, 4])
+ >>> t4d = torch.empty(3, 3, 4, 2)
+ >>> p3d = (0, 1, 2, 1, 3, 3) # pad by (0, 1), (2, 1), and (3, 3)
+ >>> out = F.pad(t4d, p3d, "constant", 0)
+ >>> print(out.size())
+ torch.Size([3, 9, 7, 3])
- elif len(pad) == 4 and (input.dim() == 3 or input.dim() == 4):
- if mode == "reflect":
- return torch._C._nn.reflection_pad2d(input, pad)
- elif mode == "replicate":
- return torch._C._nn.replication_pad2d(input, pad)
- elif mode == "circular":
- return _pad_circular(input, pad)
- else:
- raise NotImplementedError
-
- elif len(pad) == 6 and (input.dim() == 4 or input.dim() == 5):
- if mode == "reflect":
- return torch._C._nn.reflection_pad3d(input, pad)
- elif mode == "replicate":
- return torch._C._nn.replication_pad3d(input, pad)
- elif mode == "circular":
- return _pad_circular(input, pad)
- else:
- raise NotImplementedError
- else:
- raise NotImplementedError("Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now")
-
-
-# We define this function as _pad because it takes an argument
-# named pad, which clobbers the recursive reference to the pad
-# function needed for __torch_function__ support
-pad = _pad
+""")
+# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798
+pad.__module__ = "torch.nn.functional"
# distance
@@ -4684,174 +4643,6 @@
raise NotImplementedError("Input Error: Only unbatched (2D) or batched (3D) input Tensors"
f"are supported (got {input.dim()}D)")
-
-# NOTE: Do not edit. This code will be removed once the forward-compatibility
-# period is over for PR #73410
-def _pad_circular(input: Tensor, padding: List[int]) -> Tensor:
- """Circularly pads tensor.
-
- Tensor values at the beginning are used to pad the end, and values at the
- end are used to pad the beginning. For example, consider a single dimension
- with values [0, 1, 2, 3]. With circular padding of (1, 1) it would be
- padded to [3, 0, 1, 2, 3, 0], and with padding (1, 2) it would be padded to
- [3, 0, 1, 2, 3, 0, 1]. If negative padding is applied then the ends of the
- tensor get removed. With circular padding of (-1, -1) the previous example
- would become [1, 2]. Circular padding of (-1, 1) would produce
- [1, 2, 3, 1].
-
- The first and second dimensions of the tensor are not padded.
-
- Args:
- input: Tensor with shape :math:`(N, C, D[, H, W])`.
- padding: Tuple containing the number of elements to pad each side of
- the tensor. The length of padding must be twice the number of
- paddable dimensions. For example, the length of padding should be 4
- for a tensor of shape :math:`(N, C, H, W)`, and the length should
- be 6 for a tensor of shape :math:`(N, C, D, H, W)`.
-
- Examples::
-
- >>> x = torch.tensor([[[[0, 1, 2], [3, 4, 5]]]]) # Create tensor
- >>> # Example 1
- >>> padding = (1, 1, 1, 1)
- >>> y = F.pad(x, padding, mode='circular')
- >>> print(y)
- tensor([[[[5, 3, 4, 5, 3],
- [2, 0, 1, 2, 0],
- [5, 3, 4, 5, 3],
- [2, 0, 1, 2, 0]]]])
- >>> print(y.shape)
- torch.Size([1, 1, 4, 5])
- >>> # Example 2
- >>> padding = (1, 1, 2, 2)
- >>> z = F.pad(x, padding, mode='circular')
- >>> print(z)
- tensor([[[[2, 0, 1, 2, 0],
- [5, 3, 4, 5, 3],
- [2, 0, 1, 2, 0],
- [5, 3, 4, 5, 3],
- [2, 0, 1, 2, 0],
- [5, 3, 4, 5, 3]]]])
- >>> print(z.shape)
- torch.Size([1, 1, 6, 5])
- """
- in_shape = input.shape
- paddable_shape = in_shape[2:]
- ndim = len(paddable_shape)
-
- for idx, size in enumerate(paddable_shape):
- # Only supports wrapping around once
- assert padding[-(idx * 2 + 1)] <= size, "Padding value causes wrapping around more than once."
- assert padding[-(idx * 2 + 2)] <= size, "Padding value causes wrapping around more than once."
- # Negative padding should not result in negative sizes
- assert (
- padding[-(idx * 2 + 1)] + padding[-(idx * 2 + 2)] + size >= 0
- ), "Negative padding value is resulting in an empty dimension."
-
- # Get shape of padded tensor
- out_shape = in_shape[:2]
- for idx, size in enumerate(paddable_shape):
- out_shape += (size + padding[-(idx * 2 + 1)] + padding[-(idx * 2 + 2)],)
-
- out = input.new_empty(out_shape)
-
- # Put original array in padded array
- if ndim == 1:
- out_d0 = max(padding[-2], 0)
- out_d1 = out_shape[2] - max(padding[-1], 0)
-
- in_d0 = max(-padding[-2], 0)
- in_d1 = in_shape[2] - max(-padding[-1], 0)
-
- out[..., out_d0:out_d1] = input[..., in_d0:in_d1]
- elif ndim == 2:
- out_d0 = max(padding[-2], 0)
- out_d1 = out_shape[2] - max(padding[-1], 0)
-
- out_h0 = max(padding[-4], 0)
- out_h1 = out_shape[3] - max(padding[-3], 0)
-
- in_d0 = max(-padding[-2], 0)
- in_d1 = in_shape[2] - max(-padding[-1], 0)
-
- in_h0 = max(-padding[-4], 0)
- in_h1 = in_shape[3] - max(-padding[-3], 0)
-
- out[..., out_d0:out_d1, out_h0:out_h1] = input[..., in_d0:in_d1, in_h0:in_h1]
- elif ndim == 3:
- out_d0 = max(padding[-2], 0)
- out_d1 = out_shape[2] - max(padding[-1], 0)
-
- out_h0 = max(padding[-4], 0)
- out_h1 = out_shape[3] - max(padding[-3], 0)
-
- out_w0 = max(padding[-6], 0)
- out_w1 = out_shape[4] - max(padding[-5], 0)
-
- in_d0 = max(-padding[-2], 0)
- in_d1 = in_shape[2] - max(-padding[-1], 0)
-
- in_h0 = max(-padding[-4], 0)
- in_h1 = in_shape[3] - max(-padding[-3], 0)
-
- in_w0 = max(-padding[-6], 0)
- in_w1 = in_shape[4] - max(-padding[-5], 0)
-
- out[..., out_d0:out_d1, out_h0:out_h1, out_w0:out_w1] = input[..., in_d0:in_d1, in_h0:in_h1, in_w0:in_w1]
-
- # The following steps first pad the beginning of the tensor (left side),
- # and then pad the end of the tensor (right side).
- # Note: Corners will be written more than once when ndim > 1.
-
- # Only in cases where padding values are > 0 are when additional copying
- # is required.
-
- # Pad first dimension (depth)
- if padding[-2] > 0:
- i0 = out_shape[2] - padding[-2] - max(padding[-1], 0)
- i1 = out_shape[2] - max(padding[-1], 0)
- o0 = 0
- o1 = padding[-2]
- out[:, :, o0:o1] = out[:, :, i0:i1]
- if padding[-1] > 0:
- i0 = max(padding[-2], 0)
- i1 = max(padding[-2], 0) + padding[-1]
- o0 = out_shape[2] - padding[-1]
- o1 = out_shape[2]
- out[:, :, o0:o1] = out[:, :, i0:i1]
-
- # Pad second dimension (height)
- if len(padding) > 2:
- if padding[-4] > 0:
- i0 = out_shape[3] - padding[-4] - max(padding[-3], 0)
- i1 = out_shape[3] - max(padding[-3], 0)
- o0 = 0
- o1 = padding[-4]
- out[:, :, :, o0:o1] = out[:, :, :, i0:i1]
- if padding[-3] > 0:
- i0 = max(padding[-4], 0)
- i1 = max(padding[-4], 0) + padding[-3]
- o0 = out_shape[3] - padding[-3]
- o1 = out_shape[3]
- out[:, :, :, o0:o1] = out[:, :, :, i0:i1]
-
- # Pad third dimension (width)
- if len(padding) > 4:
- if padding[-6] > 0:
- i0 = out_shape[4] - padding[-6] - max(padding[-5], 0)
- i1 = out_shape[4] - max(padding[-5], 0)
- o0 = 0
- o1 = padding[-6]
- out[:, :, :, :, o0:o1] = out[:, :, :, :, i0:i1]
- if padding[-5] > 0:
- i0 = max(padding[-6], 0)
- i1 = max(padding[-6], 0) + padding[-5]
- o0 = out_shape[4] - padding[-5]
- o1 = out_shape[4]
- out[:, :, :, :, o0:o1] = out[:, :, :, :, i0:i1]
-
- return out
-
#
# multihead attention
#
diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py
index dd14186..f945998 100644
--- a/torch/onnx/symbolic_opset11.py
+++ b/torch/onnx/symbolic_opset11.py
@@ -8,7 +8,7 @@
import warnings
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _is_tensor_list, ScalarType, quantized_args
-from torch.onnx.symbolic_opset9 import expand, unused, mul, op_with_optional_float_cast
+from torch.onnx.symbolic_opset9 import expand, unused, mul, op_with_optional_float_cast, _pad_circular
from torch.onnx.symbolic_opset9 import linalg_vector_norm as lvn
from torch.nn.modules.utils import _single, _pair, _triple
from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block
@@ -502,6 +502,20 @@
replication_pad3d = replication_pad
+def pad(g, input, pad, mode, value):
+ mode = sym_help._parse_arg(mode, "s")
+ if mode == "replicate":
+ return replication_pad(g, input, pad)
+ elif mode == "reflect":
+ return reflection_pad(g, input, pad)
+ elif mode == "constant":
+ return constant_pad_nd(g, input, pad, value)
+ elif mode == "circular":
+ return _pad_circular(g, input, pad)
+ else:
+ raise RuntimeError(f"Unrecognized padding mode {mode}")
+
+
def linalg_det(g, self):
return g.op("Det", self)
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index 8453d63..aa08198 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -1097,6 +1097,50 @@
paddings = _prepare_onnx_paddings(sym_help._get_tensor_rank(input), padding)
return op_with_optional_float_cast(g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11)
+def _pad_circular(g, input, pad):
+ padding = _convert_padding_node(pad)
+ assert len(padding) % 2 == 0
+ ndim = len(padding) // 2
+
+ cur = input
+ for idx in range(ndim):
+ pad_l = padding[-(2 * idx + 1)]
+ pad_r = padding[-(2 * idx + 2)]
+
+ tensors = []
+ if pad_l > 0:
+ left = sym_help._slice_helper(
+ g,
+ cur,
+ axes=[2 + idx],
+ starts=[-(pad_l + 1)],
+ ends=[-1])
+ tensors.append(left)
+
+ if pad_l < 0 or pad_r < 0:
+ middle = sym_help._slice_helper(
+ g,
+ cur,
+ axes=[2 + idx],
+ starts=[max(0, -pad_l)],
+ ends=[-(1 + max(0, -pad_r))])
+ tensors.append(middle)
+ else:
+ tensors.append(cur)
+
+ if pad_r > 0:
+ right = sym_help._slice_helper(
+ g,
+ cur,
+ axes=[2 + idx],
+ starts=[0],
+ ends=[pad_r])
+ tensors.append(right)
+
+ cur = g.op("Concat", *tensors, axis_i=(2 + idx))
+
+ return cur
+
def reflection_pad(g, input, padding):
mode = "reflect"
@@ -1120,6 +1164,19 @@
replication_pad3d = replication_pad
+def pad(g, input, pad, mode, value):
+ mode = sym_help._parse_arg(mode, "s")
+ if mode == "replicate":
+ return replication_pad(g, input, pad)
+ elif mode == "reflect":
+ return reflection_pad(g, input, pad)
+ elif mode == "constant":
+ return constant_pad_nd(g, input, pad, value)
+ elif mode == "circular":
+ return _pad_circular(g, input, pad)
+ else:
+ raise RuntimeError(f"Unrecognized padding mode {mode}")
+
def _interpolate(name, dim, interpolate_mode):
def symbolic_fn(g, input, output_size, *args):
scales, align_corners = sym_help._get_interpolate_attributes(g, interpolate_mode, args)