Backward support for unbind() with NJT (#128032)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128032
Approved by: https://github.com/soulitzer
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index 50d6dee..fa33a13 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -5610,6 +5610,25 @@
         for dynamic in [False, True, None]:
             self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
 
+    @dtypes(torch.float32, torch.double, torch.half)
+    def test_unbind_backward(self, device, dtype):
+        nt = torch.nested.nested_tensor(
+            [
+                torch.randn(2, 4, device=device),
+                torch.randn(5, 4, device=device),
+                torch.randn(3, 4, device=device),
+            ],
+            layout=torch.jagged,
+            requires_grad=True,
+        )
+
+        a, b, c = nt.unbind()
+        b.sum().backward()
+
+        expected_grad = torch.zeros_like(nt)
+        expected_grad.unbind()[1].add_(1.0)
+        torch._dynamo.disable(self.assertEqual)(nt.grad, expected_grad)
+
 
 instantiate_parametrized_tests(TestNestedTensor)
 instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 76a7a0a..02a3e6c 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -2847,7 +2847,7 @@
       self: unbind_backward(grads, dim)
       result: auto_linear
     AutogradNestedTensor:
-      self: unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options())
+      self: "self.layout() == c10::kJagged ? unbind_backward_nested_jagged(grads, self, dim) : unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options())"
       result: auto_linear
 
 - name: stack(Tensor[] tensors, int dim=0) -> Tensor
diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp
index 9d897c6..f51c2f0 100644
--- a/torch/csrc/autograd/FunctionsManual.cpp
+++ b/torch/csrc/autograd/FunctionsManual.cpp
@@ -1014,6 +1014,23 @@
   return at::_nested_tensor_from_tensor_list(grads_tensors);
 }
 
+Tensor unbind_backward_nested_jagged(
+    const variable_list& grads,
+    const Tensor& self,
+    int64_t dim) {
+  TORCH_INTERNAL_ASSERT(
+      dim == 0, "unbind_backward_nested_jagged() only supports dim=0")
+  auto grad_nt = at::zeros_like(self);
+  auto unbound_grads = grad_nt.unbind();
+  for (int64_t i : c10::irange(static_cast<int64_t>(grads.size()))) {
+    if (grads[i].defined()) {
+      unbound_grads[i].copy_(static_cast<Tensor>(grads[i]));
+    }
+  }
+
+  return grad_nt;
+}
+
 Tensor unsqueeze_to(const Tensor& self, c10::SymIntArrayRef sym_sizes) {
   auto result = self;
 
diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h
index dedff70..ecf99bd 100644
--- a/torch/csrc/autograd/FunctionsManual.h
+++ b/torch/csrc/autograd/FunctionsManual.h
@@ -244,6 +244,10 @@
     const Tensor& nt_sizes,
     int64_t dim,
     const at::TensorOptions& options);
+at::Tensor unbind_backward_nested_jagged(
+    const variable_list& grads,
+    const Tensor& self,
+    int64_t dim);
 at::Tensor unsqueeze_to(const at::Tensor& self, c10::SymIntArrayRef sym_sizes);
 at::Tensor unsqueeze_to(
     const at::Tensor& self,
diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py
index 6f1c47d..8458f03 100644
--- a/torch/nested/_internal/ops.py
+++ b/torch/nested/_internal/ops.py
@@ -472,6 +472,17 @@
 )(jagged_unary_pointwise)
 
 
+@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all")
+def zero__default(func, *args, **kwargs):
+    _, new_kwargs = normalize_function(
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    inp = new_kwargs.pop("input")
+    func(inp._values)
+    return inp
+
+
 @register_jagged_func(
     torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any"
 )