Improve memory usage of cuDNN RNN modules (#2179)

diff --git a/test/test_nn.py b/test/test_nn.py
index 1cac377..df2b04b 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -1836,6 +1836,42 @@
             (hx + cx).sum().backward()
 
     @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
+    def test_LSTM_cudnn_weight_format(self):
+        rnn = nn.LSTM(10, 20, batch_first=True).cuda()
+        input = Variable(torch.randn(5, 4, 10).cuda(), requires_grad=True)
+        hx = Variable(torch.randn(1, 5, 20).cuda(), requires_grad=True)
+        cx = Variable(torch.randn(1, 5, 20).cuda(), requires_grad=True)
+        all_vars = [input, hx, cx] + list(rnn.parameters())
+
+        output = rnn(input, (hx, cx))
+        output[0].sum().backward()
+        grads = [v.grad.data.clone() for v in all_vars]
+        for v in all_vars:
+            v.grad.data.zero_()
+
+        # Weights will no longer view onto the same chunk of memory
+        weight = all_vars[4]
+        weight_data = weight.data.clone()
+        weight.data.set_(weight_data)
+
+        for i in range(2):
+            with warnings.catch_warnings(record=True) as w:
+                output_noncontig = rnn(input, (hx, cx))
+            if i == 0:
+                self.assertEqual(len(w), 1)
+                self.assertIn('weights are not part of single contiguous chunk of memory', w[0].message.args[0])
+            output_noncontig[0].sum().backward()
+            grads_noncontig = [v.grad.data.clone() for v in all_vars]
+            for v in all_vars:
+                v.grad.data.zero_()
+            self.assertEqual(output, output_noncontig)
+            self.assertEqual(grads_noncontig, grads)
+
+        # Make sure these still share storage
+        weight_data[:] = 4
+        self.assertEqual(weight_data, all_vars[4].data)
+
+    @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
     def test_cuda_rnn_fused(self):
         def copy_rnn(rnn1, rnn2):
             for x_layer, y_layer in zip(rnn1.all_weights, rnn2.all_weights):
@@ -2165,6 +2201,7 @@
 
                     rnn_pickle = pickle.dumps(rnn)
                     rnn2 = pickle.loads(rnn_pickle)
+                    rnn2.flatten_parameters()
                     output3, hy3 = rnn2(input, hx)
 
                     if p == 0 or not train:
diff --git a/torch/backends/cudnn/rnn.py b/torch/backends/cudnn/rnn.py
index 636e973..51da709 100644
--- a/torch/backends/cudnn/rnn.py
+++ b/torch/backends/cudnn/rnn.py
@@ -178,7 +178,10 @@
 
 
 def _copyParams(params_from, params_to):
+    assert len(params_from) == len(params_to)
     for layer_params_from, layer_params_to in zip(params_from, params_to):
+        # NOTE: these lists have all weights before all biases, so if the layer doesn't
+        # use biases, zip will terminate once layer_params_from ends and ignore them.
         for param_from, param_to in zip(layer_params_from, layer_params_to):
             assert param_from.type() == param_to.type()
             param_to.copy_(param_from, broadcast=False)
@@ -242,17 +245,21 @@
         fn.cy_desc = cudnn.descriptor(cx) if cx is not None else None
 
         # create the weight buffer and copy the weights into it
-        num_weights = get_num_weights(
-            handle, fn.rnn_desc, fn.x_descs[0], fn.datatype)
-        fn.weight_buf = x.new(num_weights)
-        fn.w_desc = init_weight_descriptor(fn, fn.weight_buf)
-        w = fn.weight_buf
-        # this zero might not seem necessary, but it is in the case
-        # where biases are disabled; then they won't be copied and must be zero'd.
-        # Alternatively, _copyParams could be written more carefully.
-        w.zero_()
-        params = get_parameters(fn, handle, w)
-        _copyParams(weight, params)
+        if fn.weight_buf is None:
+            num_weights = get_num_weights(
+                handle, fn.rnn_desc, fn.x_descs[0], fn.datatype)
+            fn.weight_buf = x.new(num_weights)
+            fn.w_desc = init_weight_descriptor(fn, fn.weight_buf)
+            w = fn.weight_buf
+            # this zero might not seem necessary, but it is in the case
+            # where biases are disabled; then they won't be copied and must be zero'd.
+            # Alternatively, _copyParams could be written more carefully.
+            w.zero_()
+            params = get_parameters(fn, handle, w)
+            _copyParams(weight, params)
+        else:
+            fn.w_desc = init_weight_descriptor(fn, fn.weight_buf)
+            w = fn.weight_buf
 
         if tuple(hx.size()) != hidden_size:
             raise RuntimeError('Expected hidden size {}, got {}'.format(
@@ -269,7 +276,9 @@
             fn.x_descs,
             ctypes.byref(workspace_size)
         ))
-        fn.workspace = torch.cuda.ByteTensor(workspace_size.value)
+        fn.workspace_size = workspace_size.value
+        with torch.cuda.device_of(input):
+            workspace = torch.cuda.ByteTensor(fn.workspace_size)
         if fn.requires_grad:
             reserve_size = ctypes.c_long()
             check_error(lib.cudnnGetRNNTrainingReserveSize(
@@ -292,7 +301,7 @@
                 fn.y_descs, ctypes.c_void_p(y.data_ptr()),
                 fn.hy_desc, ctypes.c_void_p(hy.data_ptr()),
                 fn.cy_desc, ctypes.c_void_p(cy.data_ptr()) if cx is not None else None,
-                ctypes.c_void_p(fn.workspace.data_ptr()), fn.workspace.size(0),
+                ctypes.c_void_p(workspace.data_ptr()), workspace.size(0),
                 ctypes.c_void_p(fn.reserve.data_ptr()), fn.reserve.size(0)
             ))
         else:  # inference
@@ -307,7 +316,7 @@
                 fn.y_descs, ctypes.c_void_p(y.data_ptr()),
                 fn.hy_desc, ctypes.c_void_p(hy.data_ptr()),
                 fn.cy_desc, ctypes.c_void_p(cy.data_ptr()) if cx is not None else None,
-                ctypes.c_void_p(fn.workspace.data_ptr()), fn.workspace.size(0)
+                ctypes.c_void_p(workspace.data_ptr()), workspace.size(0)
             ))
 
         if fn.batch_first and not is_input_packed:
@@ -372,6 +381,8 @@
         if not dhy.is_cuda or not dy.is_cuda or (dcy is not None and not dcy.is_cuda):
             raise RuntimeError('Gradients aren\'t CUDA tensors')
 
+        with torch.cuda.device_of(input):
+            workspace = torch.cuda.ByteTensor(fn.workspace_size)
         check_error(cudnn.lib.cudnnRNNBackwardData(
             handle,
             fn.rnn_desc,
@@ -386,7 +397,7 @@
             fn.x_descs, ctypes.c_void_p(dx.data_ptr()),
             fn.hx_desc, ctypes.c_void_p(dhx.data_ptr()),
             fn.cx_desc, ctypes.c_void_p(dcx.data_ptr()) if cx is not None else None,
-            ctypes.c_void_p(fn.workspace.data_ptr()), fn.workspace.size(0),
+            ctypes.c_void_p(workspace.data_ptr()), workspace.size(0),
             ctypes.c_void_p(fn.reserve.data_ptr()), fn.reserve.size(0)
         ))
 
@@ -439,6 +450,8 @@
         y = output
         dw = fn.weight_buf.new().resize_as_(fn.weight_buf).zero_()
 
+        with torch.cuda.device_of(input):
+            workspace = torch.cuda.ByteTensor(fn.workspace_size)
         check_error(cudnn.lib.cudnnRNNBackwardWeights(
             handle,
             fn.rnn_desc,
@@ -446,7 +459,7 @@
             fn.x_descs, ctypes.c_void_p(x.data_ptr()),
             fn.hx_desc, ctypes.c_void_p(hx.data_ptr()),
             fn.y_descs, ctypes.c_void_p(y.data_ptr()),
-            ctypes.c_void_p(fn.workspace.data_ptr()), fn.workspace.size(0),
+            ctypes.c_void_p(workspace.data_ptr()), workspace.size(0),
             fn.w_desc, ctypes.c_void_p(dw.data_ptr()),
             ctypes.c_void_p(fn.reserve.data_ptr()), fn.reserve.size(0)
         ))
diff --git a/torch/nn/_functions/rnn.py b/torch/nn/_functions/rnn.py
index 6881a2d..0fbbaf2 100644
--- a/torch/nn/_functions/rnn.py
+++ b/torch/nn/_functions/rnn.py
@@ -1,3 +1,4 @@
+import warnings
 from torch.autograd import Function, NestedIOFunction, Variable
 import torch.backends.cudnn as cudnn
 from .. import functional as F
@@ -207,7 +208,7 @@
 
 def AutogradRNN(mode, input_size, hidden_size, num_layers=1, batch_first=False,
                 dropout=0, train=True, bidirectional=False, batch_sizes=None,
-                dropout_state=None):
+                dropout_state=None, flat_weight=None):
 
     if mode == 'RNN_RELU':
         cell = RNNReLUCell
@@ -254,7 +255,7 @@
 
     def __init__(self, mode, input_size, hidden_size, num_layers=1,
                  batch_first=False, dropout=0, train=True, bidirectional=False,
-                 batch_sizes=None, dropout_state=None):
+                 batch_sizes=None, dropout_state=None, flat_weight=None):
         super(CudnnRNN, self).__init__()
         if dropout_state is None:
             dropout_state = {}
@@ -271,9 +272,16 @@
         self.batch_sizes = batch_sizes
         self.dropout_seed = torch.IntTensor(1).random_()[0]
         self.dropout_state = dropout_state
+        self.weight_buf = flat_weight
+        if flat_weight is None:
+            warnings.warn("RNN module weights are not part of single contiguous "
+                          "chunk of memory. This means they need to be compacted "
+                          "at every call, possibly greately increasing memory usage. "
+                          "To compact weights again call flatten_parameters().", stacklevel=5)
 
     def forward_extended(self, input, weight, hx):
         assert cudnn.is_acceptable(input)
+        # TODO: raise a warning if weight_data_ptr is None
 
         output = input.new()
 
diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py
index 8971880..ae9f515 100644
--- a/torch/nn/modules/rnn.py
+++ b/torch/nn/modules/rnn.py
@@ -1,5 +1,6 @@
 import math
 import torch
+import warnings
 
 from .module import Module
 from ..parameter import Parameter
@@ -23,36 +24,98 @@
         self.bidirectional = bidirectional
         num_directions = 2 if bidirectional else 1
 
+        if mode == 'LSTM':
+            gate_size = 4 * hidden_size
+        elif mode == 'GRU':
+            gate_size = 3 * hidden_size
+        else:
+            gate_size = hidden_size
+
         self._all_weights = []
+        self._param_buf_size = 0
         for layer in range(num_layers):
             for direction in range(num_directions):
                 layer_input_size = input_size if layer == 0 else hidden_size * num_directions
-                if mode == 'LSTM':
-                    gate_size = 4 * hidden_size
-                elif mode == 'GRU':
-                    gate_size = 3 * hidden_size
-                else:
-                    gate_size = hidden_size
 
                 w_ih = Parameter(torch.Tensor(gate_size, layer_input_size))
                 w_hh = Parameter(torch.Tensor(gate_size, hidden_size))
                 b_ih = Parameter(torch.Tensor(gate_size))
                 b_hh = Parameter(torch.Tensor(gate_size))
+                layer_params = (w_ih, w_hh, b_ih, b_hh)
 
                 suffix = '_reverse' if direction == 1 else ''
-                weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}']
-                weights = [x.format(layer, suffix) for x in weights]
-                setattr(self, weights[0], w_ih)
-                setattr(self, weights[1], w_hh)
+                param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
                 if bias:
-                    setattr(self, weights[2], b_ih)
-                    setattr(self, weights[3], b_hh)
-                    self._all_weights += [weights]
-                else:
-                    self._all_weights += [weights[:2]]
+                    param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}']
+                param_names = [x.format(layer, suffix) for x in param_names]
 
+                for name, param in zip(param_names, layer_params):
+                    setattr(self, name, param)
+                self._all_weights.append(param_names)
+
+                self._param_buf_size += sum(p.numel() for p in layer_params)
+
+        self.flatten_parameters()
         self.reset_parameters()
 
+    def flatten_parameters(self):
+        """Resets parameter data pointer so that they can use faster code paths.
+
+        Right now, this works only if the module is on the GPU and cuDNN is enabled.
+        Otherwise, it's a no-op.
+        """
+        any_param = next(self.parameters()).data
+        if not any_param.is_cuda or not torch.backends.cudnn.is_acceptable(any_param):
+            self._data_ptrs = []
+            return
+
+        # This is quite ugly, but it allows us to reuse the cuDNN code without larger
+        # modifications. It's really a low-level API that doesn't belong in here, but
+        # let's make this exception.
+        from torch.backends.cudnn import rnn
+        from torch.backends import cudnn
+        from torch.nn._functions.rnn import CudnnRNN
+        handle = cudnn.get_handle()
+        with warnings.catch_warnings(record=True):
+            fn = CudnnRNN(
+                self.mode,
+                self.input_size,
+                self.hidden_size,
+                num_layers=self.num_layers,
+                batch_first=self.batch_first,
+                dropout=self.dropout,
+                train=self.training,
+                bidirectional=self.bidirectional,
+                dropout_state=self.dropout_state,
+            )
+
+        # Initialize descriptors
+        fn.datatype = cudnn._typemap[any_param.type()]
+        fn.x_descs = cudnn.descriptor(any_param.new(1, self.input_size), 1)
+        fn.rnn_desc = rnn.init_rnn_descriptor(fn, handle)
+
+        # Allocate buffer to hold the weights
+        num_weights = rnn.get_num_weights(handle, fn.rnn_desc, fn.x_descs[0], fn.datatype)
+        fn.weight_buf = any_param.new(num_weights).zero_()
+        fn.w_desc = rnn.init_weight_descriptor(fn, fn.weight_buf)
+
+        # Slice off views into weight_buf
+        params = rnn.get_parameters(fn, handle, fn.weight_buf)
+        all_weights = [[p.data for p in l] for l in self.all_weights]
+
+        # Copy weights and update their storage
+        rnn._copyParams(all_weights, params)
+        for orig_layer_param, new_layer_param in zip(all_weights, params):
+            for orig_param, new_param in zip(orig_layer_param, new_layer_param):
+                orig_param.set_(new_param.view_as(orig_param))
+
+        self._data_ptrs = list(p.data.data_ptr() for p in self.parameters())
+
+    def _apply(self, fn):
+        ret = super(RNNBase, self)._apply(fn)
+        self.flatten_parameters()
+        return ret
+
     def reset_parameters(self):
         stdv = 1.0 / math.sqrt(self.hidden_size)
         for weight in self.parameters():
@@ -76,6 +139,13 @@
             if self.mode == 'LSTM':
                 hx = (hx, hx)
 
+        has_flat_weights = list(p.data.data_ptr() for p in self.parameters()) == self._data_ptrs
+        if has_flat_weights:
+            first_data = next(self.parameters()).data
+            assert first_data.storage().size() == self._param_buf_size
+            flat_weight = first_data.new().set_(first_data.storage(), 0, torch.Size([self._param_buf_size]))
+        else:
+            flat_weight = None
         func = self._backend.RNN(
             self.mode,
             self.input_size,
@@ -86,7 +156,8 @@
             train=self.training,
             bidirectional=self.bidirectional,
             batch_sizes=batch_sizes,
-            dropout_state=self.dropout_state
+            dropout_state=self.dropout_state,
+            flat_weight=flat_weight
         )
         output, hidden = func(input, self.all_weights, hx)
         if is_packed: