blob: 43bae458d580b00148fe54a6bec60f73a09639a9 [file] [log] [blame]
import torch
from torch.autograd import Function
class Linear(Function):
def forward(self, input, weight, bias=None):
self.save_for_backward(input, weight, bias)
output = input.new(input.size(0), weight.size(0))
output.addmm_(0, 1, input, weight.t())
if bias is not None:
# cuBLAS doesn't support 0 strides in sger, so we can't use expand
self.add_buffer = input.new(1).resize_(input.size(0)).fill_(1)
output.addr_(self.add_buffer, bias)
return output
def backward(self, grad_output):
input, weight, bias = self.saved_tensors
grad_input = grad_weight = grad_bias = None
if self.needs_input_grad[0]:
grad_input = torch.mm(grad_output, weight)
if self.needs_input_grad[1]:
grad_weight = torch.mm(grad_output.t(), input)
if bias is not None and self.needs_input_grad[2]:
grad_bias = torch.mv(grad_output.t(), self.add_buffer)
if bias is not None:
return grad_input, grad_weight, grad_bias
else:
return grad_input, grad_weight