blob: 3cd088d58aa4d3937bae67d9a5dc189e772ffb91 [file] [log] [blame]
import torch
from .Module import Module
class FlattenTable(Module):
def __init__(self):
super(FlattenTable, self).__init__()
self.output = []
self.input_map = []
self.gradInput = []
def _flatten(self, output, input):
if isinstance(input, list):
input_map = []
# forward DFS order
for i in range(len(input)):
input_map.append(self._flatten(output, input[i]))
else:
input_map = len(output)
output.append(input)
return input_map
def _checkMapping(self, output, input, input_map):
if isinstance(input, list):
if len(input) != len(input_map):
return False
# forward DFS order
for i in range(len(input)):
if not self._checkMapping(output, input[i], input_map[i]):
return False
return True
else:
return output[input_map] == input
# During BPROP we have to build a gradInput with the same shape as the
# input. This is a recursive function to build up a gradInput
def _inverseFlatten(self, gradOutput, input_map):
if isinstance(input_map, list):
gradInput = []
for i in range(len(input_map)):
gradInput.append(self._inverseFlatten(gradOutput, input_map[i]))
return gradInput
else:
return gradOutput[input_map]
def updateOutput(self, input):
assert isinstance(input, list)
# to avoid updating rebuilding the flattened table every updateOutput call
# we will: a DFS pass over the existing output table and the inputs to
# see if it needs to be rebuilt.
if not self._checkMapping(self.output, input, self.input_map):
self.output = []
self.input_map = self._flatten(self.output, input)
return self.output
def updateGradInput(self, input, gradOutput):
assert isinstance(input, list)
assert isinstance(gradOutput, list)
# If the input changes between the updateOutput and updateGradInput call,
#: we may have to rebuild the input_map! However, let's assume that
# the input_map is valid and that forward has already been called.
# However, we should check that the gradInput is valid:
if not self._checkMapping(gradOutput, self.gradInput, self.input_map):
self.gradInput = self._inverseFlatten(gradOutput, self.input_map)
return self.gradInput
def type(self, type=None, tensorCache=None):
if not type:
return self._type
# This function just stores references so we don't need to do any type
# conversions. Just force the tables to be empty.
self.clearState()
def clearState(self):
self.input_map = []
return super(FlattenTable, self).clearState()