Add forward and backward support for silu to NestedTensors (#97181)
# Summary
Add forward and backward support for silu to NestedTensors
- Add forward support to silu
- Add forward support to silu_
- Add backward support to silu
- Add to NT docs
- Add tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97181
Approved by: https://github.com/cpuhrsch, https://github.com/jbschlosser
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index a835b46..7810d0f 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -4859,10 +4859,14 @@
- func: silu(Tensor self) -> Tensor
structured_delegate: silu.out
python_module: nn
+ dispatch:
+ NestedTensorCPU, NestedTensorCUDA: NestedTensor_silu
- func: silu_(Tensor(a!) self) -> Tensor(a!)
structured_delegate: silu.out
python_module: nn
+ dispatch:
+ NestedTensorCPU, NestedTensorCUDA: NestedTensor_silu_
- func: silu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
structured: True
@@ -4885,6 +4889,7 @@
python_module: nn
dispatch:
CompositeImplicitAutograd: math_silu_backward
+ NestedTensorCPU, NestedTensorCUDA: silu_backward_nested
- func: mish(Tensor self) -> Tensor
structured_delegate: mish.out
diff --git a/aten/src/ATen/native/nested/NestedTensorBackward.cpp b/aten/src/ATen/native/nested/NestedTensorBackward.cpp
index 578ac4d..b748e96 100644
--- a/aten/src/ATen/native/nested/NestedTensorBackward.cpp
+++ b/aten/src/ATen/native/nested/NestedTensorBackward.cpp
@@ -185,6 +185,12 @@
return map_nt_binary(grad_output, input, partial_relu_backward);
}
+// Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
+Tensor silu_backward_nested(const Tensor& grad_output, const Tensor& self){
+ auto partial_silu_backward = [](auto && PH1, auto && PH2) { return at::silu_backward(std::forward<decltype(PH1)>(PH1), std::forward<decltype(PH2)>(PH2)); };
+ return map_nt_binary(grad_output, self, partial_silu_backward);
+}
+
std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_nested(
const Tensor& grad,
const Tensor& input,
diff --git a/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp b/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp
index 8c9cc13..39fee59 100644
--- a/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp
+++ b/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp
@@ -76,5 +76,17 @@
return self;
}
+Tensor NestedTensor_silu(const Tensor& self){
+ return map_nt(self, at::silu);
+}
+
+Tensor& NestedTensor_silu_(Tensor& self){
+ auto self_ptr = get_nested_tensor_impl(self);
+ check_numel_equals_buffer_size(self_ptr);
+ auto buffer = self_ptr->get_buffer();
+ at::silu_(buffer);
+ return self;
+}
+
} // namespace native
} // namespace at
diff --git a/docs/source/nested.rst b/docs/source/nested.rst
index 5fe8df8..19ba636 100644
--- a/docs/source/nested.rst
+++ b/docs/source/nested.rst
@@ -196,6 +196,7 @@
:func:`torch.nn.Dropout`; "Behavior is the same as on regular tensors."
:func:`torch.relu`; "Behavior is the same as on regular tensors."
:func:`torch.gelu`; "Behavior is the same as on regular tensors."
+ :func:`torch.silu`; "Behavior is the same as on regular tensors."
:func:`torch.neg`; "Behavior is the same as on regular tensors."
:func:`torch.add`; "Supports elementwise addition of two nested tensors.
Supports addition of a scalar to a nested tensor."
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index fac0477..dff40e1 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -1,6 +1,7 @@
# Owner(s): ["module: nestedtensor"]
import unittest
+from functools import partial
import numpy as np
import torch
@@ -845,7 +846,9 @@
subtest(torch._C._nn.gelu_, name='gelu_'),
subtest(torch.tanh, name='tanh'),
subtest(torch.tanh_, name='tanh_'),
- subtest(torch.neg, name='neg')])
+ subtest(torch.neg, name='neg'),
+ subtest(torch.nn.functional.silu, name='silu'),
+ subtest(partial(torch.nn.functional.silu, inplace=True), name='silu_'), ])
def test_activations(self, device, func):
nt, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device=device, dtype=torch.float32)
nested_result = func(nt)
@@ -2401,6 +2404,19 @@
data = (a, b, c)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
+ def test_selu_backward(self, device):
+ a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
+ b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
+ c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
+
+ def grad_test_func(a, b, c):
+ nt = torch.nested.as_nested_tensor([a, b, c])
+ nt_relu = torch.nn.functional.silu(nt)
+ return torch.nested.to_padded_tensor(nt_relu, 0)
+
+ data = (a, b, c)
+ assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
+
# Previously would error when input NT doesn't require grad
# NotImplementedError: Cannot access storage of UndefinedTensorImpl
def test_layer_norm_backward_edge_case(self, device):