[NT] Backward support for broadcasting binary ops (#112519)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112519
Approved by: https://github.com/jbschlosser
ghstack dependencies: #113031
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index b6c15dd..50b32eb 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -31,7 +31,11 @@
TestCase,
)
-from torch.nested._internal.nested_tensor import jagged_from_list, buffer_from_jagged
+from torch.nested._internal.nested_tensor import (
+ buffer_from_jagged,
+ jagged_from_list,
+ NestedTensor,
+)
import contextlib
# Tests are ported from pytorch/nestedtensor.
@@ -2897,6 +2901,16 @@
unittest.mock.patch.object(torch._dynamo.config, "suppress_errors", False)
)
+ # TODO: consolidate with the below
+ def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True):
+ Ds = nested_size[1:]
+ out = []
+ for s in nested_size[0]:
+ out.append(
+ torch.randn(s, *Ds, requires_grad=requires_grad, device=device, dtype=torch.float64)
+ )
+ return out
+
def _get_example_tensor_lists(self, include_list_of_lists=True, include_requires_grad=True):
def _make_tensor(*shape, include_requires_grad=include_requires_grad, requires_grad=True):
@@ -3032,6 +3046,83 @@
):
torch.split(nt, [1, 2], 1)
+ def test_binary_pointwise_broadcasting(self, device):
+ # (B, j0, 3, 4)
+ ts = self._get_list_for_jagged_tensor(((2, 3, 4), 3, 4), device, requires_grad=True)
+ # (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
+ # (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
+ # (B, j0, ?, ?) + (1, ?, ?) -> (B, j0, ?, ?)
+ # Unsupported: (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?)
+ t_sizes = (
+ (4,),
+ (1, 4),
+ (3, 1),
+ (1, 3, 1),
+ (1, 1, 1, 4),
+ # (1, 1, 1, 1, 4), (unsupported today)
+ )
+
+ def grad_test_func(t, *ts):
+ nt, _ = jagged_from_list(ts, None)
+ out = nt + t
+ return buffer_from_jagged(out)
+
+ for t_size in t_sizes:
+ t = torch.rand(t_size, requires_grad=True, device=device, dtype=torch.float64)
+ gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False)
+
+ @parametrize("keepdim", [False, True])
+ def test_sum_int_DimList(self, device, keepdim):
+ # (B, j0, 3, 4)
+ ts = self._get_list_for_jagged_tensor(((2, 3, 4), 3, 4), device=device, requires_grad=True)
+
+ # Check shape correctness
+ reduce_dims = (
+ # dims, expected shape, expected keepdim shape
+ # j0 is represented as None
+ ((0, 1), (3, 4), (1, 1, 3, 4)),
+ ((1, 2), None, None),
+ ((2, 3), (3, None), (3, None, 1, 1)),
+ ((0, 1, 3), (3,), (1, 1, 3, 1)),
+ ((0, 1, 2), (4,), (1, 1, 1, 4)),
+ ((0, 1, 2, 3), tuple(), (1, 1, 1, 1)),
+ )
+ for rd, ref_shape_no_keepdim, ref_shape_keepdim in reduce_dims:
+ if (0 in rd) ^ (1 in rd):
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "applying over the ragged dimension, but not the batch dimension"):
+ nt, _ = jagged_from_list(ts, None)
+ out = torch.sum(nt, dim=rd, keepdim=keepdim)
+ continue
+
+ nt, _ = jagged_from_list(ts, None)
+ out = torch.sum(nt, dim=rd, keepdim=keepdim)
+ ref_shape = ref_shape_keepdim if keepdim else ref_shape_no_keepdim
+ self.assertEqual(len(out.shape), len(ref_shape))
+ for o, r in zip(out.shape, ref_shape):
+ if r is not None:
+ self.assertEqual(o, r)
+ else:
+ self.assertTrue(isinstance(o, torch.SymInt))
+
+ # Check values correctness
+ # raggedness not reduced
+ nt, _ = jagged_from_list(ts, None)
+ out = torch.sum(nt, dim=(2, 3), keepdim=keepdim)
+ out_ref = torch.sum(nt.values(), dim=(1, 2))
+ self.assertIsInstance(out, NestedTensor)
+ # flatten to avoid having to replicate unsqueeze logic depending on keepdim
+ self.assertTrue(torch.allclose(out.values().view(-1), out_ref.view(-1)))
+
+ # raggedness reduced away
+ nt, _ = jagged_from_list(ts, None)
+ out = torch.sum(nt, dim=(0, 1), keepdim=keepdim)
+ out_ref = torch.sum(nt.values(), dim=(0,))
+ self.assertNotIsInstance(out, NestedTensor)
+ self.assertTrue(torch.allclose(out, out_ref))
+
+
@dtypes(torch.float, torch.double, torch.half)
@parametrize("requires_grad", [False, True])
@parametrize("weights_only", [False, True])
diff --git a/torch/csrc/autograd/input_metadata.cpp b/torch/csrc/autograd/input_metadata.cpp
index 696bb61..723303f 100644
--- a/torch/csrc/autograd/input_metadata.cpp
+++ b/torch/csrc/autograd/input_metadata.cpp
@@ -16,6 +16,14 @@
return MetadataShape{std::in_place_type<SymIntSmallVec>, input.sym_sizes()};
}
+bool is_python_dispatch(const at::Tensor& tensor) {
+ return tensor.unsafeGetTensorImpl()->is_python_dispatch();
+}
+
+bool is_cpp_nested_tensor(const at::Tensor& tensor) {
+ return tensor.is_nested() && !is_python_dispatch(tensor);
+}
+
} // namespace
InputMetadata::InputMetadata(
@@ -36,7 +44,7 @@
: InputMetadata(
t.options(),
compute_variant_shape(t),
- t.unsafeGetTensorImpl()->is_python_dispatch(),
+ is_python_dispatch(t),
t.is_nested()) {}
at::Tensor InputMetadata::zeros_like() const {
@@ -46,7 +54,9 @@
}
bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
- check_nestedness_same(grad);
+ if (!is_nestedness_same(grad)) {
+ return false;
+ }
if (is_cpp_nested_tensor()) {
return grad._nested_tensor_size().is_same_size(shape_as_tensor());
}
@@ -54,19 +64,15 @@
}
bool InputMetadata::is_expandable_to_shape(const at::Tensor& grad) const {
- // Currently NestedTensors are not expandable. If this support is added then
- // updates to reduce_grad will be needed
- check_nestedness_same(grad);
- return grad.is_nested()
- ? false
- : at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
+ if (!maybe_expandable_to(grad)) {
+ return false;
+ }
+ return at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
}
at::Tensor InputMetadata::reduce_grad(at::Tensor& grad) const {
- // Currently reduce_grad is only called if is_expandable_to_shape returns
- // true For nested tensors this always returns False, so this check
- // shouldn't fail
- TORCH_INTERNAL_ASSERT(!grad.is_nested() && !is_nested_)
+ // reduce_grad should only be called if is_expandable_to_shape returns true.
+ TORCH_INTERNAL_ASSERT(maybe_expandable_to(grad));
return at::sum_to(std::move(grad), shape_as_dim_vector());
}
@@ -75,7 +81,7 @@
const at::Tensor& grad) const {
std::stringstream ss{};
ss << "invalid gradient at index " << index << " - got ";
- if (grad.is_nested() && !grad.unsafeGetTensorImpl()->is_python_dispatch()) {
+ if (::torch::autograd::is_cpp_nested_tensor(grad)) {
ss << grad._nested_tensor_size();
} else {
ss << grad.sym_sizes();
@@ -106,21 +112,34 @@
return std::get<SymIntSmallVec>(shape_);
}
-void InputMetadata::check_nestedness_same(const at::Tensor& grad) const {
- bool grad_is_subclass = grad.unsafeGetTensorImpl()->is_python_dispatch();
- bool grad_is_nested = grad.is_nested();
- bool grad_is_cpp_nested = grad_is_nested && !grad_is_subclass;
- TORCH_CHECK(
- grad_is_cpp_nested == is_cpp_nested_tensor() &&
- grad_is_nested == is_nested_,
- "grad and the input wrt the gradient that is being computed for need to be "
- "either both nested or both non-nested tensors. Also note that nested "
- "tensors with different layouts do not compose currently.");
+bool InputMetadata::is_nestedness_same(const at::Tensor& grad) const {
+ return (
+ grad.is_nested() == is_nested_ &&
+ ::torch::autograd::is_cpp_nested_tensor(grad) == is_cpp_nested_tensor());
}
at::Tensor InputMetadata::shape_as_tensor() const {
return std::get<at::Tensor>(shape_);
}
+bool InputMetadata::maybe_expandable_to(const at::Tensor& grad) const {
+ // This is the initial step to determine whether or not the tensor represented
+ // by input_metadata is expandable to grad based on is-nestedness information
+ // alone. If this function returns true, then is_expandable_to_shape will be
+ // called. We support the following 3 types of expansion:
+ bool grad_is_nested = grad.is_nested();
+ if (!is_nested_ && !grad_is_nested) {
+ // Normal case (no NestedTensors are involved)
+ // (1) plain Tensor -> plain Tensor
+ return true;
+ } else {
+ // (2) python NT -> python NT
+ // (3) plain Tensor -> python NT
+ return (
+ grad_is_nested && is_python_dispatch(grad) &&
+ (!is_nested_ || is_tensor_subclass_));
+ }
+}
+
} // namespace autograd
} // namespace torch
diff --git a/torch/csrc/autograd/input_metadata.h b/torch/csrc/autograd/input_metadata.h
index 60c2c8c..94ab6bd 100644
--- a/torch/csrc/autograd/input_metadata.h
+++ b/torch/csrc/autograd/input_metadata.h
@@ -98,7 +98,8 @@
private:
at::Tensor shape_as_tensor() const;
- void check_nestedness_same(const at::Tensor& grad) const;
+ bool is_nestedness_same(const at::Tensor& grad) const;
+ bool maybe_expandable_to(const at::Tensor& grad) const;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const at::TensorOptions options_;
diff --git a/torch/nested/__init__.py b/torch/nested/__init__.py
index c0dba0f..6743ab0 100644
--- a/torch/nested/__init__.py
+++ b/torch/nested/__init__.py
@@ -177,7 +177,8 @@
from torch.nested._internal.nested_tensor import jagged_from_list
- nt, _ = jagged_from_list(list_of_tensors, offsets=None, device=device, dtype=dtype)
+ with torch.no_grad():
+ nt, _ = jagged_from_list(list_of_tensors, offsets=None, device=device, dtype=dtype)
nt.requires_grad_(requires_grad)
if pin_memory:
diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py
index de566d7..aaf8f52 100644
--- a/torch/nested/_internal/nested_tensor.py
+++ b/torch/nested/_internal/nested_tensor.py
@@ -61,7 +61,7 @@
torch.jagged,
values.device,
False,
- False,
+ kwargs.get("requires_grad", False),
"sizes",
False,
True, # dispatch_layout
diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py
index e16d6d9..1943d17 100644
--- a/torch/nested/_internal/ops.py
+++ b/torch/nested/_internal/ops.py
@@ -11,6 +11,13 @@
JAGGED_OPS_TABLE: Dict[Any, Any] = {}
+# Simplifying assumption: we assume that the batch dim is always the left-most
+# dim, and the ragged dim is always the second dim.
+def _outer_to_inner_dim(ndim, dim):
+ assert dim >= 0 and dim < ndim
+ return 0 if dim < 2 else dim - 1
+
+
def _wrap_jagged_dim(ndim, dim, op_name):
from torch._prims_common import canonicalize_dims
@@ -19,7 +26,29 @@
raise RuntimeError(
f"{op_name}(): not supported for NestedTensor on dim=0 or dim=1"
)
- return wrapped - 1
+ return _outer_to_inner_dim(ndim, wrapped)
+
+
+def _wrap_jagged_dims(ndim, dims, op_name):
+ # ex: (2, 3, 4) -> (1, 2, 3)
+ # ex: (0, 1, 4) -> (0, 3)
+ from torch._prims_common import canonicalize_dims
+
+ wrapped_dims = [canonicalize_dims(ndim, d) for d in dims]
+ # This logic needs to be done after we canonicalize dims but before we
+ # map to inner dims so we can print a nicer error message.
+ zero_in_dims = 0 in wrapped_dims
+ one_in_dims = 1 in wrapped_dims
+ if zero_in_dims ^ one_in_dims:
+ apply, not_apply = ("batch", "ragged") if zero_in_dims else ("ragged", "batch")
+ raise RuntimeError(
+ f"{op_name}(): applying over the {apply} dimension, but not the {not_apply}"
+ " dimension is not supported for NestedTensor"
+ )
+ return (
+ tuple(_outer_to_inner_dim(ndim, d) for d in dims if d != 0),
+ zero_in_dims,
+ )
def check_schema(schema_str: str, func, *args, **kwargs) -> None:
@@ -79,6 +108,30 @@
return list(nt._size[:end]) == list(size[:end])
+def squeeze_leading_ones(t):
+ # Note: [ Squeezing leading ones ]
+ #
+ # Squeeze leading ones from t.
+ #
+ # We want:
+ # (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
+ # (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?) (not yet supported)
+ #
+ # 1) Squeeze extra ones and grab values from NT
+ # (1, 1, ?, ?) -> (?, ?) and (sum(*), ?, ?) -> (B, j0, ?, ?)
+ # 2) Do dense broadcasting:
+ # (sum(*), ?, ?) + (?, ?) -> (sum(*), ?, ?)
+ # 3) Construct nested tensor
+ # (sum(*), ?, ?) -> (B, j0, ?, ?)
+ #
+ # If unsqueezing on the 0th dim becomes supported, we would unsqueeze
+ # at step (4) and we would need to update this function to record how
+ # many ones we unsqueezed.
+ while t.shape[0] == 1:
+ t = t.squeeze(0)
+ return t
+
+
def register_func(tables, aten_ops, schema_str):
if not isinstance(aten_ops, list):
aten_ops = [aten_ops]
@@ -163,15 +216,17 @@
# === Handle broadcasting across the batch / ragged dims ===
# Easy case: take advantage of pre-existing broadcasting logic
- # when NT dim > non-NT dim
- # ex: (B, j0, D_0, D_1) + (D_0, D_1) -> (B, j0, D_0, D_1)
- # ex: (B, j0, D_0, D_1) + (1, D_0, D_1) -> (B, j0, D_0, D_1)
- # ex: (B, j0, 1, 1) + (D_0, D_1) -> (B, j0, D_0, D_1)
- # ex: (B, j0, 1, 1) + (1, D_0, D_1) -> (B, j0, D_0, D_1)
- if (a_is_nt and a.dim() > b.dim()) or (not a_is_nt and b.dim() > a.dim()):
- arg1 = a._values if a_is_nt else a
- arg2 = b._values if not a_is_nt else b
- return NestedTensor(func(arg1, arg2, *args[2:], **kwargs), **extracted_kwargs)
+ # ex: (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
+ # ex: (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
+ # ex: (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
+ nt, t = (a, b) if a_is_nt else (b, a)
+ # See Note: [ Squeezing leading ones ]
+ if t.dim() > nt.dim():
+ raise NotImplementedError("NYI: broadcasting NT with T with larger dim")
+ t_squeezed = squeeze_leading_ones(t)
+ if nt.dim() >= t_squeezed.dim() + 2:
+ lhs, rhs = (nt._values, t_squeezed) if a_is_nt else (t_squeezed, nt._values)
+ return NestedTensor(func(lhs, rhs, *args[2:], **kwargs), **extracted_kwargs)
# Harder case: do manual broadcasting over unbound components
# when NT dim == non-NT dim
@@ -598,6 +653,31 @@
return args[0]._size == args[1]._size
+@register_jagged_func(
+ torch.ops.aten.sum.dim_IntList, "self: jt, dim: any?, keepdim: any?, dtype: any?"
+)
+def sum_dim_IntList(func, *args, **kwargs):
+ # sum_dim_IntList can produce a NT or a T depending on whether the ragged dims
+ # are reduced away.
+ _, new_kwargs = normalize_function(
+ func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+ )
+ inp = new_kwargs.pop("input")
+ assert inp._ragged_idx == 1
+ new_kwargs["dim"], ragged_reduced_away = _wrap_jagged_dims(
+ inp.dim(), new_kwargs["dim"], "sum"
+ )
+
+ if not ragged_reduced_away:
+ return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+ else:
+ # Don't wrap because we reduced away the raggedness
+ out = func(inp._values, **new_kwargs)
+ if new_kwargs["keepdim"]:
+ out = out.unsqueeze(0)
+ return out
+
+
@register_jagged_func(torch.ops.aten.transpose.int, "self: jt, dim0: any, dim1: any")
def transpose_int(func, *args, **kwargs):
_, new_kwargs = normalize_function(