blob: dea9a6ff7a0146c0cc33ccccfd712d8af0b7f599 [file] [log] [blame]
import torch
from torch.autograd import Function, Variable
class Bilinear(Function):
@staticmethod
def forward(ctx, input1, input2, weight, bias=None):
ctx.save_for_backward(input1, input2, weight, bias)
output = input1.new(input1.size(0), weight.size(0))
buff = input1.new()
# compute output scores:
for k, w in enumerate(weight):
torch.mm(input1, w, out=buff)
buff.mul_(input2)
torch.sum(buff, 1, keepdim=True, out=output.narrow(1, k, 1))
if bias is not None:
output.add_(bias.expand_as(output))
return output
@staticmethod
def backward(ctx, grad_output):
input1, input2, weight, bias = ctx.saved_variables
grad_input1 = grad_input2 = grad_weight = grad_bias = None
buff = Variable(input1.data.new())
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
grad_input1 = torch.mm(input2, weight[0].t())
grad_input1 = grad_input1.mul(grad_output.narrow(1, 0, 1).expand(grad_input1.size()))
grad_input2 = torch.mm(input1, weight[0])
grad_input2 = grad_input2.mul(grad_output.narrow(1, 0, 1).expand(grad_input2.size()))
for k in range(1, weight.size(0)):
buff = input2.mm(weight[k].t())
buff = buff.mul(grad_output.narrow(1, k, 1).expand(grad_input1.size()))
grad_input1.add_(buff)
buff = input1.mm(weight[k])
buff = buff.mul(grad_output.narrow(1, k, 1).expand(grad_input2.size()))
grad_input2.add_(buff)
grad_weight = Variable(weight.data.new(weight.size()))
if ctx.needs_input_grad[2]:
# accumulate parameter gradients:
for k in range(weight.size(0)):
buff = input1.mul(grad_output.narrow(1, k, 1).expand_as(input1))
grad_weight[k] = torch.mm(buff.t(), input2)
if bias is not None and ctx.needs_input_grad[3]:
grad_bias = grad_output.sum(0, keepdim=False)
return grad_input1, grad_input2, grad_weight, grad_bias