[NestedTensor] Support ragged_idx != 1 in pointwise ops (#118157)
This PR allows pointwise ops to operate on tensors with ragged_idx != 1. It does this by passing the ragged_idx metadata into the construction of the returned NestedTensor when computing pointwise ops. The assumption is that: pointwise ops can operate directly on the values tensors, and the resulting tensor should have all the same metadata properties as the input tensors. For binary ops, a test is added to verify that adding two tensors with different ragged_idx cannot be added.
Previously:
* unary pointwise ops would error out when performed on nested tensors with ragged_idx != 1
* binary pointwise ops would produce tensors with nonsense shapes
Differential Revision: [D53032641](https://our.internmc.facebook.com/intern/diff/D53032641)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118157
Approved by: https://github.com/jbschlosser
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index fcdf6fe..942185f 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -3054,6 +3054,32 @@
gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
+ def test_unary_pointwise_transposed_inputs(self, device):
+ a, b, c = (
+ torch.randn(i + 2, 5, requires_grad=True, dtype=torch.float64, device=device) for i in range(3)
+ )
+
+ nt, _ = jagged_from_list([a.detach(), b.detach(), c.detach()], None)
+ nt_t = nt.transpose(1, 2)
+ self.assertFalse(nt_t.is_contiguous())
+ out = torch.nn.functional.silu(nt_t.sin().cos())
+ self.assertEqual(out.is_contiguous(), torch.nn.functional.silu(b.transpose(-1, -2).sin().cos()).is_contiguous())
+
+ self.assertEqual(nt_t.shape, out.shape)
+
+ a, b, c = (
+ torch.randn(i + 2, 5, requires_grad=True, dtype=torch.float64, device=device) for i in range(3)
+ )
+
+ def grad_test_func(a, b, c):
+ nt, _ = jagged_from_list([a, b, c], None)
+ nt_t = nt.transpose(1, 2)
+ out = torch.nn.functional.silu(nt_t.sin().cos())
+ return buffer_from_jagged(out)
+
+ gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
+
+
def test_binary_pointwise(self, device):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
@@ -3078,6 +3104,43 @@
gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
+ def test_binary_pointwise_transposed(self, device):
+ a, b, c = (
+ torch.randn(i + 2, 5, dtype=torch.float64, device=device) for i in range(3)
+ )
+
+ nt1, offsets = jagged_from_list([a, b, c], None)
+ nt2, offsets = jagged_from_list([a, b, c], offsets)
+
+ nt1_t = nt1.transpose(1, 2)
+ nt2_t = nt2.transpose(1, 2)
+
+ out = nt1_t * nt2_t
+ self.assertFalse(nt1_t.is_contiguous())
+ self.assertEqual(out.is_contiguous(), (b.transpose(-1, -2) * b.transpose(-1, -2)).is_contiguous())
+ self.assertEqual(out.shape, nt1_t.shape)
+
+ self.assertRaisesRegex(
+ RuntimeError,
+ "cannot call binary pointwise function mul.Tensor with inputs of shapes",
+ lambda: nt1 * nt2_t,
+ )
+
+ a, b, c = (
+ torch.randn(i + 2, 5, requires_grad=True, dtype=torch.float64, device=device) for i in range(3)
+ )
+
+ # Correct usage: chain the calls using the same offsets tensor object
+ def grad_test_func(a, b, c):
+ nt1, offsets = jagged_from_list([a, b, c], None)
+ nt2, offsets = jagged_from_list([a, b, c], offsets)
+ nt1_t = nt1.transpose(1, 2)
+ nt2_t = nt2.transpose(1, 2)
+ out = nt1_t * nt2_t
+ return buffer_from_jagged(out)
+
+ gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
+
def test_split(self, device):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
@@ -3555,9 +3618,25 @@
nt, _ = jagged_from_list([a, b, c], None)
# transpose ragged dim
transposed = nt.transpose(1, 2)
- # pointwise ops are not supported on ragged dim transposed jagged layout NTs
- with self.assertRaisesRegex(ValueError, "expected .* to be a contiguous jagged layout"):
- clone = transposed.clone()
+ self.assertFalse(transposed.is_contiguous())
+ clone = transposed.clone()
+
+ def check_nt_equality(x, y):
+ self.assertEqual(x.values(), y.values())
+ self.assertEqual(x.offsets(), y.offsets())
+ self.assertEqual(x._ragged_idx, y._ragged_idx)
+ self.assertEqual(x.shape, y.shape)
+
+ self.assertFalse(clone.is_contiguous())
+ check_nt_equality(clone, transposed)
+
+ clone_contig = transposed.clone(memory_format=torch.contiguous_format)
+ self.assertTrue(clone_contig.is_contiguous())
+ check_nt_equality(clone_contig, transposed)
+
+ detached = transposed.detach()
+ self.assertFalse(clone.is_contiguous())
+ check_nt_equality(detached, transposed)
# Note 1: Math fallback doesn't work with bfloat16 on CUDA
# Note 2: ROCm doesn't support flash attention or mem_efficient attention for NT
@@ -3679,7 +3758,6 @@
if not (str(device).startswith("cuda") and dtype == torch.bfloat16):
check_forward_backward()
-
# This requires NT -> NT views to work in inductor, which is a TODO
@unittest.expectedFailure # noqa: E301
@onlyCUDA
diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py
index 3aacd3e..8d68b68 100644
--- a/torch/nested/_internal/ops.py
+++ b/torch/nested/_internal/ops.py
@@ -191,7 +191,7 @@
# Assume there aren't additional tensors that aren't the "unary/binary" args
num_tensor_args = sum([isinstance(x, torch.Tensor) for x in args])
if num_tensor_args == 1:
- check_schema("self: jt, ...", func, *args, **kwargs)
+ check_schema("self: jt_all, ...", func, *args, **kwargs)
return functools.partial(jagged_unary_pointwise, func)
elif num_tensor_args == 2:
check_schema("lhs: any, rhs: any", func, *args, **kwargs)
@@ -204,6 +204,7 @@
kwargs = {
"offsets": arg.offsets(),
"_metadata_cache": arg._metadata_cache,
+ "_ragged_idx": arg._ragged_idx,
}
return kwargs
@@ -374,10 +375,6 @@
if inp.lengths() is not None:
return False
- # If jagged dim is not 1 it's not contiguous
- if inp._ragged_idx != 1:
- return False
-
new_kwargs["memory_format"] = new_kwargs.get(
"memory_format", torch.contiguous_format
)
@@ -446,7 +443,7 @@
torch.ops.aten.randn_like.default,
torch.ops.aten.detach.default,
],
- "self: jt",
+ "self: jt_all",
)(jagged_unary_pointwise)
@@ -791,13 +788,14 @@
to_dim = dim1
else:
to_dim = dim0
+ inp_kwargs = extract_kwargs(inp)
+ inp_kwargs["_ragged_idx"] = to_dim
return NestedTensor(
inp.values().transpose(
_outer_to_inner_dim(len(inp._size), dim0),
_outer_to_inner_dim(len(inp._size), dim1),
),
- **extract_kwargs(inp),
- _ragged_idx=to_dim,
+ **inp_kwargs,
)
new_kwargs["dim0"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim0"], "transpose")