blob: 9aa58c1f7be26cf541b8485667cc3558ade2a982 [file] [log] [blame]
import torch
from .Module import Module
from .utils import clear
class SoftSign(Module):
def __init__(self):
super(SoftSign, self).__init__()
self.temp = None
self.tempgrad = None
def updateOutput(self, input):
if self.temp is None:
self.temp = input.new()
self.temp.resize_as_(input).copy_(input).abs_().add_(1)
self.output.resize_as_(input).copy_(input).div_(self.temp)
return self.output
def updateGradInput(self, input, gradOutput):
if self.tempgrad is None:
self.tempgrad = input.new()
self.tempgrad.resize_as_(self.output).copy_(input).abs_().add_(1).mul_(self.tempgrad)
self.gradInput.resize_as_(input).copy_(gradOutput).div_(self.tempgrad)
return self.gradInput
def clearState(self):
clear(self, 'temp', 'tempgrad')
return super(SoftSign, self).clearState()