blob: 6652ffbaac45e70a99411fee8ae568b30503b832 [file] [log] [blame]
import torch
from .Module import Module
class MulConstant(Module):
def __init__(self, constant_scalar, inplace=False):
super(MulConstant, self).__init__()
self.constant_scalar = constant_scalar
self.inplace = inplace
def updateOutput(self, input):
if self.inplace:
input.mul_(self.constant_scalar)
self.output.set_(input)
else:
self.output.resize_as_(input)
self.output.copy_(input)
self.output.mul_(self.constant_scalar)
return self.output
def updateGradInput(self, input, gradOutput):
if self.gradInput is None:
return
if self.inplace:
gradOutput.mul_(self.constant_scalar)
self.gradInput.set_(gradOutput)
# restore previous input value
input.div_(self.constant_scalar)
else:
self.gradInput.resize_as_(gradOutput)
self.gradInput.copy_(gradOutput)
self.gradInput.mul_(self.constant_scalar)
return self.gradInput