blob: cb54d7674c2d235ddcc2cb99e29d7fbbe20b375d [file] [log] [blame]
import torch
from .Container import Container
class Concat(Container):
def __init__(self, dimension):
super(Concat, self).__init__()
self.outputSize = torch.Size()
self.dimension = dimension
def updateOutput(self, input):
outs = []
for i in range(len(self.modules)):
currentOutput = self.modules[i].updateOutput(input)
outs.append(currentOutput)
if i == 0:
size = list(currentOutput.size())
else:
size[self.dimension] += currentOutput.size(self.dimension)
self.outputSize = torch.Size(size)
self.output.resize_(self.outputSize)
offset = 0
for i, module in enumerate(self.modules):
currentOutput = outs[i]
self.output.narrow(self.dimension, offset, currentOutput.size(self.dimension)).copy_(currentOutput)
offset = offset + currentOutput.size(self.dimension)
return self.output
def updateGradInput(self, input, gradOutput):
self.gradInput.resize_as_(input)
offset = 0
for i, module in enumerate(self.modules):
currentOutput = module.output
currentGradInput = module.updateGradInput(input, gradOutput.narrow(
self.dimension, offset, currentOutput.size(self.dimension)))
# if the module does not produce a gradInput (for example first layer),: ignore it and move on.
if currentGradInput:
if i == 0:
self.gradInput.copy_(currentGradInput)
else:
self.gradInput.add_(currentGradInput)
offset = offset + currentOutput.size(self.dimension)
return self.gradInput
def accGradParameters(self, input, gradOutput, scale=1):
offset = 0
for i, module in enumerate(self.modules):
currentOutput = module.output
module.accGradParameters(
input,
gradOutput.narrow(self.dimension, offset, currentOutput.size(self.dimension)),
scale)
offset = offset + currentOutput.size(self.dimension)
def backward(self, input, gradOutput, scale=1):
self.gradInput.resize_as_(input)
offset = 0
for i, module in enumerate(self.modules):
currentOutput = module.output
currentGradInput = module.backward(input, gradOutput.narrow(
self.dimension, offset, currentOutput.size(self.dimension)), scale)
# if the module.es not produce a gradInput (for example first layer),: ignore it and move on.
if currentGradInput is not None:
if i == 0:
self.gradInput.copy_(currentGradInput)
else:
self.gradInput.add_(currentGradInput)
offset = offset + currentOutput.size(self.dimension)
return self.gradInput
def accUpdateGradParameters(self, input, gradOutput, lr):
offset = 0
for i, module in enumerate(self.modules):
currentOutput = module.output
module.accUpdateGradParameters(
input,
gradOutput.narrow(self.dimension, offset, currentOutput.size(self.dimension)),
lr)
offset = offset + currentOutput.size(self.dimension)
def __tostring__(self):
tab = ' '
line = '\n'
next = ' |`-> '
ext = ' | '
extlast = ' '
last = ' +. -> '
res = torch.type(self)
res += ' {' + line + tab + 'input'
for i in range(len(self.modules)):
if i == len(self.modules) - 1:
res += line + tab + next + '(' + i + '): ' + str(self.modules[i]).replace(line, line + tab + extlast)
else:
res += line + tab + next + '(' + i + '): ' + str(self.modules[i]).replace(line, line + tab + ext)
res += line + tab + last + 'output'
res += line + '}'
return res