blob: f6eed8abdeca304de7df213e556b1bb7fe5bd59c [file] [log] [blame]
import torch
from .Module import Module
class AddConstant(Module):
def __init__(self, constant_scalar, inplace=False):
super(AddConstant, self).__init__()
self.constant_scalar = constant_scalar
self.inplace = inplace
def updateOutput(self, input):
if self.inplace:
input.add_(self.constant_scalar)
self.output.set_(input)
else:
self.output.resizeAs_(input)
self.output.copy_(input)
self.output.add_(self.constant_scalar)
return self.output
def updateGradInput(self, input, gradOutput):
if self.inplace:
self.gradInput.set_(gradOutput)
# restore previous input value
input.add_(-self.constant_scalar)
else:
self.gradInput.resizeAs_(gradOutput)
self.gradInput.copy_(gradOutput)
return self.gradInput