blob: 4085a8bd6a798f0f420c671e6f49231d539c537a [file] [log] [blame]
import torch
from .Module import Module
class MM(Module):
def __init__(self, transA=False, transB=False):
super(MM, self).__init__()
self.transA = transA
self.transB = transB
self.gradInput = [torch.Tensor(), torch.Tensor()]
def updateOutput(self, input):
assert len(input) == 2
a, b = input
assert a.nDimension() == 2 or a.nDimension() == 3
assert a.dim() == b.dim()
if a.nDimension() == 2:
if self.transA:
a = a.t()
if self.transB:
b = b.t()
self.output.resize_(a.size(0), b.size(1))
torch.mm(self.output, a, b)
else:
if self.transA:
a = a.transpose(2, 3)
if self.transB:
b = b.transpose(2, 3)
self.output.resize_(a.size(0), a.size(1), b.size(2))
torch.bmm(self.output, a, b)
return self.output
def updateGradInput(self, input, gradOutput):
self.gradInput[0] = self.gradInput[0] or input[0].new()
self.gradInput[1] = self.gradInput[1] or input[1].new()
assert len(input) == 2
a, b = input
self.gradInput[0].resizeAs_(a)
self.gradInput[1].resizeAs_(b)
assert gradOutput.nDimension() == 2 or gradOutput.nDimension() == 3
assert a.dim() == b.dim() == gradOutput.dim()
if gradOutput.nDimension() == 2:
h_dim, w_dim = 0, 1
f = "mm"
else:
h_dim, w_dim = 1, 2
f = "bmm"
if self.transA == self.transB:
a = a.transpose(h_dim, w_dim)
b = b.transpose(h_dim, w_dim)
if self.transA:
getattr(torch, f)(self.gradInput[0], b, gradOutput.transpose(h_dim, w_dim))
else:
getattr(torch, f)(self.gradInput[0], gradOutput, b)
if self.transB:
getattr(torch, f)(self.gradInput[1], gradOutput.transpose(h_dim, w_dim), a)
else:
getattr(torch, f)(self.gradInput[1], a, gradOutput)
return self.gradInput