[NT] Backward support for broadcasting binary ops (#112519)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112519
Approved by: https://github.com/jbschlosser
ghstack dependencies: #113031
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index b6c15dd..50b32eb 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -31,7 +31,11 @@
     TestCase,
 )
 
-from torch.nested._internal.nested_tensor import jagged_from_list, buffer_from_jagged
+from torch.nested._internal.nested_tensor import (
+    buffer_from_jagged,
+    jagged_from_list,
+    NestedTensor,
+)
 import contextlib
 
 # Tests are ported from pytorch/nestedtensor.
@@ -2897,6 +2901,16 @@
             unittest.mock.patch.object(torch._dynamo.config, "suppress_errors", False)
         )
 
+    # TODO: consolidate with the below
+    def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True):
+        Ds = nested_size[1:]
+        out = []
+        for s in nested_size[0]:
+            out.append(
+                torch.randn(s, *Ds, requires_grad=requires_grad, device=device, dtype=torch.float64)
+            )
+        return out
+
     def _get_example_tensor_lists(self, include_list_of_lists=True, include_requires_grad=True):
 
         def _make_tensor(*shape, include_requires_grad=include_requires_grad, requires_grad=True):
@@ -3032,6 +3046,83 @@
         ):
             torch.split(nt, [1, 2], 1)
 
+    def test_binary_pointwise_broadcasting(self, device):
+        # (B, j0, 3, 4)
+        ts = self._get_list_for_jagged_tensor(((2, 3, 4), 3, 4), device, requires_grad=True)
+        # (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
+        # (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
+        # (B, j0, ?, ?) + (1, ?, ?) -> (B, j0, ?, ?)
+        # Unsupported: (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?)
+        t_sizes = (
+            (4,),
+            (1, 4),
+            (3, 1),
+            (1, 3, 1),
+            (1, 1, 1, 4),
+            # (1, 1, 1, 1, 4), (unsupported today)
+        )
+
+        def grad_test_func(t, *ts):
+            nt, _ = jagged_from_list(ts, None)
+            out = nt + t
+            return buffer_from_jagged(out)
+
+        for t_size in t_sizes:
+            t = torch.rand(t_size, requires_grad=True, device=device, dtype=torch.float64)
+            gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False)
+
+    @parametrize("keepdim", [False, True])
+    def test_sum_int_DimList(self, device, keepdim):
+        # (B, j0, 3, 4)
+        ts = self._get_list_for_jagged_tensor(((2, 3, 4), 3, 4), device=device, requires_grad=True)
+
+        # Check shape correctness
+        reduce_dims = (
+            # dims, expected shape, expected keepdim shape
+            # j0 is represented as None
+            ((0, 1), (3, 4), (1, 1, 3, 4)),
+            ((1, 2), None, None),
+            ((2, 3), (3, None), (3, None, 1, 1)),
+            ((0, 1, 3), (3,), (1, 1, 3, 1)),
+            ((0, 1, 2), (4,), (1, 1, 1, 4)),
+            ((0, 1, 2, 3), tuple(), (1, 1, 1, 1)),
+        )
+        for rd, ref_shape_no_keepdim, ref_shape_keepdim in reduce_dims:
+            if (0 in rd) ^ (1 in rd):
+                with self.assertRaisesRegex(
+                        RuntimeError,
+                        "applying over the ragged dimension, but not the batch dimension"):
+                    nt, _ = jagged_from_list(ts, None)
+                    out = torch.sum(nt, dim=rd, keepdim=keepdim)
+                continue
+
+            nt, _ = jagged_from_list(ts, None)
+            out = torch.sum(nt, dim=rd, keepdim=keepdim)
+            ref_shape = ref_shape_keepdim if keepdim else ref_shape_no_keepdim
+            self.assertEqual(len(out.shape), len(ref_shape))
+            for o, r in zip(out.shape, ref_shape):
+                if r is not None:
+                    self.assertEqual(o, r)
+                else:
+                    self.assertTrue(isinstance(o, torch.SymInt))
+
+        # Check values correctness
+        # raggedness not reduced
+        nt, _ = jagged_from_list(ts, None)
+        out = torch.sum(nt, dim=(2, 3), keepdim=keepdim)
+        out_ref = torch.sum(nt.values(), dim=(1, 2))
+        self.assertIsInstance(out, NestedTensor)
+        # flatten to avoid having to replicate unsqueeze logic depending on keepdim
+        self.assertTrue(torch.allclose(out.values().view(-1), out_ref.view(-1)))
+
+        # raggedness reduced away
+        nt, _ = jagged_from_list(ts, None)
+        out = torch.sum(nt, dim=(0, 1), keepdim=keepdim)
+        out_ref = torch.sum(nt.values(), dim=(0,))
+        self.assertNotIsInstance(out, NestedTensor)
+        self.assertTrue(torch.allclose(out, out_ref))
+
+
     @dtypes(torch.float, torch.double, torch.half)
     @parametrize("requires_grad", [False, True])
     @parametrize("weights_only", [False, True])
diff --git a/torch/csrc/autograd/input_metadata.cpp b/torch/csrc/autograd/input_metadata.cpp
index 696bb61..723303f 100644
--- a/torch/csrc/autograd/input_metadata.cpp
+++ b/torch/csrc/autograd/input_metadata.cpp
@@ -16,6 +16,14 @@
   return MetadataShape{std::in_place_type<SymIntSmallVec>, input.sym_sizes()};
 }
 
+bool is_python_dispatch(const at::Tensor& tensor) {
+  return tensor.unsafeGetTensorImpl()->is_python_dispatch();
+}
+
+bool is_cpp_nested_tensor(const at::Tensor& tensor) {
+  return tensor.is_nested() && !is_python_dispatch(tensor);
+}
+
 } // namespace
 
 InputMetadata::InputMetadata(
@@ -36,7 +44,7 @@
     : InputMetadata(
           t.options(),
           compute_variant_shape(t),
-          t.unsafeGetTensorImpl()->is_python_dispatch(),
+          is_python_dispatch(t),
           t.is_nested()) {}
 
 at::Tensor InputMetadata::zeros_like() const {
@@ -46,7 +54,9 @@
 }
 
 bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
-  check_nestedness_same(grad);
+  if (!is_nestedness_same(grad)) {
+    return false;
+  }
   if (is_cpp_nested_tensor()) {
     return grad._nested_tensor_size().is_same_size(shape_as_tensor());
   }
@@ -54,19 +64,15 @@
 }
 
 bool InputMetadata::is_expandable_to_shape(const at::Tensor& grad) const {
-  // Currently NestedTensors are not expandable. If this support is added then
-  // updates to reduce_grad will be needed
-  check_nestedness_same(grad);
-  return grad.is_nested()
-      ? false
-      : at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
+  if (!maybe_expandable_to(grad)) {
+    return false;
+  }
+  return at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
 }
 
 at::Tensor InputMetadata::reduce_grad(at::Tensor& grad) const {
-  // Currently reduce_grad is only called if is_expandable_to_shape returns
-  // true For nested tensors this always returns False, so this check
-  // shouldn't fail
-  TORCH_INTERNAL_ASSERT(!grad.is_nested() && !is_nested_)
+  // reduce_grad should only be called if is_expandable_to_shape returns true.
+  TORCH_INTERNAL_ASSERT(maybe_expandable_to(grad));
   return at::sum_to(std::move(grad), shape_as_dim_vector());
 }
 
@@ -75,7 +81,7 @@
     const at::Tensor& grad) const {
   std::stringstream ss{};
   ss << "invalid gradient at index " << index << " - got ";
-  if (grad.is_nested() && !grad.unsafeGetTensorImpl()->is_python_dispatch()) {
+  if (::torch::autograd::is_cpp_nested_tensor(grad)) {
     ss << grad._nested_tensor_size();
   } else {
     ss << grad.sym_sizes();
@@ -106,21 +112,34 @@
   return std::get<SymIntSmallVec>(shape_);
 }
 
-void InputMetadata::check_nestedness_same(const at::Tensor& grad) const {
-  bool grad_is_subclass = grad.unsafeGetTensorImpl()->is_python_dispatch();
-  bool grad_is_nested = grad.is_nested();
-  bool grad_is_cpp_nested = grad_is_nested && !grad_is_subclass;
-  TORCH_CHECK(
-      grad_is_cpp_nested == is_cpp_nested_tensor() &&
-          grad_is_nested == is_nested_,
-      "grad and the input wrt the gradient that is being computed for need to be "
-      "either both nested or both non-nested tensors. Also note that nested "
-      "tensors with different layouts do not compose currently.");
+bool InputMetadata::is_nestedness_same(const at::Tensor& grad) const {
+  return (
+      grad.is_nested() == is_nested_ &&
+      ::torch::autograd::is_cpp_nested_tensor(grad) == is_cpp_nested_tensor());
 }
 
 at::Tensor InputMetadata::shape_as_tensor() const {
   return std::get<at::Tensor>(shape_);
 }
 
+bool InputMetadata::maybe_expandable_to(const at::Tensor& grad) const {
+  // This is the initial step to determine whether or not the tensor represented
+  // by input_metadata is expandable to grad based on is-nestedness information
+  // alone. If this function returns true, then is_expandable_to_shape will be
+  // called. We support the following 3 types of expansion:
+  bool grad_is_nested = grad.is_nested();
+  if (!is_nested_ && !grad_is_nested) {
+    // Normal case (no NestedTensors are involved)
+    // (1) plain Tensor -> plain Tensor
+    return true;
+  } else {
+    // (2) python NT -> python NT
+    // (3) plain Tensor -> python NT
+    return (
+        grad_is_nested && is_python_dispatch(grad) &&
+        (!is_nested_ || is_tensor_subclass_));
+  }
+}
+
 } // namespace autograd
 } // namespace torch
diff --git a/torch/csrc/autograd/input_metadata.h b/torch/csrc/autograd/input_metadata.h
index 60c2c8c..94ab6bd 100644
--- a/torch/csrc/autograd/input_metadata.h
+++ b/torch/csrc/autograd/input_metadata.h
@@ -98,7 +98,8 @@
 
  private:
   at::Tensor shape_as_tensor() const;
-  void check_nestedness_same(const at::Tensor& grad) const;
+  bool is_nestedness_same(const at::Tensor& grad) const;
+  bool maybe_expandable_to(const at::Tensor& grad) const;
 
   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
   const at::TensorOptions options_;
diff --git a/torch/nested/__init__.py b/torch/nested/__init__.py
index c0dba0f..6743ab0 100644
--- a/torch/nested/__init__.py
+++ b/torch/nested/__init__.py
@@ -177,7 +177,8 @@
 
         from torch.nested._internal.nested_tensor import jagged_from_list
 
-        nt, _ = jagged_from_list(list_of_tensors, offsets=None, device=device, dtype=dtype)
+        with torch.no_grad():
+            nt, _ = jagged_from_list(list_of_tensors, offsets=None, device=device, dtype=dtype)
 
         nt.requires_grad_(requires_grad)
         if pin_memory:
diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py
index de566d7..aaf8f52 100644
--- a/torch/nested/_internal/nested_tensor.py
+++ b/torch/nested/_internal/nested_tensor.py
@@ -61,7 +61,7 @@
             torch.jagged,
             values.device,
             False,
-            False,
+            kwargs.get("requires_grad", False),
             "sizes",
             False,
             True,  # dispatch_layout
diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py
index e16d6d9..1943d17 100644
--- a/torch/nested/_internal/ops.py
+++ b/torch/nested/_internal/ops.py
@@ -11,6 +11,13 @@
 JAGGED_OPS_TABLE: Dict[Any, Any] = {}
 
 
+# Simplifying assumption: we assume that the batch dim is always the left-most
+# dim, and the ragged dim is always the second dim.
+def _outer_to_inner_dim(ndim, dim):
+    assert dim >= 0 and dim < ndim
+    return 0 if dim < 2 else dim - 1
+
+
 def _wrap_jagged_dim(ndim, dim, op_name):
     from torch._prims_common import canonicalize_dims
 
@@ -19,7 +26,29 @@
         raise RuntimeError(
             f"{op_name}(): not supported for NestedTensor on dim=0 or dim=1"
         )
-    return wrapped - 1
+    return _outer_to_inner_dim(ndim, wrapped)
+
+
+def _wrap_jagged_dims(ndim, dims, op_name):
+    # ex: (2, 3, 4) -> (1, 2, 3)
+    # ex: (0, 1, 4) -> (0, 3)
+    from torch._prims_common import canonicalize_dims
+
+    wrapped_dims = [canonicalize_dims(ndim, d) for d in dims]
+    # This logic needs to be done after we canonicalize dims but before we
+    # map to inner dims so we can print a nicer error message.
+    zero_in_dims = 0 in wrapped_dims
+    one_in_dims = 1 in wrapped_dims
+    if zero_in_dims ^ one_in_dims:
+        apply, not_apply = ("batch", "ragged") if zero_in_dims else ("ragged", "batch")
+        raise RuntimeError(
+            f"{op_name}(): applying over the {apply} dimension, but not the {not_apply}"
+            " dimension is not supported for NestedTensor"
+        )
+    return (
+        tuple(_outer_to_inner_dim(ndim, d) for d in dims if d != 0),
+        zero_in_dims,
+    )
 
 
 def check_schema(schema_str: str, func, *args, **kwargs) -> None:
@@ -79,6 +108,30 @@
     return list(nt._size[:end]) == list(size[:end])
 
 
+def squeeze_leading_ones(t):
+    # Note: [ Squeezing leading ones ]
+    #
+    # Squeeze leading ones from t.
+    #
+    # We want:
+    #   (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
+    #   (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?)  (not yet supported)
+    #
+    # 1) Squeeze extra ones and grab values from NT
+    #   (1, 1, ?, ?) -> (?, ?)   and   (sum(*), ?, ?) -> (B, j0, ?, ?)
+    # 2) Do dense broadcasting:
+    #   (sum(*), ?, ?) + (?, ?) -> (sum(*), ?, ?)
+    # 3) Construct nested tensor
+    #   (sum(*), ?, ?) -> (B, j0, ?, ?)
+    #
+    # If unsqueezing on the 0th dim becomes supported, we would unsqueeze
+    # at step (4) and we would need to update this function to record how
+    # many ones we unsqueezed.
+    while t.shape[0] == 1:
+        t = t.squeeze(0)
+    return t
+
+
 def register_func(tables, aten_ops, schema_str):
     if not isinstance(aten_ops, list):
         aten_ops = [aten_ops]
@@ -163,15 +216,17 @@
     # === Handle broadcasting across the batch / ragged dims ===
 
     # Easy case: take advantage of pre-existing broadcasting logic
-    # when NT dim > non-NT dim
-    # ex: (B, j0, D_0, D_1) + (D_0, D_1) -> (B, j0, D_0, D_1)
-    # ex: (B, j0, D_0, D_1) + (1, D_0, D_1) -> (B, j0, D_0, D_1)
-    # ex: (B, j0, 1, 1) + (D_0, D_1) -> (B, j0, D_0, D_1)
-    # ex: (B, j0, 1, 1) + (1, D_0, D_1) -> (B, j0, D_0, D_1)
-    if (a_is_nt and a.dim() > b.dim()) or (not a_is_nt and b.dim() > a.dim()):
-        arg1 = a._values if a_is_nt else a
-        arg2 = b._values if not a_is_nt else b
-        return NestedTensor(func(arg1, arg2, *args[2:], **kwargs), **extracted_kwargs)
+    # ex: (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
+    # ex: (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
+    # ex: (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
+    nt, t = (a, b) if a_is_nt else (b, a)
+    # See Note: [ Squeezing leading ones ]
+    if t.dim() > nt.dim():
+        raise NotImplementedError("NYI: broadcasting NT with T with larger dim")
+    t_squeezed = squeeze_leading_ones(t)
+    if nt.dim() >= t_squeezed.dim() + 2:
+        lhs, rhs = (nt._values, t_squeezed) if a_is_nt else (t_squeezed, nt._values)
+        return NestedTensor(func(lhs, rhs, *args[2:], **kwargs), **extracted_kwargs)
 
     # Harder case: do manual broadcasting over unbound components
     # when NT dim == non-NT dim
@@ -598,6 +653,31 @@
     return args[0]._size == args[1]._size
 
 
+@register_jagged_func(
+    torch.ops.aten.sum.dim_IntList, "self: jt, dim: any?, keepdim: any?, dtype: any?"
+)
+def sum_dim_IntList(func, *args, **kwargs):
+    # sum_dim_IntList can produce a NT or a T depending on whether the ragged dims
+    # are reduced away.
+    _, new_kwargs = normalize_function(
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+    inp = new_kwargs.pop("input")
+    assert inp._ragged_idx == 1
+    new_kwargs["dim"], ragged_reduced_away = _wrap_jagged_dims(
+        inp.dim(), new_kwargs["dim"], "sum"
+    )
+
+    if not ragged_reduced_away:
+        return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
+    else:
+        # Don't wrap because we reduced away the raggedness
+        out = func(inp._values, **new_kwargs)
+        if new_kwargs["keepdim"]:
+            out = out.unsqueeze(0)
+        return out
+
+
 @register_jagged_func(torch.ops.aten.transpose.int, "self: jt, dim0: any, dim1: any")
 def transpose_int(func, *args, **kwargs):
     _, new_kwargs = normalize_function(