[NestedTensor] Integrate sum along the jagged dimension into NestedTensor (#130425)

Summary: Modify the existing `sum` operator in PyTorch, invoked by `torch.sum`, to allow for reductions along the ragged dimension of a nested tensor. This diff enables PyTorch users to invoke `torch.sum` on a nested tensor with `dim=1`, where `ragged_idx=1`.

Functions modified in `caffe2/torch/nested/_internal/ops.py`:
- `sum_dim_IntList()`: The function assumes that `ragged_idx=1`; in the case that `dim=1` as well, where `dim` is the dimension on which we reduce, this diff invokes the PyTorch benchmark found in D58423489. Specifically, this diff pads a nested tensor, e.g. of logical shape `(B, *, M)`, using [`torch.ops.aten._jagged_to_padded_dense_forward`](https://www.internalfb.com/code/fbsource/[92c2a067ab04e3eebc999254fed4ae2fbea6def3]/fbcode/deeplearning/fbgemm/fbgemm_gpu/fb/inductor_lowerings/elementwise_ops.py?lines=26), then reduces across the `*` dimension (`dim == 1`) to a `(B, M)` output tensor.
- `_wrap_jagged_dims()`: This diff adds special handling to allow for the case where `dim` contains `1` and not `0`, but to continue disallowing the case where `dim` contains `0` and not `1`. In this function's creation, I created a helper function, `_get_condition_for_invalid_jagged_reductions()`, which makes it clearer which conditions apply to which operators. Specifically, operators which are enabled with jagged reductions are specified at the top of the file in `SUPPORTED_JAGGED_REDUCTIONS` and have a different set of conditions that need to be tested, as reducing along `dim == 1` without `dim == 0` is now possible.

Functions modified in `caffe2/test/test_nestedtensor.py`:
- `test_sum_int_DimList()`: This diff adds special handling in the `sum` unit test to allow for the case where `dim` contains `1` and not `0`, but to continue disallowing the case where `dim` contains `0` and not `1`.
- `test_sum_int_DimList_ragged_dim_1()`: This diff adds a new unit test which verifies the accuracy and feasibility of reducing along the jagged dimension of a nested tensor.

Notes:
- This diff solely adds functionality for the case in which we reduce only along the ragged dimension. Cases in which we reduce along both the ragged and another dimension, like `dim == (1, 2)`, are not permitted, as this set of diffs focuses primarily on the former.
- The `sum` operator is the only operator which uses the function `_wrap_jagged_dims()`; all other operators use `_wrap_jagged_dim()`. I would like to later look into why this is the case and if we can consolidate this!
- I modified some of the comments in the `sum` function as well as the unit tests for more clarity.

Test Plan:
Verify that existing (`test_sum_int_DimList`) and new (`test_sum_int_DimList_ragged_dim_1`) unit tests pass via the following command:

```
buck2 run mode/{opt,inplace} //caffe2/test:nested -- --regex test_sum_int_DimList
```

Differential Revision: D59571209

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130425
Approved by: https://github.com/davidberard98
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index f593166..62fa67f 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -4009,34 +4009,32 @@
 
         self.assertEqual(res_dense, res_nt.values())
 
+    @dtypes(torch.float32)
     @parametrize("keepdim", [False, True])
-    def test_sum_int_DimList(self, device, keepdim):
-        # (B, j0, 3, 4)
+    @parametrize("requires_grad", [False, True])
+    @parametrize("components_require_grad", [False, True])
+    def test_sum_dim(
+        self, device, dtype, keepdim, requires_grad, components_require_grad
+    ):
         ts = self._get_list_for_jagged_tensor(
             ((2, 3, 4), 3, 4), device=device, requires_grad=True
-        )
+        )  # (B, j0, 3, 4)
 
-        # Check shape correctness
+        # verify correctness of shapes (assuming that ragged_idx == 1)
         reduce_dims = (
-            # dims, expected shape, expected keepdim shape
-            # j0 is represented as None
-            ((0, 1), (3, 4), (1, 1, 3, 4)),
-            ((1, 2), None, None),
-            ((2, 3), (3, None), (3, None, 1, 1)),
-            ((0, 1, 3), (3,), (1, 1, 3, 1)),
-            ((0, 1, 2), (4,), (1, 1, 1, 4)),
-            ((0, 1, 2, 3), (), (1, 1, 1, 1)),
-        )
-        for rd, ref_shape_no_keepdim, ref_shape_keepdim in reduce_dims:
-            if (0 in rd) ^ (1 in rd):
-                with self.assertRaisesRegex(
-                    RuntimeError,
-                    "applying over the ragged dimension, but not the batch dimension",
-                ):
-                    nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
-                    out = torch.sum(nt, dim=rd, keepdim=keepdim)
-                continue
-
+            ((0, 1), (3, 4), (1, 1, 3, 4), (0,)),  # batch, ragged
+            ((2, 3), (3, None), (3, None, 1, 1), (1, 2)),  # non-batch, non-batch
+            ((0, 1, 3), (3,), (1, 1, 3, 1), (0, 2)),  # batch, ragged, non-batch
+            ((0, 1, 2), (4,), (1, 1, 1, 4), (0, 1)),  # batch, ragged, non-batch
+            (
+                (0, 1, 2, 3),
+                (),
+                (1, 1, 1, 1),
+                (0, 1, 2),
+            ),  # batch, ragged, non-batch, non-batch
+            ((2,), (3, None, 4), (3, None, 1, 4), (1,)),  # non-batch
+        )  # (dims, expected shape, expected keepdim shape, reduce_dim_expected), where j0 is represented as None
+        for rd, ref_shape_no_keepdim, ref_shape_keepdim, _ in reduce_dims:
             nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
             out = torch.sum(nt, dim=rd, keepdim=keepdim)
             ref_shape = ref_shape_keepdim if keepdim else ref_shape_no_keepdim
@@ -4047,21 +4045,289 @@
                 else:
                     self.assertTrue(isinstance(o, torch.SymInt))
 
-        # Check values correctness
-        # raggedness not reduced
-        nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
-        out = torch.sum(nt, dim=(2, 3), keepdim=keepdim)
-        out_ref = torch.sum(nt.values(), dim=(1, 2))
-        self.assertIsInstance(out, NestedTensor)
-        # flatten to avoid having to replicate unsqueeze logic depending on keepdim
-        self.assertTrue(torch.allclose(out.values().view(-1), out_ref.view(-1)))
+        # verify correctness of values
+        tensor_lists = self._get_example_tensor_lists(
+            include_list_of_lists=False,
+            include_requires_grad=components_require_grad,
+        )
+        for tensor_list, reduce_dim_tuple in itertools.product(
+            tensor_lists, reduce_dims
+        ):
+            nt = torch.nested.nested_tensor(
+                tensor_list,
+                device=device,
+                dtype=dtype,
+                layout=torch.jagged,
+                requires_grad=requires_grad,
+            )
 
-        # raggedness reduced away
-        nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
-        out = torch.sum(nt, dim=(0, 1), keepdim=keepdim)
-        out_ref = torch.sum(nt.values(), dim=(0,))
-        self.assertNotIsInstance(out, NestedTensor)
-        self.assertTrue(torch.allclose(out, out_ref))
+            reduce_dim, _, _, reduce_dim_expected = reduce_dim_tuple
+
+            if nt.dim() > reduce_dim[-1]:
+                if nt._ragged_idx in reduce_dim:  # raggedness reduced away
+                    out_actual = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
+                    out_expected = torch.sum(
+                        nt.values(), dim=reduce_dim_expected, keepdim=keepdim
+                    )
+                    self.assertTrue(torch.allclose(out_actual, out_expected))
+                else:  # raggedness preserved
+                    out_actual = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
+                    out_expected = torch.sum(nt.values(), dim=reduce_dim_expected)
+                    self.assertTrue(
+                        torch.allclose(
+                            out_actual.values().view(-1), out_expected.view(-1)
+                        )
+                    )
+
+    @dtypes(torch.float32)
+    @parametrize("keepdim", [False, True])
+    @parametrize("requires_grad", [False, True])
+    @parametrize("components_require_grad", [False, True])
+    def test_sum_dim_reduce_ragged_dim_1(
+        self,
+        device,
+        dtype,
+        keepdim,
+        requires_grad,
+        components_require_grad,
+    ):
+        """
+        Sum on NestedTensor passes when trying to reduce across ragged dimension, where ragged_idx == 1
+        """
+        tensor_lists = self._get_example_tensor_lists(
+            include_list_of_lists=False, include_requires_grad=components_require_grad
+        )
+        reduce_dim = (1,)  # ragged
+
+        for tensor_list in tensor_lists:
+            nt = torch.nested.nested_tensor(
+                tensor_list,
+                device=device,
+                dtype=dtype,
+                layout=torch.jagged,
+                requires_grad=requires_grad,
+            )
+
+            out_actual = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
+            out_expected = torch.cat(
+                [
+                    torch.sum(t, dim=(reduce_dim[0] - 1)).unsqueeze(0)
+                    for t in nt.unbind()
+                ]
+            )
+
+            self.assertFalse(
+                out_actual.is_nested,
+                "sum(): the result of reducing a nested tensor along the ragged dimension is a dense tensor",
+            )  # output is a dense tensor
+            self.assertTrue(torch.allclose(out_actual, out_expected))
+
+    @dtypes(torch.float32)
+    @parametrize("keepdim", [False, True])
+    @parametrize("requires_grad", [False, True])
+    @parametrize("components_require_grad", [False, True])
+    def test_sum_dim_reduce_ragged_and_non_batch(
+        self,
+        device,
+        dtype,
+        keepdim,
+        requires_grad,
+        components_require_grad,
+    ):
+        """
+        Sum on NestedTensor fails when trying to reduce across ragged and non-batch dimensions
+        """
+        tensor_lists = self._get_example_tensor_lists(
+            include_list_of_lists=False, include_requires_grad=components_require_grad
+        )
+        reduce_dims = (
+            (1, 2),  # ragged, non-batch
+            (1, 3),  # ragged, non-batch
+        )
+
+        for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
+            nt = torch.nested.nested_tensor(
+                tensor_list,
+                device=device,
+                dtype=dtype,
+                layout=torch.jagged,
+                requires_grad=requires_grad,
+            )
+
+            if nt.dim() > reduce_dim[-1]:
+                with self.assertRaisesRegex(
+                    RuntimeError,
+                    "not supported along a ragged and non-batch dimension for NestedTensor",
+                ):
+                    out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
+
+    @dtypes(torch.float32)
+    @parametrize("keepdim", [False, True])
+    @parametrize("requires_grad", [False, True])
+    @parametrize("components_require_grad", [False, True])
+    def test_sum_dim_reduce_batch_and_non_batch(
+        self,
+        device,
+        dtype,
+        keepdim,
+        requires_grad,
+        components_require_grad,
+    ):
+        """
+        Sum on NestedTensor fails when trying to reduce across batch and non-batch dimensions
+        """
+        tensor_lists = self._get_example_tensor_lists(
+            include_list_of_lists=False, include_requires_grad=components_require_grad
+        )
+        reduce_dims = (
+            (0, 2),  # batch, non-batch
+            (0, 3),  # batch, non-batch
+        )
+
+        for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
+            nt = torch.nested.nested_tensor(
+                tensor_list,
+                device=device,
+                dtype=dtype,
+                layout=torch.jagged,
+                requires_grad=requires_grad,
+            )
+
+            if nt.dim() > reduce_dim[-1]:
+                with self.assertRaisesRegex(
+                    RuntimeError,
+                    "not supported along only the batch dimension for NestedTensor",
+                ):
+                    out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
+
+    @dtypes(torch.float32)
+    @parametrize("keepdim", [False, True])
+    @parametrize("requires_grad", [False, True])
+    @parametrize("components_require_grad", [False, True])
+    def test_sum_dim_reduce_batch(
+        self,
+        device,
+        dtype,
+        keepdim,
+        requires_grad,
+        components_require_grad,
+    ):
+        """
+        Sum on NestedTensor fails when trying to reduce across batch dimension
+        """
+        tensor_lists = self._get_example_tensor_lists(
+            include_list_of_lists=False, include_requires_grad=components_require_grad
+        )
+        reduce_dim = (0,)  # batch
+
+        for tensor_list in tensor_lists:
+            nt = torch.nested.nested_tensor(
+                tensor_list,
+                device=device,
+                dtype=dtype,
+                layout=torch.jagged,
+                requires_grad=requires_grad,
+            )
+
+            with self.assertRaisesRegex(
+                RuntimeError,
+                "not supported along only the batch dimension for NestedTensor",
+            ):
+                out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
+
+    @dtypes(torch.float32)
+    @parametrize("keepdim", [False, True])
+    @parametrize("requires_grad", [False, True])
+    @parametrize("components_require_grad", [False, True])
+    def test_sum_dim_ragged_dim_not_1(
+        self,
+        device,
+        dtype,
+        keepdim,
+        requires_grad,
+        components_require_grad,
+    ):
+        """
+        Sum on NestedTensor fails when trying to reduce a nested tensor with ragged_idx != 1
+        """
+        tensor_lists = self._get_example_tensor_lists(
+            include_list_of_lists=False, include_requires_grad=components_require_grad
+        )
+        reduce_dims = (
+            (1,),
+            (2, 3),
+        )
+
+        for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
+            nt = torch.nested.nested_tensor(
+                tensor_list,
+                device=device,
+                dtype=dtype,
+                layout=torch.jagged,
+                requires_grad=requires_grad,
+            )
+
+            if (
+                nt.dim() > 2 and nt.dim() > reduce_dim[-1]
+            ):  # ensure that we can transpose dims 1 and 2
+                nt_transposed = nt.transpose(1, 2)
+
+                with self.assertRaisesRegex(
+                    RuntimeError,
+                    "not supported when ragged_idx != 1 for NestedTensor",
+                ):
+                    out = torch.sum(nt_transposed, dim=reduce_dim, keepdim=keepdim)
+
+    @dtypes(torch.float32)
+    @parametrize("keepdim", [False, True])
+    @parametrize("requires_grad", [False, True])
+    @parametrize("components_require_grad", [False, True])
+    def test_sum_dim_with_lengths(
+        self,
+        device,
+        dtype,
+        keepdim,
+        requires_grad,
+        components_require_grad,
+    ):
+        """
+        Sum on NestedTensor fails when trying to reduce a nested tensor with lengths,
+        i.e. a nested tensor with holes, when reducing on the ragged dimension
+        """
+        reduce_dims = (
+            (1,),
+            (2,),
+            (2, 3),
+        )
+
+        lengths = torch.randint(5, 10, (20,), device=device)
+        offsets = torch.zeros((21,), device=device, dtype=torch.int)
+        torch.cumsum(lengths, dim=0, out=offsets[1:])
+
+        values = torch.randn(
+            (offsets[-1].item(), 20),
+            device=device,
+            dtype=dtype,
+            requires_grad=requires_grad,
+        )
+
+        nt_with_holes = torch.nested.nested_tensor_from_jagged(
+            values,
+            offsets,
+            lengths=offsets.diff() - 2,  # arbitrary subtraction to create holes
+        )
+
+        for reduce_dim in reduce_dims:
+            if nt_with_holes.dim() > reduce_dim[-1]:
+                if nt_with_holes._ragged_idx in reduce_dim:
+                    with self.assertRaisesRegex(
+                        RuntimeError,
+                        "not supported where lengths is not None "
+                        + "if reducing across the ragged dimension for NestedTensor",
+                    ):
+                        out = torch.sum(nt_with_holes, dim=reduce_dim, keepdim=keepdim)
+                else:
+                    out = torch.sum(nt_with_holes, dim=reduce_dim, keepdim=keepdim)
 
     @dtypes(torch.float, torch.double, torch.half)
     @parametrize("requires_grad", [False, True])
diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py
index 6193f65..1fdd780 100644
--- a/torch/nested/_internal/ops.py
+++ b/torch/nested/_internal/ops.py
@@ -36,27 +36,23 @@
     return _outer_to_inner_dim(ndim, wrapped) if convert_to_inner_dim else wrapped
 
 
-def _wrap_jagged_dims(ndim, dims, op_name):
-    # ex: (2, 3, 4) -> (1, 2, 3)
-    # ex: (0, 1, 4) -> (0, 3)
+def _wrap_jagged_dims(ndim, dims, op_name, ragged_idx=1):
     from torch._prims_common import canonicalize_dims
 
-    wrapped_dims = [canonicalize_dims(ndim, d) for d in dims]
-    # This logic needs to be done after we canonicalize dims but before we
-    # map to inner dims so we can print a nicer error message.
-    zero_in_dims = 0 in wrapped_dims
-    one_in_dims = 1 in wrapped_dims
-    if zero_in_dims ^ one_in_dims:
-        apply, not_apply = ("batch", "ragged") if zero_in_dims else ("ragged", "batch")
-        raise RuntimeError(
-            f"{op_name}(): applying over the {apply} dimension, but not the {not_apply}"
-            " dimension is not supported for NestedTensor"
-        )
-    return (
-        tuple(_outer_to_inner_dim(ndim, d) for d in dims if d != 0),
-        zero_in_dims,
+    wrapped_dims = [
+        canonicalize_dims(ndim, d) for d in dims
+    ]  # convert all indices to non-negative values
+
+    operate_on_batch = 0 in wrapped_dims
+    operate_on_ragged = ragged_idx in wrapped_dims
+    operate_on_non_batch = any(d != 0 and d != ragged_idx for d in wrapped_dims)
+
+    outer_to_inner_dim = tuple(
+        _outer_to_inner_dim(ndim, d) for d in wrapped_dims if d != 0
     )
 
+    return outer_to_inner_dim, operate_on_batch, operate_on_ragged, operate_on_non_batch
+
 
 def check_schema(schema_str: str, func, *args, **kwargs) -> None:
     named_arg_types = schema_str.split(", ")
@@ -847,28 +843,82 @@
 
 
 @register_jagged_func(
-    torch.ops.aten.sum.dim_IntList, "self: jt, dim: any?, keepdim: any?, dtype: any?"
+    torch.ops.aten.sum.dim_IntList,
+    "self: jt_all, dim: any?, keepdim: any?, dtype: any?",
 )
 def sum_dim_IntList(func, *args, **kwargs):
-    # sum_dim_IntList can produce a NT or a T depending on whether the ragged dims
-    # are reduced away.
+    """
+    Performs a sum along the provided tensor dimension.
+    Returns a dense tensor if the ragged dimension is reduced away, else returns a nested tensor.
+    """
     _, new_kwargs = normalize_function(
         func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
     )
     inp = new_kwargs.pop("input")
-    assert inp._ragged_idx == 1
-    new_kwargs["dim"], ragged_reduced_away = _wrap_jagged_dims(
-        inp.dim(), new_kwargs["dim"], "sum"
+
+    (
+        new_kwargs["dim"],
+        reduce_on_batch,
+        reduce_on_ragged,
+        reduce_on_non_batch,  # noqa: UFMT
+    ) = _wrap_jagged_dims(
+        inp.dim(),
+        new_kwargs["dim"],
+        "sum",
+        inp._ragged_idx,
     )
 
-    if not ragged_reduced_away:
-        return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
-    else:
-        # Don't wrap because we reduced away the raggedness
-        out = func(inp._values, **new_kwargs)
+    if inp._ragged_idx != 1:
+        raise RuntimeError("sum(): not supported when ragged_idx != 1 for NestedTensor")
+
+    if reduce_on_ragged and inp._lengths is not None:
+        raise RuntimeError(
+            "sum(): not supported where lengths is not None "
+            + "if reducing across the ragged dimension for NestedTensor"
+        )
+
+    if reduce_on_ragged:  # raggedness reduced away --> return dense tensor
+        if (
+            reduce_on_batch
+        ):  # reduction cases: (batch, ragged), (batch, ragged, non-batch), etc.
+            out = func(
+                inp._values, **new_kwargs
+            )  # no need to read offsets --> apply sum directly on values
+        else:
+            if (
+                reduce_on_non_batch
+            ):  # invalid reduction cases: (ragged, non-batch), etc.
+                raise RuntimeError(
+                    "sum(): not supported along a ragged and non-batch dimension for NestedTensor"
+                )
+            # reduction cases: (ragged)
+            out = torch.sum(
+                torch.ops.aten._jagged_to_padded_dense_forward(
+                    inp._values.view(*inp._values.shape[: inp._ragged_idx], -1),
+                    [inp._offsets],
+                    max_lengths=[inp._max_seqlen],
+                ).view(
+                    *inp.shape[: inp._ragged_idx],
+                    inp._max_seqlen,
+                    *inp.shape[inp._ragged_idx + 1 :],
+                ),
+                dim=inp._ragged_idx,
+            )  # need to read offsets --> pad jagged dimension and apply sum
+
         if new_kwargs["keepdim"]:
             out = out.unsqueeze(0)
         return out
+    else:  # raggedness preserved --> return nested tensor
+        if (
+            reduce_on_batch
+        ):  # invalid reduction cases: (batch), (batch, non-batch), etc.
+            raise RuntimeError(
+                "sum(): not supported along only the batch dimension for NestedTensor"
+            )
+        # reduction cases: (non-batch), (non-batch, non-batch), etc.
+        return NestedTensor(
+            func(inp._values, **new_kwargs), **extract_kwargs(inp)
+        )  # apply sum directly on values
 
 
 @register_jagged_func(