blob: a26d23de2889e21818b68a02c85c796b743a2cce [file] [log] [blame]
import torch
from ..function import Function, InplaceFunction
# TODO: no need to save all args if the grad w.r.t. some of them is not needed
class _BlasBase(InplaceFunction):
def __init__(self, alpha=1, beta=1, inplace=False):
super(_BlasBase, self).__init__(inplace)
self.alpha = alpha
self.beta = beta
def _get_output(self, arg):
if self.inplace:
self.mark_dirty(arg)
return arg
else:
return arg.new().resize_as_(arg)
class Addmm(_BlasBase):
def forward(self, add_matrix, matrix1, matrix2):
self.save_for_backward(matrix1, matrix2)
output = self._get_output(add_matrix)
return torch.addmm(self.alpha, add_matrix, self.beta,
matrix1, matrix2, out=output)
def backward(self, grad_output):
matrix1, matrix2 = self.saved_tensors
grad_add_matrix = grad_matrix1 = grad_matrix2 = None
if self.needs_input_grad[0]:
grad_add_matrix = grad_output
if self.alpha != 1:
grad_add_matrix = grad_add_matrix.mul(self.alpha)
if self.needs_input_grad[1]:
grad_matrix1 = torch.mm(grad_output, matrix2.t())
if self.beta != 1:
grad_matrix1 *= self.beta
if self.needs_input_grad[2]:
grad_matrix2 = torch.mm(matrix1.t(), grad_output)
if self.beta != 1:
grad_matrix2 *= self.beta
return grad_add_matrix, grad_matrix1, grad_matrix2
class Addbmm(_BlasBase):
def forward(self, add_matrix, batch1, batch2):
self.save_for_backward(batch1, batch2)
output = self._get_output(add_matrix)
return torch.addbmm(self.alpha, add_matrix, self.beta,
batch1, batch2, out=output)
def backward(self, grad_output):
batch1, batch2 = self.saved_tensors
grad_add_matrix = grad_batch1 = grad_batch2 = None
if self.needs_input_grad[0]:
grad_add_matrix = grad_output
if self.alpha != 1:
grad_add_matrix = grad_add_matrix.mul(self.alpha)
if any(self.needs_input_grad[1:]):
batch_grad_output = (grad_output
.unsqueeze(0)
.expand(batch1.size(0), batch1.size(1), batch2.size(2)))
if self.needs_input_grad[1]:
grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2))
if self.beta != 1:
grad_batch1 *= self.beta
if self.needs_input_grad[2]:
grad_batch2 = torch.bmm(batch1.transpose(1, 2), batch_grad_output)
if self.beta != 1:
grad_batch2 *= self.beta
return grad_add_matrix, grad_batch1, grad_batch2
class Baddbmm(_BlasBase):
def forward(self, add_batch, batch1, batch2):
self.save_for_backward(batch1, batch2)
output = self._get_output(add_batch)
return torch.baddbmm(self.alpha, add_batch, self.beta,
batch1, batch2, out=output)
def backward(self, grad_output):
batch1, batch2 = self.saved_tensors
grad_add_batch = grad_batch1 = grad_batch2 = None
if self.needs_input_grad[0]:
grad_add_batch = grad_output
if self.alpha != 1:
grad_add_batch = grad_add_batch.mul(self.alpha)
if self.needs_input_grad[1]:
grad_batch1 = torch.bmm(grad_output, batch2.transpose(1, 2))
if self.beta != 1:
grad_batch1 *= self.beta
if self.needs_input_grad[2]:
grad_batch2 = torch.bmm(batch1.transpose(1, 2), grad_output)
if self.beta != 1:
grad_batch2 *= self.beta
return grad_add_batch, grad_batch1, grad_batch2
class Addmv(_BlasBase):
def forward(self, add_vector, matrix, vector):
self.save_for_backward(matrix, vector)
output = self._get_output(add_vector)
return torch.addmv(self.alpha, add_vector, self.beta,
matrix, vector, out=output)
def backward(self, grad_output):
matrix, vector = self.saved_tensors
grad_add_vector = grad_matrix = grad_vector = None
if self.needs_input_grad[0]:
grad_add_vector = grad_output
if self.alpha != 1:
grad_add_vector = grad_add_vector.mul(self.alpha)
if self.needs_input_grad[1]:
grad_matrix = torch.ger(grad_output, vector)
if self.beta != 1:
grad_matrix *= self.beta
if self.needs_input_grad[2]:
grad_vector = torch.mv(matrix.t(), grad_output)
if self.beta != 1:
grad_vector *= self.beta
return grad_add_vector, grad_matrix, grad_vector
class Addr(_BlasBase):
def forward(self, add_matrix, vector1, vector2):
self.save_for_backward(vector1, vector2)
output = self._get_output(add_matrix)
return torch.addr(self.alpha, add_matrix, self.beta,
vector1, vector2, out=output)
def backward(self, grad_output):
vector1, vector2 = self.saved_tensors
grad_add_matrix = grad_vector1 = grad_vector2 = None
if self.needs_input_grad[0]:
grad_add_matrix = grad_output
if self.alpha != 1:
grad_add_matrix = grad_add_matrix.mul(self.alpha)
if self.needs_input_grad[1]:
grad_vector1 = torch.mv(grad_output, vector2)
if self.beta != 1:
grad_vector1 *= self.beta
if self.needs_input_grad[2]:
# TODO: maybe it's better to do transpose + mv + transpose
grad_vector2 = torch.mm(vector1.unsqueeze(0), grad_output).squeeze(0)
if self.beta != 1:
grad_vector2 *= self.beta
return grad_add_matrix, grad_vector1, grad_vector2
class Dot(Function):
def forward(self, vector1, vector2):
self.save_for_backward(vector1, vector2)
self.sizes = (vector1.size(), vector2.size())
return vector1.new((vector1.dot(vector2),))
def backward(self, grad_output):
vector1, vector2 = self.saved_tensors
grad_vector1 = grad_vector2 = None
if self.needs_input_grad[0]:
grad_vector1 = vector2.mul(grad_output[0]).view(self.sizes[0])
if self.needs_input_grad[1]:
grad_vector2 = vector1.mul(grad_output[0]).view(self.sizes[1])
return grad_vector1, grad_vector2