Remove F.pad python implementation

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73433

Approved by: https://github.com/albanD, https://github.com/jbschlosser
diff --git a/test/onnx/expect/TestOperators.test_pad.expect b/test/onnx/expect/TestOperators.test_pad.expect
index 0a25fb0..293877a 100644
--- a/test/onnx/expect/TestOperators.test_pad.expect
+++ b/test/onnx/expect/TestOperators.test_pad.expect
@@ -161,7 +161,7 @@
     }
   }
   node {
-    input: "input"
+    input: "onnx::Pad_0"
     input: "onnx::Pad_22"
     output: "23"
     name: "Pad_13"
@@ -186,7 +186,7 @@
     raw_data: "\002\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000"
   }
   input {
-    name: "input"
+    name: "onnx::Pad_0"
     type {
       tensor_type {
         elem_type: 1
diff --git a/test/test_fx.py b/test/test_fx.py
index 4a5436b..6710c21 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -3810,6 +3810,7 @@
         "linear": BUILT_IN_FUNC,
         "logsigmoid": BUILT_IN_FUNC,
         "one_hot": BUILT_IN_FUNC,
+        "pad": BUILT_IN_FUNC,
         "pairwise_distance": BUILT_IN_FUNC,
         "pdist": BUILT_IN_FUNC,
         "pixel_shuffle": BUILT_IN_FUNC,
@@ -3827,7 +3828,6 @@
         "adaptive_max_pool2d_with_indices": LEN_ERROR,
         "adaptive_max_pool3d_with_indices": LEN_ERROR,
         "instance_norm": CONTROL_FLOW,
-        "pad": LEN_ERROR,
 
         "adaptive_max_pool1d": PROXY_ITERABLE,
         "adaptive_max_pool2d": PROXY_ITERABLE,
diff --git a/test/test_jit.py b/test/test_jit.py
index 37d9ec5..cd6bc43 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -15513,7 +15513,7 @@
         self.assertEqual(m.int64_min, imported.int64_min)
 
     def test_script_scope(self):
-        scripted = torch.jit.script(torch.nn.functional.pad)
+        scripted = torch.jit.script(torch.nn.functional.triplet_margin_loss)
 
     @unittest.skipIf(IS_WINDOWS, "NYI: TemporaryFileName on Windows")
     def test_serialization_sharing(self):
diff --git a/test/test_nn.py b/test/test_nn.py
index f71c234..ed1251a 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -5430,8 +5430,8 @@
 
     def test_pad_scalar_error(self):
         inputs = torch.tensor(0., requires_grad=True)
-        self.assertRaises(AssertionError, lambda: F.pad(inputs, (1, 1)))
-        self.assertRaises(AssertionError, lambda: F.pad(inputs, (1,)))
+        self.assertRaises(RuntimeError, lambda: F.pad(inputs, (1, 1)))
+        self.assertRaises(RuntimeError, lambda: F.pad(inputs, (1,)))
 
     @unittest.skipIf(not TEST_NUMPY, "numpy not found")
     @parametrize_test("average_attn_weights", [True, False])
@@ -14345,10 +14345,10 @@
         # Assert assertion errors are raised for invalid circular padding values
         inputs = torch.randn(1, 1, 4, device=device, dtype=dtype, requires_grad=True)
         # Should raise error when trying to wrap around more than once
-        self.assertRaises(AssertionError, lambda: F.pad(inputs, (5, 4), mode='circular'))
-        self.assertRaises(AssertionError, lambda: F.pad(inputs, (3, 6), mode='circular'))
+        self.assertRaises(RuntimeError, lambda: F.pad(inputs, (5, 4), mode='circular'))
+        self.assertRaises(RuntimeError, lambda: F.pad(inputs, (3, 6), mode='circular'))
         # Should raise error when negative padding results in negative output shape
-        self.assertRaises(AssertionError, lambda: F.pad(inputs, (-3, -2), mode='circular'))
+        self.assertRaises(RuntimeError, lambda: F.pad(inputs, (-3, -2), mode='circular'))
 
         # assert that relfection padding errors when pad >= input size
         expected_err_msg = r"Padding size should be less than the corresponding input dimension"
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index e8734ff..cd42554 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -4263,111 +4263,70 @@
     return torch.affine_grid_generator(theta, size, align_corners)
 
 
-# NOTE: Do not edit. This code will be removed once the forward-compatibility
-#       period is over for PR #73431
-def _pad(input: Tensor, pad: BroadcastingList1[int], mode: str = "constant", value: Union[int, float] = 0.0) -> Tensor:
-    r"""Pads tensor.
+pad = _add_docstr(
+    torch._C._nn.pad,
+    r"""
+pad(input, pad, mode="constant", value=None) -> Tensor
 
-    Padding size:
-        The padding size by which to pad some dimensions of :attr:`input`
-        are described starting from the last dimension and moving forward.
-        :math:`\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor` dimensions
-        of ``input`` will be padded.
-        For example, to pad only the last dimension of the input tensor, then
-        :attr:`pad` has the form
-        :math:`(\text{padding\_left}, \text{padding\_right})`;
-        to pad the last 2 dimensions of the input tensor, then use
-        :math:`(\text{padding\_left}, \text{padding\_right},`
-        :math:`\text{padding\_top}, \text{padding\_bottom})`;
-        to pad the last 3 dimensions, use
-        :math:`(\text{padding\_left}, \text{padding\_right},`
-        :math:`\text{padding\_top}, \text{padding\_bottom}`
-        :math:`\text{padding\_front}, \text{padding\_back})`.
+Pads tensor.
 
-    Padding mode:
-        See :class:`torch.nn.ConstantPad2d`, :class:`torch.nn.ReflectionPad2d`, and
-        :class:`torch.nn.ReplicationPad2d` for concrete examples on how each of the
-        padding modes works. Constant padding is implemented for arbitrary dimensions.
-        Replicate and reflection padding are implemented for padding the last 3
-        dimensions of a 4D or 5D input tensor, the last 2 dimensions of a 3D
-        or 4D input tensor, or the last dimension of a 2D or 3D input tensor.
+Padding size:
+    The padding size by which to pad some dimensions of :attr:`input`
+    are described starting from the last dimension and moving forward.
+    :math:`\left\lfloor\frac{\text{len(pad)}}{2}\right\rfloor` dimensions
+    of ``input`` will be padded.
+    For example, to pad only the last dimension of the input tensor, then
+    :attr:`pad` has the form
+    :math:`(\text{padding\_left}, \text{padding\_right})`;
+    to pad the last 2 dimensions of the input tensor, then use
+    :math:`(\text{padding\_left}, \text{padding\_right},`
+    :math:`\text{padding\_top}, \text{padding\_bottom})`;
+    to pad the last 3 dimensions, use
+    :math:`(\text{padding\_left}, \text{padding\_right},`
+    :math:`\text{padding\_top}, \text{padding\_bottom}`
+    :math:`\text{padding\_front}, \text{padding\_back})`.
 
-    Note:
-        When using the CUDA backend, this operation may induce nondeterministic
-        behaviour in its backward pass that is not easily switched off.
-        Please see the notes on :doc:`/notes/randomness` for background.
+Padding mode:
+    See :class:`torch.nn.ConstantPad2d`, :class:`torch.nn.ReflectionPad2d`, and
+    :class:`torch.nn.ReplicationPad2d` for concrete examples on how each of the
+    padding modes works. Constant padding is implemented for arbitrary dimensions.
+    Replicate and reflection padding are implemented for padding the last 3
+    dimensions of a 4D or 5D input tensor, the last 2 dimensions of a 3D
+    or 4D input tensor, or the last dimension of a 2D or 3D input tensor.
 
-    Args:
-        input (Tensor): N-dimensional tensor
-        pad (tuple): m-elements tuple, where
-            :math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even.
-        mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
-            Default: ``'constant'``
-        value: fill value for ``'constant'`` padding. Default: ``0``
+Note:
+    When using the CUDA backend, this operation may induce nondeterministic
+    behaviour in its backward pass that is not easily switched off.
+    Please see the notes on :doc:`/notes/randomness` for background.
 
-    Examples::
+Args:
+    input (Tensor): N-dimensional tensor
+    pad (tuple): m-elements tuple, where
+        :math:`\frac{m}{2} \leq` input dimensions and :math:`m` is even.
+    mode: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+        Default: ``'constant'``
+    value: fill value for ``'constant'`` padding. Default: ``0``
 
-        >>> t4d = torch.empty(3, 3, 4, 2)
-        >>> p1d = (1, 1) # pad last dim by 1 on each side
-        >>> out = F.pad(t4d, p1d, "constant", 0)  # effectively zero padding
-        >>> print(out.size())
-        torch.Size([3, 3, 4, 4])
-        >>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2)
-        >>> out = F.pad(t4d, p2d, "constant", 0)
-        >>> print(out.size())
-        torch.Size([3, 3, 8, 4])
-        >>> t4d = torch.empty(3, 3, 4, 2)
-        >>> p3d = (0, 1, 2, 1, 3, 3) # pad by (0, 1), (2, 1), and (3, 3)
-        >>> out = F.pad(t4d, p3d, "constant", 0)
-        >>> print(out.size())
-        torch.Size([3, 9, 7, 3])
+Examples::
 
-    """
-    if has_torch_function_unary(input):
-        return handle_torch_function(_pad, (input,), input, pad, mode=mode, value=value)
-    assert len(pad) % 2 == 0, "Padding length must be divisible by 2"
-    assert len(pad) // 2 <= input.dim(), "Padding length too large"
-    if mode == "constant":
-        return _VF.constant_pad_nd(input, pad, value)
-    else:
-        assert value == 0.0, 'Padding mode "{}"" doesn\'t take in value argument'.format(mode)
-        if len(pad) == 2 and (input.dim() == 2 or input.dim() == 3):
-            if mode == "reflect":
-                return torch._C._nn.reflection_pad1d(input, pad)
-            elif mode == "replicate":
-                return torch._C._nn.replication_pad1d(input, pad)
-            elif mode == "circular":
-                return _pad_circular(input, pad)
-            else:
-                raise NotImplementedError
+    >>> t4d = torch.empty(3, 3, 4, 2)
+    >>> p1d = (1, 1) # pad last dim by 1 on each side
+    >>> out = F.pad(t4d, p1d, "constant", 0)  # effectively zero padding
+    >>> print(out.size())
+    torch.Size([3, 3, 4, 4])
+    >>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2)
+    >>> out = F.pad(t4d, p2d, "constant", 0)
+    >>> print(out.size())
+    torch.Size([3, 3, 8, 4])
+    >>> t4d = torch.empty(3, 3, 4, 2)
+    >>> p3d = (0, 1, 2, 1, 3, 3) # pad by (0, 1), (2, 1), and (3, 3)
+    >>> out = F.pad(t4d, p3d, "constant", 0)
+    >>> print(out.size())
+    torch.Size([3, 9, 7, 3])
 
-        elif len(pad) == 4 and (input.dim() == 3 or input.dim() == 4):
-            if mode == "reflect":
-                return torch._C._nn.reflection_pad2d(input, pad)
-            elif mode == "replicate":
-                return torch._C._nn.replication_pad2d(input, pad)
-            elif mode == "circular":
-                return _pad_circular(input, pad)
-            else:
-                raise NotImplementedError
-
-        elif len(pad) == 6 and (input.dim() == 4 or input.dim() == 5):
-            if mode == "reflect":
-                return torch._C._nn.reflection_pad3d(input, pad)
-            elif mode == "replicate":
-                return torch._C._nn.replication_pad3d(input, pad)
-            elif mode == "circular":
-                return _pad_circular(input, pad)
-            else:
-                raise NotImplementedError
-        else:
-            raise NotImplementedError("Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now")
-
-
-# We define this function as _pad because it takes an argument
-# named pad, which clobbers the recursive reference to the pad
-# function needed for __torch_function__ support
-pad = _pad
+""")
+# TODO: Fix via https://github.com/pytorch/pytorch/issues/75798
+pad.__module__ = "torch.nn.functional"
 
 # distance
 
@@ -4684,174 +4643,6 @@
         raise NotImplementedError("Input Error: Only unbatched (2D) or batched (3D) input Tensors"
                                   f"are supported (got {input.dim()}D)")
 
-
-# NOTE: Do not edit. This code will be removed once the forward-compatibility
-#       period is over for PR #73410
-def _pad_circular(input: Tensor, padding: List[int]) -> Tensor:
-    """Circularly pads tensor.
-
-    Tensor values at the beginning are used to pad the end, and values at the
-    end are used to pad the beginning. For example, consider a single dimension
-    with values [0, 1, 2, 3]. With circular padding of (1, 1) it would be
-    padded to [3, 0, 1, 2, 3, 0], and with padding (1, 2) it would be padded to
-    [3, 0, 1, 2, 3, 0, 1]. If negative padding is applied then the ends of the
-    tensor get removed. With circular padding of (-1, -1) the previous example
-    would become [1, 2]. Circular padding of (-1, 1) would produce
-    [1, 2, 3, 1].
-
-    The first and second dimensions of the tensor are not padded.
-
-    Args:
-        input: Tensor with shape :math:`(N, C, D[, H, W])`.
-        padding: Tuple containing the number of elements to pad each side of
-            the tensor. The length of padding must be twice the number of
-            paddable dimensions. For example, the length of padding should be 4
-            for a tensor of shape :math:`(N, C, H, W)`, and the length should
-            be 6 for a tensor of shape :math:`(N, C, D, H, W)`.
-
-    Examples::
-
-        >>> x = torch.tensor([[[[0, 1, 2], [3, 4, 5]]]])  # Create tensor
-        >>> # Example 1
-        >>> padding = (1, 1, 1, 1)
-        >>> y = F.pad(x, padding, mode='circular')
-        >>> print(y)
-        tensor([[[[5, 3, 4, 5, 3],
-                  [2, 0, 1, 2, 0],
-                  [5, 3, 4, 5, 3],
-                  [2, 0, 1, 2, 0]]]])
-        >>> print(y.shape)
-        torch.Size([1, 1, 4, 5])
-        >>> # Example 2
-        >>> padding = (1, 1, 2, 2)
-        >>> z = F.pad(x, padding, mode='circular')
-        >>> print(z)
-        tensor([[[[2, 0, 1, 2, 0],
-                  [5, 3, 4, 5, 3],
-                  [2, 0, 1, 2, 0],
-                  [5, 3, 4, 5, 3],
-                  [2, 0, 1, 2, 0],
-                  [5, 3, 4, 5, 3]]]])
-        >>> print(z.shape)
-        torch.Size([1, 1, 6, 5])
-    """
-    in_shape = input.shape
-    paddable_shape = in_shape[2:]
-    ndim = len(paddable_shape)
-
-    for idx, size in enumerate(paddable_shape):
-        # Only supports wrapping around once
-        assert padding[-(idx * 2 + 1)] <= size, "Padding value causes wrapping around more than once."
-        assert padding[-(idx * 2 + 2)] <= size, "Padding value causes wrapping around more than once."
-        # Negative padding should not result in negative sizes
-        assert (
-            padding[-(idx * 2 + 1)] + padding[-(idx * 2 + 2)] + size >= 0
-        ), "Negative padding value is resulting in an empty dimension."
-
-    # Get shape of padded tensor
-    out_shape = in_shape[:2]
-    for idx, size in enumerate(paddable_shape):
-        out_shape += (size + padding[-(idx * 2 + 1)] + padding[-(idx * 2 + 2)],)
-
-    out = input.new_empty(out_shape)
-
-    # Put original array in padded array
-    if ndim == 1:
-        out_d0 = max(padding[-2], 0)
-        out_d1 = out_shape[2] - max(padding[-1], 0)
-
-        in_d0 = max(-padding[-2], 0)
-        in_d1 = in_shape[2] - max(-padding[-1], 0)
-
-        out[..., out_d0:out_d1] = input[..., in_d0:in_d1]
-    elif ndim == 2:
-        out_d0 = max(padding[-2], 0)
-        out_d1 = out_shape[2] - max(padding[-1], 0)
-
-        out_h0 = max(padding[-4], 0)
-        out_h1 = out_shape[3] - max(padding[-3], 0)
-
-        in_d0 = max(-padding[-2], 0)
-        in_d1 = in_shape[2] - max(-padding[-1], 0)
-
-        in_h0 = max(-padding[-4], 0)
-        in_h1 = in_shape[3] - max(-padding[-3], 0)
-
-        out[..., out_d0:out_d1, out_h0:out_h1] = input[..., in_d0:in_d1, in_h0:in_h1]
-    elif ndim == 3:
-        out_d0 = max(padding[-2], 0)
-        out_d1 = out_shape[2] - max(padding[-1], 0)
-
-        out_h0 = max(padding[-4], 0)
-        out_h1 = out_shape[3] - max(padding[-3], 0)
-
-        out_w0 = max(padding[-6], 0)
-        out_w1 = out_shape[4] - max(padding[-5], 0)
-
-        in_d0 = max(-padding[-2], 0)
-        in_d1 = in_shape[2] - max(-padding[-1], 0)
-
-        in_h0 = max(-padding[-4], 0)
-        in_h1 = in_shape[3] - max(-padding[-3], 0)
-
-        in_w0 = max(-padding[-6], 0)
-        in_w1 = in_shape[4] - max(-padding[-5], 0)
-
-        out[..., out_d0:out_d1, out_h0:out_h1, out_w0:out_w1] = input[..., in_d0:in_d1, in_h0:in_h1, in_w0:in_w1]
-
-    # The following steps first pad the beginning of the tensor (left side),
-    # and then pad the end of the tensor (right side).
-    # Note: Corners will be written more than once when ndim > 1.
-
-    # Only in cases where padding values are > 0 are when additional copying
-    # is required.
-
-    # Pad first dimension (depth)
-    if padding[-2] > 0:
-        i0 = out_shape[2] - padding[-2] - max(padding[-1], 0)
-        i1 = out_shape[2] - max(padding[-1], 0)
-        o0 = 0
-        o1 = padding[-2]
-        out[:, :, o0:o1] = out[:, :, i0:i1]
-    if padding[-1] > 0:
-        i0 = max(padding[-2], 0)
-        i1 = max(padding[-2], 0) + padding[-1]
-        o0 = out_shape[2] - padding[-1]
-        o1 = out_shape[2]
-        out[:, :, o0:o1] = out[:, :, i0:i1]
-
-    # Pad second dimension (height)
-    if len(padding) > 2:
-        if padding[-4] > 0:
-            i0 = out_shape[3] - padding[-4] - max(padding[-3], 0)
-            i1 = out_shape[3] - max(padding[-3], 0)
-            o0 = 0
-            o1 = padding[-4]
-            out[:, :, :, o0:o1] = out[:, :, :, i0:i1]
-        if padding[-3] > 0:
-            i0 = max(padding[-4], 0)
-            i1 = max(padding[-4], 0) + padding[-3]
-            o0 = out_shape[3] - padding[-3]
-            o1 = out_shape[3]
-            out[:, :, :, o0:o1] = out[:, :, :, i0:i1]
-
-    # Pad third dimension (width)
-    if len(padding) > 4:
-        if padding[-6] > 0:
-            i0 = out_shape[4] - padding[-6] - max(padding[-5], 0)
-            i1 = out_shape[4] - max(padding[-5], 0)
-            o0 = 0
-            o1 = padding[-6]
-            out[:, :, :, :, o0:o1] = out[:, :, :, :, i0:i1]
-        if padding[-5] > 0:
-            i0 = max(padding[-6], 0)
-            i1 = max(padding[-6], 0) + padding[-5]
-            o0 = out_shape[4] - padding[-5]
-            o1 = out_shape[4]
-            out[:, :, :, :, o0:o1] = out[:, :, :, :, i0:i1]
-
-    return out
-
 #
 # multihead attention
 #
diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py
index dd14186..f945998 100644
--- a/torch/onnx/symbolic_opset11.py
+++ b/torch/onnx/symbolic_opset11.py
@@ -8,7 +8,7 @@
 import warnings
 
 from torch.onnx.symbolic_helper import parse_args, _unimplemented, _is_tensor_list, ScalarType, quantized_args
-from torch.onnx.symbolic_opset9 import expand, unused, mul, op_with_optional_float_cast
+from torch.onnx.symbolic_opset9 import expand, unused, mul, op_with_optional_float_cast, _pad_circular
 from torch.onnx.symbolic_opset9 import linalg_vector_norm as lvn
 from torch.nn.modules.utils import _single, _pair, _triple
 from torch.onnx.utils import _add_block, _add_input_to_block, _add_output_to_block
@@ -502,6 +502,20 @@
 replication_pad3d = replication_pad
 
 
+def pad(g, input, pad, mode, value):
+    mode = sym_help._parse_arg(mode, "s")
+    if mode == "replicate":
+        return replication_pad(g, input, pad)
+    elif mode == "reflect":
+        return reflection_pad(g, input, pad)
+    elif mode == "constant":
+        return constant_pad_nd(g, input, pad, value)
+    elif mode == "circular":
+        return _pad_circular(g, input, pad)
+    else:
+        raise RuntimeError(f"Unrecognized padding mode {mode}")
+
+
 def linalg_det(g, self):
     return g.op("Det", self)
 
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index 8453d63..aa08198 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -1097,6 +1097,50 @@
     paddings = _prepare_onnx_paddings(sym_help._get_tensor_rank(input), padding)
     return op_with_optional_float_cast(g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11)
 
+def _pad_circular(g, input, pad):
+    padding = _convert_padding_node(pad)
+    assert len(padding) % 2 == 0
+    ndim = len(padding) // 2
+
+    cur = input
+    for idx in range(ndim):
+        pad_l = padding[-(2 * idx + 1)]
+        pad_r = padding[-(2 * idx + 2)]
+
+        tensors = []
+        if pad_l > 0:
+            left = sym_help._slice_helper(
+                g,
+                cur,
+                axes=[2 + idx],
+                starts=[-(pad_l + 1)],
+                ends=[-1])
+            tensors.append(left)
+
+        if pad_l < 0 or pad_r < 0:
+            middle = sym_help._slice_helper(
+                g,
+                cur,
+                axes=[2 + idx],
+                starts=[max(0, -pad_l)],
+                ends=[-(1 + max(0, -pad_r))])
+            tensors.append(middle)
+        else:
+            tensors.append(cur)
+
+        if pad_r > 0:
+            right = sym_help._slice_helper(
+                g,
+                cur,
+                axes=[2 + idx],
+                starts=[0],
+                ends=[pad_r])
+            tensors.append(right)
+
+        cur = g.op("Concat", *tensors, axis_i=(2 + idx))
+
+    return cur
+
 
 def reflection_pad(g, input, padding):
     mode = "reflect"
@@ -1120,6 +1164,19 @@
 replication_pad3d = replication_pad
 
 
+def pad(g, input, pad, mode, value):
+    mode = sym_help._parse_arg(mode, "s")
+    if mode == "replicate":
+        return replication_pad(g, input, pad)
+    elif mode == "reflect":
+        return reflection_pad(g, input, pad)
+    elif mode == "constant":
+        return constant_pad_nd(g, input, pad, value)
+    elif mode == "circular":
+        return _pad_circular(g, input, pad)
+    else:
+        raise RuntimeError(f"Unrecognized padding mode {mode}")
+
 def _interpolate(name, dim, interpolate_mode):
     def symbolic_fn(g, input, output_size, *args):
         scales, align_corners = sym_help._get_interpolate_attributes(g, interpolate_mode, args)