[NestedTensor] Integrate sum along the jagged dimension into NestedTensor (#130425)
Summary: Modify the existing `sum` operator in PyTorch, invoked by `torch.sum`, to allow for reductions along the ragged dimension of a nested tensor. This diff enables PyTorch users to invoke `torch.sum` on a nested tensor with `dim=1`, where `ragged_idx=1`.
Functions modified in `caffe2/torch/nested/_internal/ops.py`:
- `sum_dim_IntList()`: The function assumes that `ragged_idx=1`; in the case that `dim=1` as well, where `dim` is the dimension on which we reduce, this diff invokes the PyTorch benchmark found in D58423489. Specifically, this diff pads a nested tensor, e.g. of logical shape `(B, *, M)`, using [`torch.ops.aten._jagged_to_padded_dense_forward`](https://www.internalfb.com/code/fbsource/[92c2a067ab04e3eebc999254fed4ae2fbea6def3]/fbcode/deeplearning/fbgemm/fbgemm_gpu/fb/inductor_lowerings/elementwise_ops.py?lines=26), then reduces across the `*` dimension (`dim == 1`) to a `(B, M)` output tensor.
- `_wrap_jagged_dims()`: This diff adds special handling to allow for the case where `dim` contains `1` and not `0`, but to continue disallowing the case where `dim` contains `0` and not `1`. In this function's creation, I created a helper function, `_get_condition_for_invalid_jagged_reductions()`, which makes it clearer which conditions apply to which operators. Specifically, operators which are enabled with jagged reductions are specified at the top of the file in `SUPPORTED_JAGGED_REDUCTIONS` and have a different set of conditions that need to be tested, as reducing along `dim == 1` without `dim == 0` is now possible.
Functions modified in `caffe2/test/test_nestedtensor.py`:
- `test_sum_int_DimList()`: This diff adds special handling in the `sum` unit test to allow for the case where `dim` contains `1` and not `0`, but to continue disallowing the case where `dim` contains `0` and not `1`.
- `test_sum_int_DimList_ragged_dim_1()`: This diff adds a new unit test which verifies the accuracy and feasibility of reducing along the jagged dimension of a nested tensor.
Notes:
- This diff solely adds functionality for the case in which we reduce only along the ragged dimension. Cases in which we reduce along both the ragged and another dimension, like `dim == (1, 2)`, are not permitted, as this set of diffs focuses primarily on the former.
- The `sum` operator is the only operator which uses the function `_wrap_jagged_dims()`; all other operators use `_wrap_jagged_dim()`. I would like to later look into why this is the case and if we can consolidate this!
- I modified some of the comments in the `sum` function as well as the unit tests for more clarity.
Test Plan:
Verify that existing (`test_sum_int_DimList`) and new (`test_sum_int_DimList_ragged_dim_1`) unit tests pass via the following command:
```
buck2 run mode/{opt,inplace} //caffe2/test:nested -- --regex test_sum_int_DimList
```
Differential Revision: D59571209
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130425
Approved by: https://github.com/davidberard98
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index f593166..62fa67f 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -4009,34 +4009,32 @@
self.assertEqual(res_dense, res_nt.values())
+ @dtypes(torch.float32)
@parametrize("keepdim", [False, True])
- def test_sum_int_DimList(self, device, keepdim):
- # (B, j0, 3, 4)
+ @parametrize("requires_grad", [False, True])
+ @parametrize("components_require_grad", [False, True])
+ def test_sum_dim(
+ self, device, dtype, keepdim, requires_grad, components_require_grad
+ ):
ts = self._get_list_for_jagged_tensor(
((2, 3, 4), 3, 4), device=device, requires_grad=True
- )
+ ) # (B, j0, 3, 4)
- # Check shape correctness
+ # verify correctness of shapes (assuming that ragged_idx == 1)
reduce_dims = (
- # dims, expected shape, expected keepdim shape
- # j0 is represented as None
- ((0, 1), (3, 4), (1, 1, 3, 4)),
- ((1, 2), None, None),
- ((2, 3), (3, None), (3, None, 1, 1)),
- ((0, 1, 3), (3,), (1, 1, 3, 1)),
- ((0, 1, 2), (4,), (1, 1, 1, 4)),
- ((0, 1, 2, 3), (), (1, 1, 1, 1)),
- )
- for rd, ref_shape_no_keepdim, ref_shape_keepdim in reduce_dims:
- if (0 in rd) ^ (1 in rd):
- with self.assertRaisesRegex(
- RuntimeError,
- "applying over the ragged dimension, but not the batch dimension",
- ):
- nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
- out = torch.sum(nt, dim=rd, keepdim=keepdim)
- continue
-
+ ((0, 1), (3, 4), (1, 1, 3, 4), (0,)), # batch, ragged
+ ((2, 3), (3, None), (3, None, 1, 1), (1, 2)), # non-batch, non-batch
+ ((0, 1, 3), (3,), (1, 1, 3, 1), (0, 2)), # batch, ragged, non-batch
+ ((0, 1, 2), (4,), (1, 1, 1, 4), (0, 1)), # batch, ragged, non-batch
+ (
+ (0, 1, 2, 3),
+ (),
+ (1, 1, 1, 1),
+ (0, 1, 2),
+ ), # batch, ragged, non-batch, non-batch
+ ((2,), (3, None, 4), (3, None, 1, 4), (1,)), # non-batch
+ ) # (dims, expected shape, expected keepdim shape, reduce_dim_expected), where j0 is represented as None
+ for rd, ref_shape_no_keepdim, ref_shape_keepdim, _ in reduce_dims:
nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
out = torch.sum(nt, dim=rd, keepdim=keepdim)
ref_shape = ref_shape_keepdim if keepdim else ref_shape_no_keepdim
@@ -4047,21 +4045,289 @@
else:
self.assertTrue(isinstance(o, torch.SymInt))
- # Check values correctness
- # raggedness not reduced
- nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
- out = torch.sum(nt, dim=(2, 3), keepdim=keepdim)
- out_ref = torch.sum(nt.values(), dim=(1, 2))
- self.assertIsInstance(out, NestedTensor)
- # flatten to avoid having to replicate unsqueeze logic depending on keepdim
- self.assertTrue(torch.allclose(out.values().view(-1), out_ref.view(-1)))
+ # verify correctness of values
+ tensor_lists = self._get_example_tensor_lists(
+ include_list_of_lists=False,
+ include_requires_grad=components_require_grad,
+ )
+ for tensor_list, reduce_dim_tuple in itertools.product(
+ tensor_lists, reduce_dims
+ ):
+ nt = torch.nested.nested_tensor(
+ tensor_list,
+ device=device,
+ dtype=dtype,
+ layout=torch.jagged,
+ requires_grad=requires_grad,
+ )
- # raggedness reduced away
- nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
- out = torch.sum(nt, dim=(0, 1), keepdim=keepdim)
- out_ref = torch.sum(nt.values(), dim=(0,))
- self.assertNotIsInstance(out, NestedTensor)
- self.assertTrue(torch.allclose(out, out_ref))
+ reduce_dim, _, _, reduce_dim_expected = reduce_dim_tuple
+
+ if nt.dim() > reduce_dim[-1]:
+ if nt._ragged_idx in reduce_dim: # raggedness reduced away
+ out_actual = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
+ out_expected = torch.sum(
+ nt.values(), dim=reduce_dim_expected, keepdim=keepdim
+ )
+ self.assertTrue(torch.allclose(out_actual, out_expected))
+ else: # raggedness preserved
+ out_actual = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
+ out_expected = torch.sum(nt.values(), dim=reduce_dim_expected)
+ self.assertTrue(
+ torch.allclose(
+ out_actual.values().view(-1), out_expected.view(-1)
+ )
+ )
+
+ @dtypes(torch.float32)
+ @parametrize("keepdim", [False, True])
+ @parametrize("requires_grad", [False, True])
+ @parametrize("components_require_grad", [False, True])
+ def test_sum_dim_reduce_ragged_dim_1(
+ self,
+ device,
+ dtype,
+ keepdim,
+ requires_grad,
+ components_require_grad,
+ ):
+ """
+ Sum on NestedTensor passes when trying to reduce across ragged dimension, where ragged_idx == 1
+ """
+ tensor_lists = self._get_example_tensor_lists(
+ include_list_of_lists=False, include_requires_grad=components_require_grad
+ )
+ reduce_dim = (1,) # ragged
+
+ for tensor_list in tensor_lists:
+ nt = torch.nested.nested_tensor(
+ tensor_list,
+ device=device,
+ dtype=dtype,
+ layout=torch.jagged,
+ requires_grad=requires_grad,
+ )
+
+ out_actual = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
+ out_expected = torch.cat(
+ [
+ torch.sum(t, dim=(reduce_dim[0] - 1)).unsqueeze(0)
+ for t in nt.unbind()
+ ]
+ )
+
+ self.assertFalse(
+ out_actual.is_nested,
+ "sum(): the result of reducing a nested tensor along the ragged dimension is a dense tensor",
+ ) # output is a dense tensor
+ self.assertTrue(torch.allclose(out_actual, out_expected))
+
+ @dtypes(torch.float32)
+ @parametrize("keepdim", [False, True])
+ @parametrize("requires_grad", [False, True])
+ @parametrize("components_require_grad", [False, True])
+ def test_sum_dim_reduce_ragged_and_non_batch(
+ self,
+ device,
+ dtype,
+ keepdim,
+ requires_grad,
+ components_require_grad,
+ ):
+ """
+ Sum on NestedTensor fails when trying to reduce across ragged and non-batch dimensions
+ """
+ tensor_lists = self._get_example_tensor_lists(
+ include_list_of_lists=False, include_requires_grad=components_require_grad
+ )
+ reduce_dims = (
+ (1, 2), # ragged, non-batch
+ (1, 3), # ragged, non-batch
+ )
+
+ for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
+ nt = torch.nested.nested_tensor(
+ tensor_list,
+ device=device,
+ dtype=dtype,
+ layout=torch.jagged,
+ requires_grad=requires_grad,
+ )
+
+ if nt.dim() > reduce_dim[-1]:
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "not supported along a ragged and non-batch dimension for NestedTensor",
+ ):
+ out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
+
+ @dtypes(torch.float32)
+ @parametrize("keepdim", [False, True])
+ @parametrize("requires_grad", [False, True])
+ @parametrize("components_require_grad", [False, True])
+ def test_sum_dim_reduce_batch_and_non_batch(
+ self,
+ device,
+ dtype,
+ keepdim,
+ requires_grad,
+ components_require_grad,
+ ):
+ """
+ Sum on NestedTensor fails when trying to reduce across batch and non-batch dimensions
+ """
+ tensor_lists = self._get_example_tensor_lists(
+ include_list_of_lists=False, include_requires_grad=components_require_grad
+ )
+ reduce_dims = (
+ (0, 2), # batch, non-batch
+ (0, 3), # batch, non-batch
+ )
+
+ for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
+ nt = torch.nested.nested_tensor(
+ tensor_list,
+ device=device,
+ dtype=dtype,
+ layout=torch.jagged,
+ requires_grad=requires_grad,
+ )
+
+ if nt.dim() > reduce_dim[-1]:
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "not supported along only the batch dimension for NestedTensor",
+ ):
+ out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
+
+ @dtypes(torch.float32)
+ @parametrize("keepdim", [False, True])
+ @parametrize("requires_grad", [False, True])
+ @parametrize("components_require_grad", [False, True])
+ def test_sum_dim_reduce_batch(
+ self,
+ device,
+ dtype,
+ keepdim,
+ requires_grad,
+ components_require_grad,
+ ):
+ """
+ Sum on NestedTensor fails when trying to reduce across batch dimension
+ """
+ tensor_lists = self._get_example_tensor_lists(
+ include_list_of_lists=False, include_requires_grad=components_require_grad
+ )
+ reduce_dim = (0,) # batch
+
+ for tensor_list in tensor_lists:
+ nt = torch.nested.nested_tensor(
+ tensor_list,
+ device=device,
+ dtype=dtype,
+ layout=torch.jagged,
+ requires_grad=requires_grad,
+ )
+
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "not supported along only the batch dimension for NestedTensor",
+ ):
+ out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
+
+ @dtypes(torch.float32)
+ @parametrize("keepdim", [False, True])
+ @parametrize("requires_grad", [False, True])
+ @parametrize("components_require_grad", [False, True])
+ def test_sum_dim_ragged_dim_not_1(
+ self,
+ device,
+ dtype,
+ keepdim,
+ requires_grad,
+ components_require_grad,
+ ):
+ """
+ Sum on NestedTensor fails when trying to reduce a nested tensor with ragged_idx != 1
+ """
+ tensor_lists = self._get_example_tensor_lists(
+ include_list_of_lists=False, include_requires_grad=components_require_grad
+ )
+ reduce_dims = (
+ (1,),
+ (2, 3),
+ )
+
+ for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
+ nt = torch.nested.nested_tensor(
+ tensor_list,
+ device=device,
+ dtype=dtype,
+ layout=torch.jagged,
+ requires_grad=requires_grad,
+ )
+
+ if (
+ nt.dim() > 2 and nt.dim() > reduce_dim[-1]
+ ): # ensure that we can transpose dims 1 and 2
+ nt_transposed = nt.transpose(1, 2)
+
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "not supported when ragged_idx != 1 for NestedTensor",
+ ):
+ out = torch.sum(nt_transposed, dim=reduce_dim, keepdim=keepdim)
+
+ @dtypes(torch.float32)
+ @parametrize("keepdim", [False, True])
+ @parametrize("requires_grad", [False, True])
+ @parametrize("components_require_grad", [False, True])
+ def test_sum_dim_with_lengths(
+ self,
+ device,
+ dtype,
+ keepdim,
+ requires_grad,
+ components_require_grad,
+ ):
+ """
+ Sum on NestedTensor fails when trying to reduce a nested tensor with lengths,
+ i.e. a nested tensor with holes, when reducing on the ragged dimension
+ """
+ reduce_dims = (
+ (1,),
+ (2,),
+ (2, 3),
+ )
+
+ lengths = torch.randint(5, 10, (20,), device=device)
+ offsets = torch.zeros((21,), device=device, dtype=torch.int)
+ torch.cumsum(lengths, dim=0, out=offsets[1:])
+
+ values = torch.randn(
+ (offsets[-1].item(), 20),
+ device=device,
+ dtype=dtype,
+ requires_grad=requires_grad,
+ )
+
+ nt_with_holes = torch.nested.nested_tensor_from_jagged(
+ values,
+ offsets,
+ lengths=offsets.diff() - 2, # arbitrary subtraction to create holes
+ )
+
+ for reduce_dim in reduce_dims:
+ if nt_with_holes.dim() > reduce_dim[-1]:
+ if nt_with_holes._ragged_idx in reduce_dim:
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "not supported where lengths is not None "
+ + "if reducing across the ragged dimension for NestedTensor",
+ ):
+ out = torch.sum(nt_with_holes, dim=reduce_dim, keepdim=keepdim)
+ else:
+ out = torch.sum(nt_with_holes, dim=reduce_dim, keepdim=keepdim)
@dtypes(torch.float, torch.double, torch.half)
@parametrize("requires_grad", [False, True])
diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py
index 6193f65..1fdd780 100644
--- a/torch/nested/_internal/ops.py
+++ b/torch/nested/_internal/ops.py
@@ -36,27 +36,23 @@
return _outer_to_inner_dim(ndim, wrapped) if convert_to_inner_dim else wrapped
-def _wrap_jagged_dims(ndim, dims, op_name):
- # ex: (2, 3, 4) -> (1, 2, 3)
- # ex: (0, 1, 4) -> (0, 3)
+def _wrap_jagged_dims(ndim, dims, op_name, ragged_idx=1):
from torch._prims_common import canonicalize_dims
- wrapped_dims = [canonicalize_dims(ndim, d) for d in dims]
- # This logic needs to be done after we canonicalize dims but before we
- # map to inner dims so we can print a nicer error message.
- zero_in_dims = 0 in wrapped_dims
- one_in_dims = 1 in wrapped_dims
- if zero_in_dims ^ one_in_dims:
- apply, not_apply = ("batch", "ragged") if zero_in_dims else ("ragged", "batch")
- raise RuntimeError(
- f"{op_name}(): applying over the {apply} dimension, but not the {not_apply}"
- " dimension is not supported for NestedTensor"
- )
- return (
- tuple(_outer_to_inner_dim(ndim, d) for d in dims if d != 0),
- zero_in_dims,
+ wrapped_dims = [
+ canonicalize_dims(ndim, d) for d in dims
+ ] # convert all indices to non-negative values
+
+ operate_on_batch = 0 in wrapped_dims
+ operate_on_ragged = ragged_idx in wrapped_dims
+ operate_on_non_batch = any(d != 0 and d != ragged_idx for d in wrapped_dims)
+
+ outer_to_inner_dim = tuple(
+ _outer_to_inner_dim(ndim, d) for d in wrapped_dims if d != 0
)
+ return outer_to_inner_dim, operate_on_batch, operate_on_ragged, operate_on_non_batch
+
def check_schema(schema_str: str, func, *args, **kwargs) -> None:
named_arg_types = schema_str.split(", ")
@@ -847,28 +843,82 @@
@register_jagged_func(
- torch.ops.aten.sum.dim_IntList, "self: jt, dim: any?, keepdim: any?, dtype: any?"
+ torch.ops.aten.sum.dim_IntList,
+ "self: jt_all, dim: any?, keepdim: any?, dtype: any?",
)
def sum_dim_IntList(func, *args, **kwargs):
- # sum_dim_IntList can produce a NT or a T depending on whether the ragged dims
- # are reduced away.
+ """
+ Performs a sum along the provided tensor dimension.
+ Returns a dense tensor if the ragged dimension is reduced away, else returns a nested tensor.
+ """
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
- assert inp._ragged_idx == 1
- new_kwargs["dim"], ragged_reduced_away = _wrap_jagged_dims(
- inp.dim(), new_kwargs["dim"], "sum"
+
+ (
+ new_kwargs["dim"],
+ reduce_on_batch,
+ reduce_on_ragged,
+ reduce_on_non_batch, # noqa: UFMT
+ ) = _wrap_jagged_dims(
+ inp.dim(),
+ new_kwargs["dim"],
+ "sum",
+ inp._ragged_idx,
)
- if not ragged_reduced_away:
- return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
- else:
- # Don't wrap because we reduced away the raggedness
- out = func(inp._values, **new_kwargs)
+ if inp._ragged_idx != 1:
+ raise RuntimeError("sum(): not supported when ragged_idx != 1 for NestedTensor")
+
+ if reduce_on_ragged and inp._lengths is not None:
+ raise RuntimeError(
+ "sum(): not supported where lengths is not None "
+ + "if reducing across the ragged dimension for NestedTensor"
+ )
+
+ if reduce_on_ragged: # raggedness reduced away --> return dense tensor
+ if (
+ reduce_on_batch
+ ): # reduction cases: (batch, ragged), (batch, ragged, non-batch), etc.
+ out = func(
+ inp._values, **new_kwargs
+ ) # no need to read offsets --> apply sum directly on values
+ else:
+ if (
+ reduce_on_non_batch
+ ): # invalid reduction cases: (ragged, non-batch), etc.
+ raise RuntimeError(
+ "sum(): not supported along a ragged and non-batch dimension for NestedTensor"
+ )
+ # reduction cases: (ragged)
+ out = torch.sum(
+ torch.ops.aten._jagged_to_padded_dense_forward(
+ inp._values.view(*inp._values.shape[: inp._ragged_idx], -1),
+ [inp._offsets],
+ max_lengths=[inp._max_seqlen],
+ ).view(
+ *inp.shape[: inp._ragged_idx],
+ inp._max_seqlen,
+ *inp.shape[inp._ragged_idx + 1 :],
+ ),
+ dim=inp._ragged_idx,
+ ) # need to read offsets --> pad jagged dimension and apply sum
+
if new_kwargs["keepdim"]:
out = out.unsqueeze(0)
return out
+ else: # raggedness preserved --> return nested tensor
+ if (
+ reduce_on_batch
+ ): # invalid reduction cases: (batch), (batch, non-batch), etc.
+ raise RuntimeError(
+ "sum(): not supported along only the batch dimension for NestedTensor"
+ )
+ # reduction cases: (non-batch), (non-batch, non-batch), etc.
+ return NestedTensor(
+ func(inp._values, **new_kwargs), **extract_kwargs(inp)
+ ) # apply sum directly on values
@register_jagged_func(