Implement Softsign double backwards.
diff --git a/test/common_nn.py b/test/common_nn.py
index 99647ec..857d49a 100644
--- a/test/common_nn.py
+++ b/test/common_nn.py
@@ -226,7 +226,6 @@
module_name='Softsign',
input_size=(3, 2, 5),
reference_fn=lambda i, _: i.div(1 + torch.abs(i)),
- check_gradgrad=False,
),
dict(
module_name='Softmin',
diff --git a/torch/nn/_functions/activation.py b/torch/nn/_functions/activation.py
index 433e8ee..b99e199 100644
--- a/torch/nn/_functions/activation.py
+++ b/torch/nn/_functions/activation.py
@@ -3,15 +3,15 @@
class Softsign(Function):
- def forward(self, input):
- self.buffer = input.clone().abs_().add_(1)
- self.buffer_squared = False
- output = input.clone().div_(self.buffer)
- return output
+ @staticmethod
+ def forward(ctx, input):
+ ctx.save_for_backward(input)
+ buffer = input.clone().abs_().add_(1)
+ return input.clone().div_(buffer)
- def backward(self, grad_output):
- if not self.buffer_squared:
- self.buffer.mul_(self.buffer)
- self.buffer_squared = True
- grad_input = grad_output.clone().div_(self.buffer)
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_variables
+ buffer = input.abs().add_(1)
+ grad_input = grad_output.div(buffer.mul(buffer))
return grad_input
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 6e83c90..9acf75f 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -515,7 +515,7 @@
def softsign(input):
- return _functions.activation.Softsign()(input)
+ return _functions.activation.Softsign.apply(input)
def softplus(input, beta=1, threshold=20):