blob: 48d8a03f55bb6eb614aa2b90235ca00975997726 [file] [log] [blame]
import torch
from .Module import Module
from .utils import clear, recursiveResizeAs, recursiveFill
class NarrowTable(Module):
def __init__(self, offset, length=1):
super(NarrowTable, self).__init__()
self.offset = offset
self.length = length
self.output = []
self.gradInput = []
def updateOutput(self, input):
self.output[:] = [input[self.offset + i] for i in range(self.length)]
return self.output
def updateGradInput(self, input, gradOutput):
if len(self.gradInput) != len(input):
self.gradInput[:] = [None for i in range(len(input))]
assert len(gradOutput) == self.length
for i in range(self.length):
self.gradInput[self.offset + i] = gradOutput[i]
for i in range(len(input)):
if i < self.offset or i >= self.offset + self.length:
gi = self.gradInput[i]
if gi is None:
gi = input[i].new()
self.gradInput[i] = recursiveResizeAs(gi, input[i])[0]
recursiveFill(self.gradInput[i], 0)
return self.gradInput
def type(self, type=None, tensorCache=None):
if not type:
return self._type
clear(self, 'output', 'gradInput')
return super(NarrowTable, self).type(self, type, tensorCache)