blob: dd8b684c404514887d1e1e60f3a113f660e79e9a [file] [log] [blame]
import math
import torch
from .Module import Module
from .utils import clear
class Bilinear(Module):
def _assertInput(self, input):
if len(input) != 2 or not isinstance(input[0], torch.Tensor) or not isinstance(input[1], torch.Tensor):
raise RuntimeError('input should be a table containing two data Tensors')
if input[0].ndimension() != 2 or input[1].ndimension() != 2:
raise RuntimeError('input Tensors should be two-dimensional')
if input[0].size(0) != input[1].size(0):
raise RuntimeError('input Tensors should have the same number of rows')
if input[0].size(1) != self.weight.size(1):
raise RuntimeError('dimensionality of first input is erroneous')
if input[1].size(1) != self.weight.size(2):
raise RuntimeError('dimensionality of second input is erroneous')
def _assertInputGradOutput(self, input, gradOutput):
if input[0].size(0) != gradOutput.size(0):
raise RuntimeError('number of rows in gradOutput.es not match input')
if gradOutput.size(1) != self.weight.size(0):
raise RuntimeError('number of columns in gradOutput does not match layer\'s output size')
def __init__(self, inputSize1, inputSize2, outputSize, bias=True):
# set up model:
super(Bilinear, self).__init__()
self.weight = torch.Tensor(outputSize, inputSize1, inputSize2)
self.gradWeight = torch.Tensor(outputSize, inputSize1, inputSize2)
if bias:
self.bias = torch.Tensor(outputSize)
self.gradBias = torch.Tensor(outputSize)
else:
self.bias = None
self.gradBias = None
self.buff1 = None
self.buff2 = None
self.gradInput = [torch.Tensor(), torch.Tensor()]
self.reset()
def reset(self, stdv=None):
if stdv is not None:
stdv = stdv * math.sqrt(3)
else:
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.uniform_(-stdv, stdv)
return self
def updateOutput(self, input):
self._assertInput(input)
# set up buffer:
if self.buff2 is None:
self.buff2 = input[0].new()
self.buff2.resize_as_(input[1])
# compute output scores:
self.output.resize_(input[0].size(0), self.weight.size(0))
for k in range(self.weight.size(0)):
torch.mm(input[0], self.weight[k], out=self.buff2)
self.buff2.mul_(input[1])
torch.sum(self.buff2, 1, True, out=self.output.narrow(1, k, 1))
if self.bias is not None:
self.output.add_(self.bias.view(1, self.bias.nelement()).expand_as(self.output))
return self.output
def updateGradInput(self, input, gradOutput):
if self.gradInput is None:
return
self._assertInputGradOutput(input, gradOutput)
# compute d output / d input:
self.gradInput[0].resize_as_(input[0]).fill_(0)
self.gradInput[1].resize_as_(input[1]).fill_(0)
#: first slice of weight tensor (k = 1)
self.gradInput[0].addmm_(input[1], self.weight[0].t())
self.gradInput[0].mul_(gradOutput.narrow(1, 0, 1).expand(self.gradInput[0].size(0),
self.gradInput[0].size(1)))
self.gradInput[1].addmm_(input[0], self.weight[0])
self.gradInput[1].mul_(gradOutput.narrow(1, 0, 1).expand(self.gradInput[1].size(0),
self.gradInput[1].size(1)))
#: remaining slices of weight tensor
if self.weight.size(0) > 1:
if self.buff1 is None:
self.buff1 = input[0].new()
self.buff1.resize_as_(input[0])
for k in range(1, self.weight.size(0)):
torch.mm(input[1], self.weight[k].t(), out=self.buff1)
self.buff1.mul_(gradOutput.narrow(1, k, 1).expand(self.gradInput[0].size(0),
self.gradInput[0].size(1)))
self.gradInput[0].add_(self.buff1)
torch.mm(input[0], self.weight[k], out=self.buff2)
self.buff2.mul_(gradOutput.narrow(1, k, 1).expand(self.gradInput[1].size(0),
self.gradInput[1].size(1)))
self.gradInput[1].add_(self.buff2)
return self.gradInput
def accGradParameters(self, input, gradOutput, scale=1):
self._assertInputGradOutput(input, gradOutput)
# make sure we have buffer:
if self.buff1 is None:
self.buff1 = input[0].new()
self.buff1.resize_as_(input[0])
# accumulate parameter gradients:
for k in range(self.weight.size(0)):
torch.mul(input[0], gradOutput.narrow(1, k, 1).expand_as(input[0]), out=self.buff1)
self.gradWeight[k].addmm_(self.buff1.t(), input[1])
if self.bias is not None:
self.gradBias.add_(scale, gradOutput.sum(0, keepdim=False))
def __repr__(self):
return str(type(self)) + \
'({}x{} -> {}) {}'.format(
self.weight.size(1), self.weight.size(2), self.weight.size(0),
(' without bias' if self.bias is None else '')
)
def clearState(self):
clear(self, 'buff1', 'buff2')
return super(Bilinear, self).clearState()