Use torch.matmul in nn.Linear (#1935)

This takes advantage of the broadcasting behavior of torch.matmul to
support inputs with more than two dimensions. The extra dimensions are
treated like part of the batch dimension, much like nn.Bottle in Lua
Torch.

There are a few related small performance changes:

 * Addmm computes the gradient in column-major for inputs in
   column-major format
 * Variable.mm calls Addmm in-place with the desired output buffer
diff --git a/test/test_nn.py b/test/test_nn.py
index 6b01286..19532ed 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -2316,6 +2316,12 @@
         input = Variable(torch.randn(1, 1, 2, 2, 2), requires_grad=True)
         self.assertTrue(gradcheck(lambda x: F.upsample(x, 4, mode='trilinear'), (input,)))
 
+    def test_linear_broadcasting(self):
+        m = nn.Linear(5, 8)
+        inp = Variable(torch.randn(2, 3, 5))
+        expected = m(inp.view(6, 5)).view(2, 3, 8)
+        self.assertEqual(expected, m(inp))
+
     def test_bilinear(self):
         module = nn.Bilinear(10, 10, 8)
         module_legacy = legacy.Bilinear(10, 10, 8)
diff --git a/torch/autograd/_functions/blas.py b/torch/autograd/_functions/blas.py
index 20ccec1..16626f7 100644
--- a/torch/autograd/_functions/blas.py
+++ b/torch/autograd/_functions/blas.py
@@ -36,12 +36,20 @@
                 grad_add_matrix = grad_add_matrix.mul(ctx.alpha)
 
         if ctx.needs_input_grad[1]:
-            grad_matrix1 = torch.mm(grad_output, matrix2.t())
+            if matrix1.stride() == (1, matrix1.size(0)):
+                # column major gradient if input is column major
+                grad_matrix1 = torch.mm(matrix2, grad_output.t()).t()
+            else:
+                grad_matrix1 = torch.mm(grad_output, matrix2.t())
             if ctx.beta != 1:
                 grad_matrix1 *= ctx.beta
 
         if ctx.needs_input_grad[2]:
-            grad_matrix2 = torch.mm(matrix1.t(), grad_output)
+            if matrix2.stride() == (1, matrix2.size(0)):
+                # column major gradient if input is column major
+                grad_matrix2 = torch.mm(grad_output.t(), matrix1).t()
+            else:
+                grad_matrix2 = torch.mm(matrix1.t(), grad_output)
             if ctx.beta != 1:
                 grad_matrix2 *= ctx.beta
 
diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py
index d4a452a..15f44bb 100644
--- a/torch/autograd/variable.py
+++ b/torch/autograd/variable.py
@@ -544,7 +544,7 @@
 
     def mm(self, matrix):
         output = Variable(self.data.new(self.data.size(0), matrix.data.size(1)))
-        return self._static_blas(Addmm, (output, 0, 1, self, matrix), False)
+        return Addmm.apply(output, self, matrix, 0, 1, True)
 
     def bmm(self, batch):
         output = Variable(self.data.new(self.data.size(0), self.data.size(1),
diff --git a/torch/nn/_functions/linear.py b/torch/nn/_functions/linear.py
index 2877845..8e3fac2 100644
--- a/torch/nn/_functions/linear.py
+++ b/torch/nn/_functions/linear.py
@@ -1,35 +1,5 @@
 import torch
 from torch.autograd import Function
-from torch.autograd import Variable
-
-
-class Linear(Function):
-
-    @staticmethod
-    def forward(ctx, input, weight, bias=None):
-        ctx.save_for_backward(input, weight, bias)
-        output = input.new(input.size(0), weight.size(0))
-        output.addmm_(0, 1, input, weight.t())
-        if bias is not None:
-            output.add_(bias.expand_as(output))
-        return output
-
-    @staticmethod
-    def backward(ctx, grad_output):
-        input, weight, bias = ctx.saved_variables
-
-        grad_input = grad_weight = grad_bias = None
-        if ctx.needs_input_grad[0]:
-            grad_input = torch.mm(grad_output, weight)
-        if ctx.needs_input_grad[1]:
-            grad_weight = torch.mm(grad_output.t(), input)
-        if bias is not None and ctx.needs_input_grad[2]:
-            grad_bias = grad_output.sum(0, False)
-
-        if bias is not None:
-            return grad_input, grad_weight, grad_bias
-        else:
-            return grad_input, grad_weight
 
 
 class Bilinear(Function):
diff --git a/torch/nn/backends/thnn.py b/torch/nn/backends/thnn.py
index 57b55fa..5b0316f 100644
--- a/torch/nn/backends/thnn.py
+++ b/torch/nn/backends/thnn.py
@@ -20,7 +20,6 @@
 
 def _initialize_backend():
     from .._functions.thnn import _all_functions as _thnn_functions
-    from .._functions.linear import Linear
     from .._functions.conv import ConvNd
     from .._functions.rnn import RNN, \
         RNNTanhCell, RNNReLUCell, GRUCell, LSTMCell
@@ -29,7 +28,6 @@
     from .._functions.loss import CosineEmbeddingLoss, \
         HingeEmbeddingLoss, MarginRankingLoss
 
-    backend.register_function('Linear', Linear)
     backend.register_function('ConvNd', ConvNd)
     backend.register_function('RNN', RNN)
     backend.register_function('RNNTanhCell', RNNTanhCell)
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 190a59c..e0e7953 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -7,6 +7,7 @@
 import torch
 from . import _functions
 from .modules import utils
+from ._functions.linear import Bilinear
 from ._functions.padding import ConstantPad2d
 from ..autograd import _functions as _autograd_functions
 from torch.autograd import Variable
@@ -532,18 +533,21 @@
 # etc.
 
 def linear(input, weight, bias=None):
-    if bias is None:
-        return _functions.linear.Linear.apply(input, weight)
-    else:
-        return _functions.linear.Linear.apply(input, weight, bias)
+    if input.dim() == 2 and bias is not None:
+        # fused op is marginally faster
+        return torch.addmm(bias, input, weight.t())
+
+    output = input.matmul(weight.t())
+    if bias is not None:
+        output += bias
+    return output
 
 
 def bilinear(input1, input2, weight, bias=None):
-    state = _functions.linear.Bilinear()
     if bias is None:
-        return state(input1, input2, weight)
+        return Bilinear()(input1, input2, weight)
     else:
-        return state(input1, input2, weight, bias)
+        return Bilinear()(input1, input2, weight, bias)
 
 
 def batch_norm(input, running_mean, running_var, weight=None, bias=None,
diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py
index 6cbc6a5..58abdfb 100644
--- a/torch/nn/modules/linear.py
+++ b/torch/nn/modules/linear.py
@@ -48,10 +48,7 @@
             self.bias.data.uniform_(-stdv, stdv)
 
     def forward(self, input):
-        if self.bias is None:
-            return self._backend.Linear.apply(input, self.weight)
-        else:
-            return self._backend.Linear.apply(input, self.weight, self.bias)
+        return F.linear(input, self.weight, self.bias)
 
     def __repr__(self):
         return self.__class__.__name__ + ' (' \