blob: 8a4ff9f6f1b406c10a2f6542c59ca6ac8b9c0892 [file] [log] [blame]
import torch
from torch.autograd import Function
class Linear(Function):
def forward(self, input, weight, bias=None):
if bias:
self.save_for_backward(input, weight, bias)
else:
self.save_for_backward(input, weight)
output = input.new(input.size(0), weight.size(0))
output.addmm_(0, 1, input, weight.t())
if bias is not None:
# cuBLAS doesn't support 0 strides in sger, so we can't use expand
self.add_buffer = input.new(1).resize_(input.size(0)).fill_(1)
output.addr_(self.add_buffer, bias)
return output
def backward(self, grad_output):
tensors = self.saved_tensors
if len(tensors) == 2:
input, weight = tensors
bias = None
else:
input, weight, bias = tensors
grad_tuple = (
torch.mm(grad_output, weight) if \
self.needs_input_grad[0] else None,
torch.mm(grad_output.t(), input) if \
self.needs_input_grad[1] else None,
torch.mv(grad_output.t(), self.add_buffer) if \
bias is not None and self.needs_input_grad[2] else None,
)
return grad_tuple