changing trapz to trapezoid (#61475)

Summary:
This PR resolves issue https://github.com/pytorch/pytorch/issues/52606 while also adding support for complex number

Stack from [ghstack](https://github.com/ezyang/ghstack):
* https://github.com/pytorch/pytorch/issues/61616
* https://github.com/pytorch/pytorch/issues/61615
* **https://github.com/pytorch/pytorch/issues/61475**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61475

Reviewed By: mruberry

Differential Revision: D29794958

Pulled By: NivekT

fbshipit-source-id: 60b9c07efd47fd85b9c8178768fc7828d7b57d29
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index e7a0206..bc8cf63 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -265,6 +265,8 @@
   _(aten, linalg_householder_product)\
   _(aten, transpose)                 \
   _(aten, transpose_)                \
+  _(aten, trapz)                     \
+  _(aten, trapezoid)                 \
   _(aten, unsqueeze_)                \
   _(aten, __getitem__)               \
   _(aten, _set_item)                 \
diff --git a/aten/src/ATen/native/Integration.cpp b/aten/src/ATen/native/Integration.cpp
index 909a946..15dcccf 100644
--- a/aten/src/ATen/native/Integration.cpp
+++ b/aten/src/ATen/native/Integration.cpp
@@ -3,6 +3,7 @@
 #include <ATen/WrapDimUtils.h>
 #include <ATen/core/DimVector.h>
 #include <c10/util/Exception.h>
+#include <c10/core/ScalarType.h>
 
 namespace at {
 namespace native {
@@ -16,16 +17,17 @@
 //
 // TODO: if we extend TensorIterator to accept 3 inputs,
 // we can probably make this a bit more performant.
-Tensor do_trapz(const Tensor& y, const Tensor& dx, int64_t dim) {
+Tensor do_trapezoid(const Tensor& y, const Tensor& dx, int64_t dim) {
     Tensor left = y.slice(dim, 0, -1);
     Tensor right = y.slice(dim, 1);
-
+    // If the dimensions of 'dx' and '(left + right)' do not match
+    // broadcasting is attempted here.
     return ((left + right) * dx).sum(dim) / 2.;
 }
 
 // When dx is constant, the above formula simplifies
 // to dx * [(\sum_{i=1}^n y_i) - (y_1 + y_n)/2]
-Tensor do_trapz(const Tensor& y, double dx, int64_t dim) {
+Tensor do_trapezoid(const Tensor& y, double dx, int64_t dim) {
     return (y.sum(dim) - (y.select(dim, 0) + y.select(dim, -1)) * (0.5)) * dx;
 }
 
@@ -38,35 +40,53 @@
 
 }
 
-Tensor trapz(const Tensor& y, const Tensor& x, int64_t dim) {
+Tensor trapezoid(const Tensor& y, const Tensor& x, int64_t dim) {
     dim = maybe_wrap_dim(dim, y);
     // asking for the integral with zero samples is a bit nonsensical,
     // but we'll return "0" to match numpy behavior.
     if (y.size(dim) == 0) {
         return zeros_like_except(y, dim);
     }
+    TORCH_CHECK(y.scalar_type() != kBool && x.scalar_type() != kBool, "trapezoid: received a bool input for `x` or `y`, but bool is not supported")
     Tensor x_viewed;
     if (x.dim() == 1) {
-        TORCH_CHECK(x.size(0) == y.size(dim), "trapz: There must be one `x` value for each sample point");
+        // This step takes 'x' with dimension (n,), and returns 'x_view' with
+        // dimension (1,1,...,n,...,1,1) based on dim and y.dim() so that 'x'
+        // can be broadcasted later to match 'y'.
+        // Note: This behavior differs from numpy in that numpy tries to
+        // broadcast 'dx', but this tries to broadcast 'x' to match 'y' instead.
+        TORCH_CHECK(x.size(0) == y.size(dim), "trapezoid: There must be one `x` value for each sample point");
         DimVector sizes(y.dim(), 1);
         sizes[dim] = x.size(0);
         x_viewed = x.view(sizes);
     } else {
         x_viewed = x;
     }
+    // Note the .slice operation reduces the dimension along 'dim' by 1.
+    // The sizes of other dimensions are untouched.
     Tensor x_left = x_viewed.slice(dim, 0, -1);
     Tensor x_right = x_viewed.slice(dim, 1);
 
     Tensor dx = x_right - x_left;
-    return do_trapz(y, dx, dim);
+    return do_trapezoid(y, dx, dim);
 }
 
-Tensor trapz(const Tensor& y, double dx, int64_t dim) {
+Tensor trapezoid(const Tensor& y, const Scalar& dx, int64_t dim) {
     // see above
     if (y.size(dim) == 0) {
         return zeros_like_except(y, dim);
     }
-    return do_trapz(y, dx, dim);
+    TORCH_CHECK(y.scalar_type() != kBool, "trapezoid: received a bool input for `y`, but bool is not supported")
+    TORCH_CHECK(!(dx.isComplex() or dx.isBoolean()), "trapezoid: Currently, we only support dx as a real number.");
+    return do_trapezoid(y, dx.toDouble(), dim);
+}
+
+Tensor trapz(const Tensor& y, const Tensor& x, int64_t dim) {
+    return at::native::trapezoid(y, x, dim);
+}
+
+Tensor trapz(const Tensor& y, double dx, int64_t dim) {
+    return at::native::trapezoid(y, dx, dim);
 }
 
 }} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 4d7e76b..92d5dd9 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -4289,6 +4289,10 @@
   dispatch:
     CompositeExplicitAutograd: rot90
 
+- func: trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor
+
+- func: trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor
+
 - func: trapz.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor
 
 - func: trapz.dx(Tensor y, *, float dx=1, int dim=-1) -> Tensor
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index 1612717..8020713 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -571,6 +571,7 @@
     symeig
     lobpcg
     trapz
+    trapezoid
     triangular_solve
     vdot
 
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 00f56b6..ba374cf 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -2789,15 +2789,6 @@
                               lambda a, b: torch.cat((a, b)),
                               True, f_args_variable, f_args_tensor, check_forward_ad=True)
 
-    def test_trapz(self):
-        f_args_variable = (torch.randn(2, 3, dtype=torch.double, requires_grad=True),
-                           torch.tensor([[1.0, 2.0, 5.5], [2.3, 0.5, 6.2]], dtype=torch.double, requires_grad=True))
-        f_args_tensor = deepcopy(unpack_variables(f_args_variable))
-        run_functional_checks(self, "test_trapz", "trapz",
-                              lambda y, x: torch.trapz(y, x),
-                              True, f_args_variable, f_args_tensor)
-
-
     def test_var_mean_differentiable(self):
         dim = [2, 4]
         keepdim = False
diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py
index 6c7fcef..d916a5f 100644
--- a/test/test_binary_ufuncs.py
+++ b/test/test_binary_ufuncs.py
@@ -2536,17 +2536,17 @@
             _test_atan2(1, -1, math.pi / -4 , device, dtype)
             _test_atan2(-1, 1, math.pi * 3 / 4 , device, dtype)
 
-    def test_trapz(self, device):
+    def test_trapezoid(self, device):
         def test_dx(sizes, dim, dx, device):
             t = torch.randn(sizes, device=device)
-            actual = torch.trapz(t, dx=dx, dim=dim)
+            actual = torch.trapezoid(t, dx=dx, dim=dim)
             expected = np.trapz(t.cpu().numpy(), dx=dx, axis=dim)
             self.assertEqual(expected.shape, actual.shape)
             self.assertEqual(expected, actual, exact_dtype=False)
 
         def test_x(sizes, dim, x, device):
             t = torch.randn(sizes, device=device)
-            actual = torch.trapz(t, x=torch.tensor(x, device=device), dim=dim)
+            actual = torch.trapezoid(t, x=torch.tensor(x, device=device), dim=dim)
             expected = np.trapz(t.cpu().numpy(), x=x, axis=dim)
             self.assertEqual(expected.shape, actual.shape)
             self.assertEqual(expected, actual.cpu(), exact_dtype=False)
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 37326aa..f90c478 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -101,7 +101,7 @@
     'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum', 'rsub',
     'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward',
     'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', '_conj_physical',
-    'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'conj_physical_', '_neg_view'
+    'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'trapezoid', 'conj_physical_', '_neg_view'
 }
 
 GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 363bb47..01d342c 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -10724,54 +10724,123 @@
             [2, 2],
             [2, 3],
             [3, 3]])
+
+""")
+
+add_docstr(torch.trapezoid,
+           r"""
+trapezoid(y, x=None, *, dx=None, dim=-1) -> Tensor
+
+Computes the `trapezoidal rule <https://en.wikipedia.org/wiki/Trapezoidal_rule>_ along
+:attr:`dim`. By default the spacing between elements is assumed to be 1, but
+:attr:`dx` can be used to specify a different constant spacing, and :attr:`x` can be
+used to specify arbitrary spacing along :attr:`dim`.
+
+
+Assuming :attr:`y` is a one-dimensional tensor with elements :math:`{y_0, y_1, ..., y_n}`,
+the default computation is
+
+.. math::
+    \begin{aligned}
+        \sum_{i = 1}^{n-1} \frac{1}{2} (y_i + y_{i-1})
+    \end{aligned}
+
+When :attr:`dx` is specified the computation becomes
+
+.. math::
+    \begin{aligned}
+        \sum_{i = 1}^{n-1} \frac{\Delta x}{2} (y_i + y_{i-1})
+    \end{aligned}
+
+effectively multiplying the result by :attr:`dx`. When :attr:`x` is specified,
+assuming :attr:`x` is also a one-dimensional tensor with
+elements :math:`{x_0, x_1, ..., x_n}`, the computation becomes
+
+.. math::
+    \begin{aligned}
+        \sum_{i = 1}^{n-1} \frac{(x_i - x_{i-1})}{2} (y_i + y_{i-1})
+    \end{aligned}
+
+When :attr:`y` is two or more dimensions, this computation is performed independently
+along dimension :attr:`dim`. If :attr:`x` is also specified and is one-dimensional,
+then that dimension defines the spacing for each computation.
+If :attr:`x` is also specified and is not one-dimensional, then it is broadcast to
+the shape of :attr:`y` and the corresponding sizes are used for each computation.
+See the examples below for details.
+
+.. note::
+    The trapezoidal rule is a technique for approximating the definite integral of a function
+    by averaging its left and right Riemann sums. The approximation becomes more accurate as
+    the resolution of the partition increases.
+
+Arguments:
+    y (Tensor): Values to use when computing the trapezoidal rule.
+    x (Tensor): If specified, defines spacing between values as specified above.
+
+Keyword arguments:
+    dx (float): constant spacing between values. If neither :attr:`x` or :attr:`dx`
+        are specified then this defaults to 1. Effectively multiplies the result by its value.
+    dim (int): The dimension along which to compute the trapezoidal rule.
+        The last (inner-most) dimension by default.
+
+Examples::
+
+    >>> # Computes the trapezoidal rule in 1D, spacing is implicitly 1
+    >>> y = torch.tensor([1, 5, 10])
+    >>> torch.trapezoid(y)
+    tensor(10.5)
+
+    >>> # Computes the same trapezoidal rule directly to verify
+    >>> (1 + 10 + 10) / 2
+    10.5
+
+    >>> # Computes the trapezoidal rule in 1D with constant spacing of 2
+    >>> # NOTE: the result is the same as before, but multiplied by 2
+    >>> torch.trapezoid(y, dx=2)
+    21.0
+
+    >>> # Computes the trapezoidal rule in 1D with arbitrary spacing
+    >>> x = torch.tensor([1, 3, 6])
+    >>> torch.trapezoid(y, x)
+    28.5
+
+    >>> # Computes the same trapezoidal rule directly to verify
+    >>> ((3 - 1) * (1 + 5) + (6 - 3) * (5 + 10)) / 2
+    28.5
+
+    >>> # Computes the trapezoidal rule for each row of a 3x3 matrix
+    >>> y = torch.arange(9).reshape(3, 3)
+    tensor([[0, 1, 2],
+            [3, 4, 5],
+            [6, 7, 8]])
+    >>> torch.trapezoid(y)
+    tensor([ 2., 8., 14.])
+
+    >>> # Computes the trapezoidal rule for each column of the matrix
+    >>> torch.trapezoid(y, dim=0)
+    tensor([ 6., 8., 10.])
+
+    >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix
+    >>> #   with the same arbitrary spacing
+    >>> y = torch.ones(3, 3)
+    >>> x = torch.tensor([1, 3, 6])
+    >>> torch.trapezoid(y, x)
+    array([5., 5., 5.])
+
+    >>> # Computes the trapezoidal rule for each row of a 3x3 ones matrix
+    >>> #   with different arbitrary spacing per row
+    >>> y = torch.ones(3, 3)
+    >>> x = torch.tensor([[1, 2, 3], [1, 3, 5], [1, 4, 7]])
+    >>> torch.trapezoid(y, x)
+    array([2., 4., 6.])
 """)
 
 add_docstr(torch.trapz,
            r"""
 trapz(y, x, *, dim=-1) -> Tensor
 
-Estimate :math:`\int y\,dx` along `dim`, using the trapezoid rule.
+Alias for :func:`torch.trapezoid`.
 
-Arguments:
-    y (Tensor): The values of the function to integrate
-    x (Tensor): The points at which the function `y` is sampled.
-        If `x` is not in ascending order, intervals on which it is decreasing
-        contribute negatively to the estimated integral (i.e., the convention
-        :math:`\int_a^b f = -\int_b^a f` is followed).
-    dim (int): The dimension along which to integrate.
-        By default, use the last dimension.
-
-Returns:
-    A Tensor with the same shape as the input, except with `dim` removed.
-    Each element of the returned tensor represents the estimated integral
-    :math:`\int y\,dx` along `dim`.
-
-Example::
-
-    >>> y = torch.randn((2, 3))
-    >>> y
-    tensor([[-2.1156,  0.6857, -0.2700],
-            [-1.2145,  0.5540,  2.0431]])
-    >>> x = torch.tensor([[1, 3, 4], [1, 2, 3]])
-    >>> torch.trapz(y, x)
-    tensor([-1.2220,  0.9683])
-
-.. function:: trapz(y, *, dx=1, dim=-1) -> Tensor
-
-As above, but the sample points are spaced uniformly at a distance of `dx`.
-
-Arguments:
-    y (Tensor): The values of the function to integrate
-
-Keyword args:
-    dx (float): The distance between points at which `y` is sampled.
-    dim (int): The dimension along which to integrate.
-        By default, use the last dimension.
-
-Returns:
-    A Tensor with the same shape as the input, except with `dim` removed.
-    Each element of the returned tensor represents the estimated integral
-    :math:`\int y\,dx` along `dim`.
 """)
 
 add_docstr(torch.repeat_interleave,
diff --git a/torch/overrides.py b/torch/overrides.py
index 13b6767..0286801 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -944,6 +944,7 @@
         torch.trace: lambda input: -1,
         torch.transpose: lambda input, dim0, dim1: -1,
         torch.trapz: lambda y, x=None, dim=-1: -1,
+        torch.trapezoid: lambda y, x=None, dim=-1: -1,
         torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
         torch.tril: lambda input, diagonal=0, out=None: -1,
         torch.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index e13fc94..6134c67 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -2477,6 +2477,32 @@
 
     return list(generator())
 
+def sample_trapezoid(op_info, device, dtype, requires_grad, **kwargs):
+    y_shape_x_shape_and_kwargs = [
+        ((2, 3), (2, 3), {}),
+        ((2, 3), (2, 3), {'dim': 1}),
+        ((6,), (6,), {}),
+        ((6,), None, {}),
+        # When 'trapezoid' is called with an empty input, it does not produce an output with requires_grad
+        # See Issue #{61619}
+        # ((6,0), (6,0), {}),
+        ((2, 3), (1, 3), {}),
+        ((3, 3), (3, 3), {}),
+        ((3, 3), (3, 3), {'dim': -2}),
+        ((5,), None, {'dx': 2.0}),
+        ((2, 2), None, {'dx': 3.0})
+    ]
+    samples = []
+    for y_shape, x_shape, kwarg in y_shape_x_shape_and_kwargs:
+        y_tensor = make_tensor(y_shape, device, dtype, low=None, high=None,
+                               requires_grad=requires_grad)
+        if x_shape is not None:
+            x_tensor = make_tensor(x_shape, device, dtype, low=None, high=None,
+                                   requires_grad=requires_grad)
+            samples.append(SampleInput(y_tensor, args=(x_tensor,), kwargs=kwarg))
+        else:
+            samples.append(SampleInput(y_tensor, kwargs=kwarg))
+    return samples
 
 def sample_unsqueeze(op_info, device, dtype, requires_grad, **kwargs):
     shapes_and_axes = [
@@ -7427,6 +7453,14 @@
                   dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
                   supports_out=False,
                   sample_inputs_func=sample_repeat_tile),
+    OpInfo('trapz',  # TODO: in the future, 'trapz' should be made a proper alias of 'trapezoid'
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           supports_out=False,
+           sample_inputs_func=sample_trapezoid),
+    OpInfo('trapezoid',
+           dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
+           supports_out=False,
+           sample_inputs_func=sample_trapezoid),
     OpInfo('unsqueeze',
            dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
            supports_out=False,