Backward support for unbind() with NJT (#128032)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128032
Approved by: https://github.com/soulitzer
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index 50d6dee..fa33a13 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -5610,6 +5610,25 @@
for dynamic in [False, True, None]:
self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
+ @dtypes(torch.float32, torch.double, torch.half)
+ def test_unbind_backward(self, device, dtype):
+ nt = torch.nested.nested_tensor(
+ [
+ torch.randn(2, 4, device=device),
+ torch.randn(5, 4, device=device),
+ torch.randn(3, 4, device=device),
+ ],
+ layout=torch.jagged,
+ requires_grad=True,
+ )
+
+ a, b, c = nt.unbind()
+ b.sum().backward()
+
+ expected_grad = torch.zeros_like(nt)
+ expected_grad.unbind()[1].add_(1.0)
+ torch._dynamo.disable(self.assertEqual)(nt.grad, expected_grad)
+
instantiate_parametrized_tests(TestNestedTensor)
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 76a7a0a..02a3e6c 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -2847,7 +2847,7 @@
self: unbind_backward(grads, dim)
result: auto_linear
AutogradNestedTensor:
- self: unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options())
+ self: "self.layout() == c10::kJagged ? unbind_backward_nested_jagged(grads, self, dim) : unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options())"
result: auto_linear
- name: stack(Tensor[] tensors, int dim=0) -> Tensor
diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp
index 9d897c6..f51c2f0 100644
--- a/torch/csrc/autograd/FunctionsManual.cpp
+++ b/torch/csrc/autograd/FunctionsManual.cpp
@@ -1014,6 +1014,23 @@
return at::_nested_tensor_from_tensor_list(grads_tensors);
}
+Tensor unbind_backward_nested_jagged(
+ const variable_list& grads,
+ const Tensor& self,
+ int64_t dim) {
+ TORCH_INTERNAL_ASSERT(
+ dim == 0, "unbind_backward_nested_jagged() only supports dim=0")
+ auto grad_nt = at::zeros_like(self);
+ auto unbound_grads = grad_nt.unbind();
+ for (int64_t i : c10::irange(static_cast<int64_t>(grads.size()))) {
+ if (grads[i].defined()) {
+ unbound_grads[i].copy_(static_cast<Tensor>(grads[i]));
+ }
+ }
+
+ return grad_nt;
+}
+
Tensor unsqueeze_to(const Tensor& self, c10::SymIntArrayRef sym_sizes) {
auto result = self;
diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h
index dedff70..ecf99bd 100644
--- a/torch/csrc/autograd/FunctionsManual.h
+++ b/torch/csrc/autograd/FunctionsManual.h
@@ -244,6 +244,10 @@
const Tensor& nt_sizes,
int64_t dim,
const at::TensorOptions& options);
+at::Tensor unbind_backward_nested_jagged(
+ const variable_list& grads,
+ const Tensor& self,
+ int64_t dim);
at::Tensor unsqueeze_to(const at::Tensor& self, c10::SymIntArrayRef sym_sizes);
at::Tensor unsqueeze_to(
const at::Tensor& self,
diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py
index 6f1c47d..8458f03 100644
--- a/torch/nested/_internal/ops.py
+++ b/torch/nested/_internal/ops.py
@@ -472,6 +472,17 @@
)(jagged_unary_pointwise)
+@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all")
+def zero__default(func, *args, **kwargs):
+ _, new_kwargs = normalize_function(
+ func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+ )
+
+ inp = new_kwargs.pop("input")
+ func(inp._values)
+ return inp
+
+
@register_jagged_func(
torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any"
)