blob: 4caaf0de824a0fc43de4b8d16eed4b9fa829afc6 [file] [log] [blame]
from torch.autograd.function import Function
class Softsign(Function):
def forward(self, input):
self.buffer = input.clone().abs_().add_(1)
self.buffer_squared = False
output = input.clone().div_(self.buffer)
return output
def backward(self, grad_output):
if not self.buffer_squared:
self.buffer.mul_(self.buffer)
self.buffer_squared = True
grad_input = grad_output.clone().div_(self.buffer)
return grad_input