blob: 23d5ad9b8e7043f7eb3bb48be8a9d23650db3525 [file] [log] [blame]
import torch
from .Module import Module
from .utils import clear
class Reshape(Module):
def __init__(self, *args):
super(Reshape, self).__init__()
if len(args) == 0 and isinstance(args[0], torch.Size):
self.size = args[0]
else:
self.size = torch.Size(args)
self.nelement = 1
for s in self.size:
self.nelement *= s
self._input = None
self._gradOutput = None
def updateOutput(self, input):
if not input.is_contiguous():
if self._input is None:
self._input = input.new()
self._input.resize_as_(input)
self._input.copy_(input)
input = self._input
batchsize = [input.size(0)] + list(self.size)
self.output = input.view(torch.Size(batchsize))
return self.output
def updateGradInput(self, input, gradOutput):
if not gradOutput.is_contiguous():
if self._gradOutput is None:
self._gradOutput = gradOutput.new()
self._gradOutput.resize_as_(gradOutput)
self._gradOutput.copy_(gradOutput)
gradOutput = self._gradOutput
self.gradInput = gradOutput.view_as(input)
return self.gradInput
def __repr__(self):
return super(Reshape, self).__repr__() + \
'({})'.format('x'.join(map(lambda x: str(x), self.size)))
def clearState(self):
clear(self, '_input', '_gradOutput')
return super(Reshape, self).clearState()