blob: b9d07e7d3bb0a925218dbc5d65029c0c686d8480 [file] [log] [blame]
import torch
from .Container import Container
class ConcatTable(Container):
def __init__(self, ):
super(ConcatTable, self).__init__()
self.modules = []
self.output = []
def updateOutput(self, input):
for i in range(len(self.modules)):
out = self.modules[i].updateOutput(input)
if i in self.output:
self.output[i] = out
else:
self.output.append(out)
for i in range(len(self.output)-1, len(self.modules)-1, -1):
del self.output[i]
return self.output
def _map_list(self, l1, l2, f):
for i, v in enumerate(l2):
if isinstance(v, list):
res = self._map_list(l1[i] if i < len(l1) else [], v, f)
if i >= len(l1):
assert i == len(l1)
l1.append(res)
else:
l1[i] = res
else:
f(l1, i, v)
for i in range(len(l1)-1, len(l2)-1, -1):
del l1[i]
return l1
def _backward(self, method, input, gradOutput, scale):
isTable = isinstance(input, list)
wasTable = isinstance(self.gradInput, list)
if isTable:
for i, module in enumerate(self.modules):
currentGradInput = getattr(module, method)(input, gradOutput[i], scale)
if not isinstance(currentGradInput, list):
raise RuntimeError("currentGradInput is not a table!")
if len(input) != len(currentGradInput):
raise RuntimeError("table size mismatch")
if i == 0:
self.gradInput = self.gradInput if wasTable else []
def fn(l, i, v):
if i >= len(l):
assert len(l) == i
l.append(v.clone())
else:
l[i].resizeAs_(v)
l[i].copy_(v)
self._map_list(self.gradInput, currentGradInput, fn)
else:
def fn(l, i, v):
if i < len(l):
l[i].add_(v)
else:
assert len(l) == i
l.append(v.clone())
self._map_list(self.gradInput, currentGradInput, fn)
else:
self.gradInput = self.gradInput if not wasTable else input.clone()
for i, module in enumerate(self.modules):
currentGradInput = getattr(module, method)(input, gradOutput[i], scale)
if i == 0:
self.gradInput.resizeAs_(currentGradInput).copy_(currentGradInput)
else:
self.gradInput.add_(currentGradInput)
return self.gradInput
def updateGradInput(self, input, gradOutput):
return self._backward('updateGradInput', input, gradOutput)
def backward(self, input, gradOutput, scale=1):
return self._backward('backward', input, gradOutput, scale)
def accGradParameters(self, input, gradOutput, scale=1):
for i, module in ipairs(self.modules):
self.rethrowErrors(module, i, 'accGradParameters', input, gradOutput[i], scale)
def accUpdateGradParameters(self, input, gradOutput, lr):
for i, module in ipairs(self.modules):
self.rethrowErrors(module, i, 'accUpdateGradParameters', input, gradOutput[i], lr)
def __repr__(self):
tab = ' '
line = '\n'
next = ' |`-> '
ext = ' | '
extlast = ' '
last = ' +. -> '
res = torch.typename(self)
res = res + ' {' + line + tab + 'input'
for i in range(len(self.modules)):
if i == len(self.modules)-1:
res = res + line + tab + next + '(' + str(i) + '): ' + str(self.modules[i]).replace(line, line + tab + extlast)
else:
res = res + line + tab + next + '(' + str(i) + '): ' + str(self.modules[i]).replace(line, line + tab + ext)
res = res + line + tab + last + 'output'
res = res + line + '}'
return res