put dropout states on the input device (#7515)

* put dropout states on the input device

* add assert to aten, add test, fix lint

* only assert device if states are defined
diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp
index e171da2..064a4f46 100644
--- a/aten/src/ATen/native/cudnn/RNN.cpp
+++ b/aten/src/ATen/native/cudnn/RNN.cpp
@@ -1,4 +1,5 @@
 #include <ATen/ATen.h>
+#include <ATen/TensorUtils.h>
 #include <ATen/Config.h>
 #include <ATen/Error.h>
 #include <ATen/MatrixRef.h>
@@ -586,7 +587,11 @@
 
   auto input = input_r;
   auto weight_buf = weight_buf_r;
-
+  if (fn_dropout_state.defined()) {
+      auto input_arg = TensorArg(input, "input", 1);
+      auto dropout_state_arg = TensorArg(fn_dropout_state, "dropout_states", 15);
+      checkSameGPU("cudnn_rnn", input_arg, dropout_state_arg);
+  }
   RNNParams fn;
   fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, getCudnnDataType(input));
   fn.dropout.set(fn_train, fn_dropout, fn_dropout_state);
diff --git a/test/test_nn.py b/test/test_nn.py
index fa398fc..91dd596 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -3381,6 +3381,15 @@
 
             (hx + cx).sum().backward()
 
+    @unittest.skipIf(not (TEST_CUDNN and TEST_MULTIGPU), 'CUDNN or multi-gpu not available')
+    def test_cudnn_rnn_dropout_states_device(self):
+        rnn = nn.RNN(10, 20, num_layers=2, dropout=.5)
+        device = 1
+        input = torch.randn(5, 4, 10).cuda(device)
+        rnn.cuda(device)
+        hx = torch.randn(2, 4, 20).cuda(device)
+        output = rnn(input, hx)
+
     @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
     def test_cudnn_weight_format(self):
         rnns = [
diff --git a/torch/nn/_functions/rnn.py b/torch/nn/_functions/rnn.py
index fc6a82d..2fe44ee 100644
--- a/torch/nn/_functions/rnn.py
+++ b/torch/nn/_functions/rnn.py
@@ -271,8 +271,9 @@
             cx = None
 
         handle = cudnn.get_handle()
-        dropout_ts = cudnn.rnn.init_dropout_state(torch.uint8, torch.device('cuda'), dropout,
-                                                  train, dropout_seed, dropout_state)
+        with torch.cuda.device(input.get_device()):
+            dropout_ts = cudnn.rnn.init_dropout_state(torch.uint8, torch.device('cuda'), dropout,
+                                                      train, dropout_seed, dropout_state)
 
         weight_arr = list(itertools.chain.from_iterable(weight))
         weight_stride0 = len(weight[0])