put dropout states on the input device (#7515)
* put dropout states on the input device
* add assert to aten, add test, fix lint
* only assert device if states are defined
diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp
index e171da2..064a4f46 100644
--- a/aten/src/ATen/native/cudnn/RNN.cpp
+++ b/aten/src/ATen/native/cudnn/RNN.cpp
@@ -1,4 +1,5 @@
#include <ATen/ATen.h>
+#include <ATen/TensorUtils.h>
#include <ATen/Config.h>
#include <ATen/Error.h>
#include <ATen/MatrixRef.h>
@@ -586,7 +587,11 @@
auto input = input_r;
auto weight_buf = weight_buf_r;
-
+ if (fn_dropout_state.defined()) {
+ auto input_arg = TensorArg(input, "input", 1);
+ auto dropout_state_arg = TensorArg(fn_dropout_state, "dropout_states", 15);
+ checkSameGPU("cudnn_rnn", input_arg, dropout_state_arg);
+ }
RNNParams fn;
fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, getCudnnDataType(input));
fn.dropout.set(fn_train, fn_dropout, fn_dropout_state);
diff --git a/test/test_nn.py b/test/test_nn.py
index fa398fc..91dd596 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -3381,6 +3381,15 @@
(hx + cx).sum().backward()
+ @unittest.skipIf(not (TEST_CUDNN and TEST_MULTIGPU), 'CUDNN or multi-gpu not available')
+ def test_cudnn_rnn_dropout_states_device(self):
+ rnn = nn.RNN(10, 20, num_layers=2, dropout=.5)
+ device = 1
+ input = torch.randn(5, 4, 10).cuda(device)
+ rnn.cuda(device)
+ hx = torch.randn(2, 4, 20).cuda(device)
+ output = rnn(input, hx)
+
@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
def test_cudnn_weight_format(self):
rnns = [
diff --git a/torch/nn/_functions/rnn.py b/torch/nn/_functions/rnn.py
index fc6a82d..2fe44ee 100644
--- a/torch/nn/_functions/rnn.py
+++ b/torch/nn/_functions/rnn.py
@@ -271,8 +271,9 @@
cx = None
handle = cudnn.get_handle()
- dropout_ts = cudnn.rnn.init_dropout_state(torch.uint8, torch.device('cuda'), dropout,
- train, dropout_seed, dropout_state)
+ with torch.cuda.device(input.get_device()):
+ dropout_ts = cudnn.rnn.init_dropout_state(torch.uint8, torch.device('cuda'), dropout,
+ train, dropout_seed, dropout_state)
weight_arr = list(itertools.chain.from_iterable(weight))
weight_stride0 = len(weight[0])