blob: b8bae62f9bc739ce7fd71c955381d9996ded0bf8 [file] [log] [blame]
import torch
from .Module import Module
class HardTanh(Module):
def __init__(self, min_value=-1, max_value=1, inplace=False):
super(HardTanh, self).__init__()
self.min_val = min_value
self.max_val = max_value
self.inplace = inplace
assert self.max_val > self.min_val
def updateOutput(self, input):
self._backend.HardTanh_updateOutput(
self._backend.library_state,
input,
self.output,
self.min_val,
self.max_val,
self.inplace
)
return self.output
def updateGradInput(self, input, gradOutput):
self._backend.HardTanh_updateGradInput(
self._backend.library_state,
input,
gradOutput,
self.gradInput,
self.min_val,
self.max_val,
self.inplace
)
return self.gradInput