blob: 36ef0f1689b11ab552895dffc0fb516c49ea610d [file] [log] [blame]
import torch
from .Module import Module
from .Tanh import Tanh
class TanhShrink(Module):
def __init__(self):
super(TanhShrink, self).__init__()
self.tanh = Tanh()
def updateOutput(self, input):
th = self.tanh.updateOutput(input)
self.output.resize_as_(input).copy_(input)
self.output.add_(-1, th)
return self.output
def updateGradInput(self, input, gradOutput):
dth = self.tanh.updateGradInput(input, gradOutput)
self.gradInput.resize_as_(input).copy_(gradOutput)
self.gradInput.add_(-1, dth)
return self.gradInput