blob: 31fd988460ab73e63e123659fbcc37f12dddbb49 [file] [log] [blame]
import torch
from .Module import Module
from .utils import clear
class Sum(Module):
def __init__(self, dimension=0, sizeAverage=False):
super(Sum, self).__init__()
self.dimension = dimension
self.sizeAverage = sizeAverage
self._gradOutput = None
def _getPositiveDimension(self, input):
dimension = self.dimension
if dimension < 0:
dimension = input.dim() + dimension
return dimension
def updateOutput(self, input):
dimension = self._getPositiveDimension(input)
torch.sum(input, dimension, out=self.output, keepdim=True)
if self.sizeAverage:
self.output.div_(input.size(dimension))
if self.output.dim() > 1:
self.output.set_(self.output.select(dimension, 0))
return self.output
def updateGradInput(self, input, gradOutput):
dimension = self._getPositiveDimension(input)
# zero-strides don't work with MKL/BLAS, so
# don't set self.gradInput to zero-stride tensor.
# Instead, do a deepcopy.
size = list(input.size())
size[dimension] = 1
if not gradOutput.is_contiguous():
if self._gradOutput is None:
self._gradOutput = gradOutput.new()
self._gradOutput.resize_as_(gradOutput).copy_(gradOutput)
gradOutput = self._gradOutput
gradOutput = gradOutput.view(*size)
self.gradInput.resize_as_(input)
self.gradInput.copy_(gradOutput.expand_as(input))
if self.sizeAverage:
self.gradInput.div_(input.size(dimension))
return self.gradInput
def clearState(self):
clear(self, '_gradOutput')
return super(Sum, self).clearState()