blob: b8f46d030a4337dfd5c935f7083a15452777a9cf [file] [log] [blame]
import torch
from .Module import Module
class SoftPlus(Module):
def __init__(self, beta=1):
super(SoftPlus, self).__init__()
self.beta = beta # Beta controls sharpness of transfer function
self.threshold = 20 # Avoid floating point issues with exp(x), x>20
def updateOutput(self, input):
# f(x) = 1/beta * log(1 + exp(beta * x))
self._backend.SoftPlus_updateOutput(
self._backend.library_state,
input,
self.output,
self.beta,
self.threshold
)
return self.output
def updateGradInput(self, input, gradOutput):
# d/dx[log(1+exp(k*x))/k] = exp(kx) / (exp(kx) + 1)
# SINCE
# y = (1/k)*log(1+exp(k*x)) #> x = (1/k)*log(exp(k*y)-1)
# THEREFORE:
# d/dx(f(x)) = (exp(k*y) - 1) / exp(k*y)
self._backend.SoftPlus_updateGradInput(
self._backend.library_state,
input,
gradOutput,
self.gradInput,
self.output,
self.beta,
self.threshold
)
return self.gradInput