| import torch |
| |
| from ..function import Function, InplaceFunction |
| from .utils import maybe_unexpand |
| |
| |
| # TODO: no need to save all args if the grad w.r.t. some of them is not needed |
| def _get_output(ctx, arg, inplace=False): |
| if inplace: |
| ctx.mark_dirty(arg) |
| return arg |
| else: |
| return arg.new().resize_as_(arg) |
| |
| |
| class Addmm(InplaceFunction): |
| |
| @staticmethod |
| def symbolic(g, add_matrix, matrix1, matrix2, alpha=1, beta=1, inplace=False): |
| if alpha != 1: |
| matrix1 = g.op("Scale", matrix1, scale_f=alpha) |
| if beta != 1: |
| add_matrix = g.op("Scale", add_matrix, scale_f=beta) |
| # TODO: Talk to ONNX about why their FC involves a transpose |
| matrix2_t = g.op("Transpose", matrix2) |
| return g.op("FC", matrix1, matrix2_t, add_matrix) |
| |
| @staticmethod |
| def forward(ctx, add_matrix, matrix1, matrix2, alpha=1, beta=1, inplace=False): |
| ctx.alpha = alpha |
| ctx.beta = beta |
| ctx.add_matrix_size = add_matrix.size() |
| ctx.save_for_backward(matrix1, matrix2) |
| output = _get_output(ctx, add_matrix, inplace=inplace) |
| return torch.addmm(alpha, add_matrix, beta, |
| matrix1, matrix2, out=output) |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| matrix1, matrix2 = ctx.saved_variables |
| grad_add_matrix = grad_matrix1 = grad_matrix2 = None |
| |
| if ctx.needs_input_grad[0]: |
| grad_add_matrix = maybe_unexpand(grad_output, ctx.add_matrix_size) |
| if ctx.alpha != 1: |
| grad_add_matrix = grad_add_matrix.mul(ctx.alpha) |
| |
| if ctx.needs_input_grad[1]: |
| if matrix1.stride() == (1, matrix1.size(0)): |
| # column major gradient if input is column major |
| grad_matrix1 = torch.mm(matrix2, grad_output.t()).t() |
| else: |
| grad_matrix1 = torch.mm(grad_output, matrix2.t()) |
| if ctx.beta != 1: |
| grad_matrix1 *= ctx.beta |
| |
| if ctx.needs_input_grad[2]: |
| if matrix2.stride() == (1, matrix2.size(0)): |
| # column major gradient if input is column major |
| grad_matrix2 = torch.mm(grad_output.t(), matrix1).t() |
| else: |
| grad_matrix2 = torch.mm(matrix1.t(), grad_output) |
| if ctx.beta != 1: |
| grad_matrix2 *= ctx.beta |
| |
| return grad_add_matrix, grad_matrix1, grad_matrix2, None, None, None |
| |
| |
| class Addbmm(InplaceFunction): |
| |
| @staticmethod |
| def forward(ctx, add_matrix, batch1, batch2, alpha=1, beta=1, inplace=False): |
| ctx.alpha = alpha |
| ctx.beta = beta |
| ctx.add_matrix_size = add_matrix.size() |
| ctx.save_for_backward(batch1, batch2) |
| output = _get_output(ctx, add_matrix, inplace=inplace) |
| return torch.addbmm(alpha, add_matrix, beta, |
| batch1, batch2, out=output) |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| batch1, batch2 = ctx.saved_variables |
| grad_add_matrix = grad_batch1 = grad_batch2 = None |
| |
| if ctx.needs_input_grad[0]: |
| grad_add_matrix = maybe_unexpand(grad_output, ctx.add_matrix_size) |
| if ctx.alpha != 1: |
| grad_add_matrix = grad_add_matrix.mul(ctx.alpha) |
| |
| if any(ctx.needs_input_grad[1:]): |
| batch_grad_output = (grad_output |
| .unsqueeze(0) |
| .expand(batch1.size(0), batch1.size(1), batch2.size(2))) |
| |
| if ctx.needs_input_grad[1]: |
| grad_batch1 = torch.bmm(batch_grad_output, batch2.transpose(1, 2)) |
| if ctx.beta != 1: |
| grad_batch1 *= ctx.beta |
| |
| if ctx.needs_input_grad[2]: |
| grad_batch2 = torch.bmm(batch1.transpose(1, 2), batch_grad_output) |
| if ctx.beta != 1: |
| grad_batch2 *= ctx.beta |
| |
| return grad_add_matrix, grad_batch1, grad_batch2, None, None, None |
| |
| |
| class Baddbmm(InplaceFunction): |
| |
| @staticmethod |
| def forward(ctx, add_batch, batch1, batch2, alpha=1, beta=1, inplace=False): |
| ctx.alpha = alpha |
| ctx.beta = beta |
| ctx.add_batch_size = add_batch.size() |
| ctx.save_for_backward(batch1, batch2) |
| output = _get_output(ctx, add_batch, inplace=inplace) |
| return torch.baddbmm(alpha, add_batch, beta, |
| batch1, batch2, out=output) |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| batch1, batch2 = ctx.saved_variables |
| grad_add_batch = grad_batch1 = grad_batch2 = None |
| |
| if ctx.needs_input_grad[0]: |
| grad_add_batch = maybe_unexpand(grad_output, ctx.add_batch_size) |
| if ctx.alpha != 1: |
| grad_add_batch = grad_add_batch.mul(ctx.alpha) |
| |
| if ctx.needs_input_grad[1]: |
| grad_batch1 = torch.bmm(grad_output, batch2.transpose(1, 2)) |
| if ctx.beta != 1: |
| grad_batch1 *= ctx.beta |
| |
| if ctx.needs_input_grad[2]: |
| grad_batch2 = torch.bmm(batch1.transpose(1, 2), grad_output) |
| if ctx.beta != 1: |
| grad_batch2 *= ctx.beta |
| |
| return grad_add_batch, grad_batch1, grad_batch2, None, None, None |
| |
| |
| class Addmv(InplaceFunction): |
| |
| @staticmethod |
| def forward(ctx, add_vector, matrix, vector, alpha=1, beta=1, inplace=False): |
| ctx.alpha = alpha |
| ctx.beta = beta |
| ctx.add_vector_size = add_vector.size() |
| ctx.save_for_backward(matrix, vector) |
| output = _get_output(ctx, add_vector, inplace=inplace) |
| return torch.addmv(alpha, add_vector, beta, |
| matrix, vector, out=output) |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| matrix, vector = ctx.saved_variables |
| grad_add_vector = grad_matrix = grad_vector = None |
| |
| if ctx.needs_input_grad[0]: |
| grad_add_vector = maybe_unexpand(grad_output, ctx.add_vector_size) |
| if ctx.alpha != 1: |
| grad_add_vector = grad_add_vector.mul(ctx.alpha) |
| |
| if ctx.needs_input_grad[1]: |
| grad_matrix = torch.ger(grad_output, vector) |
| if ctx.beta != 1: |
| grad_matrix *= ctx.beta |
| |
| if ctx.needs_input_grad[2]: |
| grad_vector = torch.mv(matrix.t(), grad_output) |
| if ctx.beta != 1: |
| grad_vector *= ctx.beta |
| |
| return grad_add_vector, grad_matrix, grad_vector, None, None, None |
| |
| |
| class Addr(InplaceFunction): |
| |
| @staticmethod |
| def forward(ctx, add_matrix, vector1, vector2, alpha=1, beta=1, inplace=False): |
| ctx.alpha = alpha |
| ctx.beta = beta |
| ctx.add_matrix_size = add_matrix.size() |
| ctx.save_for_backward(vector1, vector2) |
| output = _get_output(ctx, add_matrix, inplace=inplace) |
| return torch.addr(alpha, add_matrix, beta, |
| vector1, vector2, out=output) |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| vector1, vector2 = ctx.saved_variables |
| grad_add_matrix = grad_vector1 = grad_vector2 = None |
| |
| if ctx.needs_input_grad[0]: |
| grad_add_matrix = maybe_unexpand(grad_output, ctx.add_matrix_size) |
| if ctx.alpha != 1: |
| grad_add_matrix = grad_add_matrix.mul(ctx.alpha) |
| |
| if ctx.needs_input_grad[1]: |
| grad_vector1 = torch.mv(grad_output, vector2) |
| if ctx.beta != 1: |
| grad_vector1 *= ctx.beta |
| |
| if ctx.needs_input_grad[2]: |
| # TODO: maybe it's better to do transpose + mv + transpose |
| grad_vector2 = torch.mm(vector1.unsqueeze(0), grad_output).squeeze(0) |
| if ctx.beta != 1: |
| grad_vector2 *= ctx.beta |
| |
| return grad_add_matrix, grad_vector1, grad_vector2, None, None, None |
| |
| |
| class Dot(Function): |
| |
| @staticmethod |
| def forward(ctx, vector1, vector2): |
| ctx.save_for_backward(vector1, vector2) |
| ctx.sizes = (vector1.size(), vector2.size()) |
| return vector1.new((vector1.dot(vector2),)) |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| vector1, vector2 = ctx.saved_variables |
| grad_vector1 = grad_vector2 = None |
| |
| if ctx.needs_input_grad[0]: |
| grad_vector1 = vector2.mul(grad_output.expand(ctx.sizes[1])).view(ctx.sizes[0]) |
| |
| if ctx.needs_input_grad[1]: |
| grad_vector2 = vector1.mul(grad_output.expand(ctx.sizes[0])).view(ctx.sizes[1]) |
| |
| return grad_vector1, grad_vector2 |