blob: 66aa8116fd66bdc6db7ff00ebbd2ae3b8771ad61 [file] [log] [blame]
import torch
from .Module import Module
from .Identity import Identity
from .LookupTable import LookupTable
from .Sequential import Sequential
from .ParallelTable import ParallelTable
from .MM import MM
class PartialLinear(Module):
"""
PartialLinear is a Linear layer that allows the user to a set a collection of
column indices. When the column indices are set, the layer will behave like a
Linear layer that only has those columns. Meanwhile, all parameters are
preserved, so resetting the PartialLinear layer will result in a module that
behaves just like a regular Linear layer.
This module is useful, for instance, when you want to: forward-backward on
only a subset of a Linear layer during training but use the full Linear layer
at test time.
"""
def __init__(self, inputsize, outputsize, bias=True):
super(PartialLinear, self).__init__()
# define the layer as a small network:
pt = ParallelTable()
pt.add(Identity()).add(LookupTable(outputsize, inputsize))
self.network = Sequential().add(pt).add(MM(False, True))
if bias:
self.bias = torch.zeros(1, outputsize)
self.gradBias = torch.zeros(1, outputsize)
else:
self.bias = self.gradBias = None
# set partition:
self.inputsize = inputsize
self.outputsize = outputsize
self.allcolumns = torch.arange(0, self.outputsize).long()
self.resetPartition()
self.addBuffer = None
self.buffer = None
def setPartition(self, indices):
self.partition = indices.type(self.allcolumns.type())
return self
def resetPartition(self):
self.partition = self.allcolumns
return self
def parameters(self):
return [self.network.get(0).get(1).weight, self.bias], \
[self.network.get(0).get(1).gradWeight, self.gradBias]
# should return only the relevant partition?
def updateOutput(self, input):
self.output.set_(self.network.forward([input, self.partition]))
if self.bias is not None:
self.output.add_(torch.index_select(self.bias, 1, self.partition).expand_as(self.output))
if self.addBuffer is None:
self.addBuffer = input.new()
if self.addBuffer.nelement() != input.size(0):
self.addBuffer.resize_(input.size(0)).fill_(1)
return self.output
def updateGradInput(self, input, gradOutput):
if self.gradInput is not None:
self.network.updateGradInput([input, self.partition], gradOutput)
self.gradInput.set_(self.network.gradInput[0])
return self.gradInput
def accGradParameters(self, input, gradOutput, scale=1):
self.network.accGradParameters([input, self.partition], gradOutput, scale)
if self.bias is not None:
if self.buffer is None:
self.buffer = input.new()
self.buffer.resize_(gradOutput.size(1))
torch.mv(gradOutput.t(), self.addBuffer, out=self.buffer).mul_(scale)
self.gradBias.index_add_(
1, self.partition, self.buffer.view(1, self.buffer.nelement())
)
def accUpdateGradParameters(self, input, gradOutput, lr):
gradWeight = self.network.get(0).get(1).gradWeight
gradBias = self.gradBias
self.network.get(0).get(1).gradWeight = self.network.get(0).get(1).weight
self.gradBias = self.bias
self.accGradParameters(input, gradOutput, -lr)
self.network.get(0).get(1).gradWeight = gradWeight
self.gradBias = gradBias
def zeroGradParameters(self):
self.network.zeroGradParameters()
self.gradBias.zero_()
def updateParameters(self, learningRate):
self.network.updateParameters(learningRate)
self.bias._add(-learningRate, self.gradBias)
def type(self, type=None, tensorCache=None):
result = super(PartialLinear, self).type(type, tensorCache)
self.partition = self.partition.long()
self.allcolumns = self.allcolumns.long()
if type == 'torch.cuda.FloatTensor':
self.allcolumns = self.allcolumns.cuda()
self.partition = self.partition.cuda()
return result
def __repr__(self):
return super(ParallelTable, self).__repr__() + \
'({} -> {})'.format(self.inputsize, self.outputsize) + \
' without bias' if self.bias is None else ''