blob: c3175946ca373576b7583e1934ba394638206cd7 [file] [log] [blame]
import torch
from .Module import Module
class LeakyReLU(Module):
def __init__(self, negval=1/100, inplace=False):
super(LeakyReLU, self).__init__()
if isinstance(negval, bool):
inplace = negval
self.negval = 1/100
else:
self.negval = negval
# default for inplace is False
self.inplace = inplace
if self.negval < 0:
# TODO: warning here
self.inplace = False
def updateOutput(self, input):
self._backend.LeakyReLU_updateOutput(
self._backend.library_state,
input,
self.output,
self.negval,
self.inplace
)
return self.output
def updateGradInput(self, input, gradOutput):
self._backend.LeakyReLU_updateGradInput(
self._backend.library_state,
input,
gradOutput,
self.gradInput,
self.negval,
self.inplace
)
return self.gradInput
def __repr__(self):
return str(type(self)) + '({:.4f})'.format(self.negval)