blob: 433e8eee625154f3ca7cfdad41341c59a90d6272 [file] [log] [blame]
from torch.autograd.function import Function
class Softsign(Function):
def forward(self, input):
self.buffer = input.clone().abs_().add_(1)
self.buffer_squared = False
output = input.clone().div_(self.buffer)
return output
def backward(self, grad_output):
if not self.buffer_squared:
self.buffer.mul_(self.buffer)
self.buffer_squared = True
grad_input = grad_output.clone().div_(self.buffer)
return grad_input