blob: 313f3332ebe342e348b5f1c7bd5f40d68614dd96 [file] [log] [blame]
import torch
from .Module import Module
class SplitTable(Module):
def __init__(self, dimension):
super(SplitTable, self).__init__()
self.dimension = dimension
def _getPositiveDimension(self, input):
dimension = self.dimension
if dimension < 0:
dimension = input.dim() + dimension
return dimension
def updateOutput(self, input):
dimension = self._getPositiveDimension(input)
slices = input.size(dimension)
currentOutput = []
for i in range(slices):
currentOutput.append(input.select(dimension, i))
self.output = currentOutput
return self.output
def updateGradInput(self, input, gradOutput):
if not self.gradInput:
return
dimension = self._getPositiveDimension(input)
slices = input.size(dimension)
self.gradInput.resizeAs_(input)
for i in range(slices):
self.gradInput.select(dimension, i).copy_(gradOutput[i])
return self.gradInput