Improve memory usage of cuDNN RNN modules (#2179)
diff --git a/test/test_nn.py b/test/test_nn.py
index 1cac377..df2b04b 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -1836,6 +1836,42 @@
(hx + cx).sum().backward()
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
+ def test_LSTM_cudnn_weight_format(self):
+ rnn = nn.LSTM(10, 20, batch_first=True).cuda()
+ input = Variable(torch.randn(5, 4, 10).cuda(), requires_grad=True)
+ hx = Variable(torch.randn(1, 5, 20).cuda(), requires_grad=True)
+ cx = Variable(torch.randn(1, 5, 20).cuda(), requires_grad=True)
+ all_vars = [input, hx, cx] + list(rnn.parameters())
+
+ output = rnn(input, (hx, cx))
+ output[0].sum().backward()
+ grads = [v.grad.data.clone() for v in all_vars]
+ for v in all_vars:
+ v.grad.data.zero_()
+
+ # Weights will no longer view onto the same chunk of memory
+ weight = all_vars[4]
+ weight_data = weight.data.clone()
+ weight.data.set_(weight_data)
+
+ for i in range(2):
+ with warnings.catch_warnings(record=True) as w:
+ output_noncontig = rnn(input, (hx, cx))
+ if i == 0:
+ self.assertEqual(len(w), 1)
+ self.assertIn('weights are not part of single contiguous chunk of memory', w[0].message.args[0])
+ output_noncontig[0].sum().backward()
+ grads_noncontig = [v.grad.data.clone() for v in all_vars]
+ for v in all_vars:
+ v.grad.data.zero_()
+ self.assertEqual(output, output_noncontig)
+ self.assertEqual(grads_noncontig, grads)
+
+ # Make sure these still share storage
+ weight_data[:] = 4
+ self.assertEqual(weight_data, all_vars[4].data)
+
+ @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
def test_cuda_rnn_fused(self):
def copy_rnn(rnn1, rnn2):
for x_layer, y_layer in zip(rnn1.all_weights, rnn2.all_weights):
@@ -2165,6 +2201,7 @@
rnn_pickle = pickle.dumps(rnn)
rnn2 = pickle.loads(rnn_pickle)
+ rnn2.flatten_parameters()
output3, hy3 = rnn2(input, hx)
if p == 0 or not train:
diff --git a/torch/backends/cudnn/rnn.py b/torch/backends/cudnn/rnn.py
index 636e973..51da709 100644
--- a/torch/backends/cudnn/rnn.py
+++ b/torch/backends/cudnn/rnn.py
@@ -178,7 +178,10 @@
def _copyParams(params_from, params_to):
+ assert len(params_from) == len(params_to)
for layer_params_from, layer_params_to in zip(params_from, params_to):
+ # NOTE: these lists have all weights before all biases, so if the layer doesn't
+ # use biases, zip will terminate once layer_params_from ends and ignore them.
for param_from, param_to in zip(layer_params_from, layer_params_to):
assert param_from.type() == param_to.type()
param_to.copy_(param_from, broadcast=False)
@@ -242,17 +245,21 @@
fn.cy_desc = cudnn.descriptor(cx) if cx is not None else None
# create the weight buffer and copy the weights into it
- num_weights = get_num_weights(
- handle, fn.rnn_desc, fn.x_descs[0], fn.datatype)
- fn.weight_buf = x.new(num_weights)
- fn.w_desc = init_weight_descriptor(fn, fn.weight_buf)
- w = fn.weight_buf
- # this zero might not seem necessary, but it is in the case
- # where biases are disabled; then they won't be copied and must be zero'd.
- # Alternatively, _copyParams could be written more carefully.
- w.zero_()
- params = get_parameters(fn, handle, w)
- _copyParams(weight, params)
+ if fn.weight_buf is None:
+ num_weights = get_num_weights(
+ handle, fn.rnn_desc, fn.x_descs[0], fn.datatype)
+ fn.weight_buf = x.new(num_weights)
+ fn.w_desc = init_weight_descriptor(fn, fn.weight_buf)
+ w = fn.weight_buf
+ # this zero might not seem necessary, but it is in the case
+ # where biases are disabled; then they won't be copied and must be zero'd.
+ # Alternatively, _copyParams could be written more carefully.
+ w.zero_()
+ params = get_parameters(fn, handle, w)
+ _copyParams(weight, params)
+ else:
+ fn.w_desc = init_weight_descriptor(fn, fn.weight_buf)
+ w = fn.weight_buf
if tuple(hx.size()) != hidden_size:
raise RuntimeError('Expected hidden size {}, got {}'.format(
@@ -269,7 +276,9 @@
fn.x_descs,
ctypes.byref(workspace_size)
))
- fn.workspace = torch.cuda.ByteTensor(workspace_size.value)
+ fn.workspace_size = workspace_size.value
+ with torch.cuda.device_of(input):
+ workspace = torch.cuda.ByteTensor(fn.workspace_size)
if fn.requires_grad:
reserve_size = ctypes.c_long()
check_error(lib.cudnnGetRNNTrainingReserveSize(
@@ -292,7 +301,7 @@
fn.y_descs, ctypes.c_void_p(y.data_ptr()),
fn.hy_desc, ctypes.c_void_p(hy.data_ptr()),
fn.cy_desc, ctypes.c_void_p(cy.data_ptr()) if cx is not None else None,
- ctypes.c_void_p(fn.workspace.data_ptr()), fn.workspace.size(0),
+ ctypes.c_void_p(workspace.data_ptr()), workspace.size(0),
ctypes.c_void_p(fn.reserve.data_ptr()), fn.reserve.size(0)
))
else: # inference
@@ -307,7 +316,7 @@
fn.y_descs, ctypes.c_void_p(y.data_ptr()),
fn.hy_desc, ctypes.c_void_p(hy.data_ptr()),
fn.cy_desc, ctypes.c_void_p(cy.data_ptr()) if cx is not None else None,
- ctypes.c_void_p(fn.workspace.data_ptr()), fn.workspace.size(0)
+ ctypes.c_void_p(workspace.data_ptr()), workspace.size(0)
))
if fn.batch_first and not is_input_packed:
@@ -372,6 +381,8 @@
if not dhy.is_cuda or not dy.is_cuda or (dcy is not None and not dcy.is_cuda):
raise RuntimeError('Gradients aren\'t CUDA tensors')
+ with torch.cuda.device_of(input):
+ workspace = torch.cuda.ByteTensor(fn.workspace_size)
check_error(cudnn.lib.cudnnRNNBackwardData(
handle,
fn.rnn_desc,
@@ -386,7 +397,7 @@
fn.x_descs, ctypes.c_void_p(dx.data_ptr()),
fn.hx_desc, ctypes.c_void_p(dhx.data_ptr()),
fn.cx_desc, ctypes.c_void_p(dcx.data_ptr()) if cx is not None else None,
- ctypes.c_void_p(fn.workspace.data_ptr()), fn.workspace.size(0),
+ ctypes.c_void_p(workspace.data_ptr()), workspace.size(0),
ctypes.c_void_p(fn.reserve.data_ptr()), fn.reserve.size(0)
))
@@ -439,6 +450,8 @@
y = output
dw = fn.weight_buf.new().resize_as_(fn.weight_buf).zero_()
+ with torch.cuda.device_of(input):
+ workspace = torch.cuda.ByteTensor(fn.workspace_size)
check_error(cudnn.lib.cudnnRNNBackwardWeights(
handle,
fn.rnn_desc,
@@ -446,7 +459,7 @@
fn.x_descs, ctypes.c_void_p(x.data_ptr()),
fn.hx_desc, ctypes.c_void_p(hx.data_ptr()),
fn.y_descs, ctypes.c_void_p(y.data_ptr()),
- ctypes.c_void_p(fn.workspace.data_ptr()), fn.workspace.size(0),
+ ctypes.c_void_p(workspace.data_ptr()), workspace.size(0),
fn.w_desc, ctypes.c_void_p(dw.data_ptr()),
ctypes.c_void_p(fn.reserve.data_ptr()), fn.reserve.size(0)
))
diff --git a/torch/nn/_functions/rnn.py b/torch/nn/_functions/rnn.py
index 6881a2d..0fbbaf2 100644
--- a/torch/nn/_functions/rnn.py
+++ b/torch/nn/_functions/rnn.py
@@ -1,3 +1,4 @@
+import warnings
from torch.autograd import Function, NestedIOFunction, Variable
import torch.backends.cudnn as cudnn
from .. import functional as F
@@ -207,7 +208,7 @@
def AutogradRNN(mode, input_size, hidden_size, num_layers=1, batch_first=False,
dropout=0, train=True, bidirectional=False, batch_sizes=None,
- dropout_state=None):
+ dropout_state=None, flat_weight=None):
if mode == 'RNN_RELU':
cell = RNNReLUCell
@@ -254,7 +255,7 @@
def __init__(self, mode, input_size, hidden_size, num_layers=1,
batch_first=False, dropout=0, train=True, bidirectional=False,
- batch_sizes=None, dropout_state=None):
+ batch_sizes=None, dropout_state=None, flat_weight=None):
super(CudnnRNN, self).__init__()
if dropout_state is None:
dropout_state = {}
@@ -271,9 +272,16 @@
self.batch_sizes = batch_sizes
self.dropout_seed = torch.IntTensor(1).random_()[0]
self.dropout_state = dropout_state
+ self.weight_buf = flat_weight
+ if flat_weight is None:
+ warnings.warn("RNN module weights are not part of single contiguous "
+ "chunk of memory. This means they need to be compacted "
+ "at every call, possibly greately increasing memory usage. "
+ "To compact weights again call flatten_parameters().", stacklevel=5)
def forward_extended(self, input, weight, hx):
assert cudnn.is_acceptable(input)
+ # TODO: raise a warning if weight_data_ptr is None
output = input.new()
diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py
index 8971880..ae9f515 100644
--- a/torch/nn/modules/rnn.py
+++ b/torch/nn/modules/rnn.py
@@ -1,5 +1,6 @@
import math
import torch
+import warnings
from .module import Module
from ..parameter import Parameter
@@ -23,36 +24,98 @@
self.bidirectional = bidirectional
num_directions = 2 if bidirectional else 1
+ if mode == 'LSTM':
+ gate_size = 4 * hidden_size
+ elif mode == 'GRU':
+ gate_size = 3 * hidden_size
+ else:
+ gate_size = hidden_size
+
self._all_weights = []
+ self._param_buf_size = 0
for layer in range(num_layers):
for direction in range(num_directions):
layer_input_size = input_size if layer == 0 else hidden_size * num_directions
- if mode == 'LSTM':
- gate_size = 4 * hidden_size
- elif mode == 'GRU':
- gate_size = 3 * hidden_size
- else:
- gate_size = hidden_size
w_ih = Parameter(torch.Tensor(gate_size, layer_input_size))
w_hh = Parameter(torch.Tensor(gate_size, hidden_size))
b_ih = Parameter(torch.Tensor(gate_size))
b_hh = Parameter(torch.Tensor(gate_size))
+ layer_params = (w_ih, w_hh, b_ih, b_hh)
suffix = '_reverse' if direction == 1 else ''
- weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}']
- weights = [x.format(layer, suffix) for x in weights]
- setattr(self, weights[0], w_ih)
- setattr(self, weights[1], w_hh)
+ param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
if bias:
- setattr(self, weights[2], b_ih)
- setattr(self, weights[3], b_hh)
- self._all_weights += [weights]
- else:
- self._all_weights += [weights[:2]]
+ param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}']
+ param_names = [x.format(layer, suffix) for x in param_names]
+ for name, param in zip(param_names, layer_params):
+ setattr(self, name, param)
+ self._all_weights.append(param_names)
+
+ self._param_buf_size += sum(p.numel() for p in layer_params)
+
+ self.flatten_parameters()
self.reset_parameters()
+ def flatten_parameters(self):
+ """Resets parameter data pointer so that they can use faster code paths.
+
+ Right now, this works only if the module is on the GPU and cuDNN is enabled.
+ Otherwise, it's a no-op.
+ """
+ any_param = next(self.parameters()).data
+ if not any_param.is_cuda or not torch.backends.cudnn.is_acceptable(any_param):
+ self._data_ptrs = []
+ return
+
+ # This is quite ugly, but it allows us to reuse the cuDNN code without larger
+ # modifications. It's really a low-level API that doesn't belong in here, but
+ # let's make this exception.
+ from torch.backends.cudnn import rnn
+ from torch.backends import cudnn
+ from torch.nn._functions.rnn import CudnnRNN
+ handle = cudnn.get_handle()
+ with warnings.catch_warnings(record=True):
+ fn = CudnnRNN(
+ self.mode,
+ self.input_size,
+ self.hidden_size,
+ num_layers=self.num_layers,
+ batch_first=self.batch_first,
+ dropout=self.dropout,
+ train=self.training,
+ bidirectional=self.bidirectional,
+ dropout_state=self.dropout_state,
+ )
+
+ # Initialize descriptors
+ fn.datatype = cudnn._typemap[any_param.type()]
+ fn.x_descs = cudnn.descriptor(any_param.new(1, self.input_size), 1)
+ fn.rnn_desc = rnn.init_rnn_descriptor(fn, handle)
+
+ # Allocate buffer to hold the weights
+ num_weights = rnn.get_num_weights(handle, fn.rnn_desc, fn.x_descs[0], fn.datatype)
+ fn.weight_buf = any_param.new(num_weights).zero_()
+ fn.w_desc = rnn.init_weight_descriptor(fn, fn.weight_buf)
+
+ # Slice off views into weight_buf
+ params = rnn.get_parameters(fn, handle, fn.weight_buf)
+ all_weights = [[p.data for p in l] for l in self.all_weights]
+
+ # Copy weights and update their storage
+ rnn._copyParams(all_weights, params)
+ for orig_layer_param, new_layer_param in zip(all_weights, params):
+ for orig_param, new_param in zip(orig_layer_param, new_layer_param):
+ orig_param.set_(new_param.view_as(orig_param))
+
+ self._data_ptrs = list(p.data.data_ptr() for p in self.parameters())
+
+ def _apply(self, fn):
+ ret = super(RNNBase, self)._apply(fn)
+ self.flatten_parameters()
+ return ret
+
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
@@ -76,6 +139,13 @@
if self.mode == 'LSTM':
hx = (hx, hx)
+ has_flat_weights = list(p.data.data_ptr() for p in self.parameters()) == self._data_ptrs
+ if has_flat_weights:
+ first_data = next(self.parameters()).data
+ assert first_data.storage().size() == self._param_buf_size
+ flat_weight = first_data.new().set_(first_data.storage(), 0, torch.Size([self._param_buf_size]))
+ else:
+ flat_weight = None
func = self._backend.RNN(
self.mode,
self.input_size,
@@ -86,7 +156,8 @@
train=self.training,
bidirectional=self.bidirectional,
batch_sizes=batch_sizes,
- dropout_state=self.dropout_state
+ dropout_state=self.dropout_state,
+ flat_weight=flat_weight
)
output, hidden = func(input, self.all_weights, hx)
if is_packed: