| import torch |
| from .Module import Module |
| from .utils import clear |
| |
| |
| class Reshape(Module): |
| |
| def __init__(self, *args): |
| super(Reshape, self).__init__() |
| |
| if len(args) == 0 and isinstance(args[0], torch.Size): |
| self.size = args[0] |
| else: |
| self.size = torch.Size(args) |
| |
| self.nelement = 1 |
| for s in self.size: |
| self.nelement *= s |
| |
| self._input = None |
| self._gradOutput = None |
| |
| def updateOutput(self, input): |
| if not input.is_contiguous(): |
| if self._input is None: |
| self._input = input.new() |
| self._input.resize_as_(input) |
| self._input.copy_(input) |
| input = self._input |
| |
| batchsize = [input.size(0)] + list(self.size) |
| self.output = input.view(torch.Size(batchsize)) |
| |
| return self.output |
| |
| def updateGradInput(self, input, gradOutput): |
| if not gradOutput.is_contiguous(): |
| if self._gradOutput is None: |
| self._gradOutput = gradOutput.new() |
| self._gradOutput.resize_as_(gradOutput) |
| self._gradOutput.copy_(gradOutput) |
| gradOutput = self._gradOutput |
| |
| self.gradInput = gradOutput.view_as(input) |
| return self.gradInput |
| |
| def __repr__(self): |
| return super(Reshape, self).__repr__() + \ |
| '({})'.format('x'.join(map(lambda x: str(x), self.size))) |
| |
| def clearState(self): |
| clear(self, '_input', '_gradOutput') |
| return super(Reshape, self).clearState() |