make cudnn rnn respect current stream (#27026)

Summary:
Make cudnn rnn respect current stream. After this lands, non-default test stream can be reenabled in https://github.com/pytorch/pytorch/issues/26791
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27026

Test Plan: default stream functionality is tested in existing tests, stream safety tests will be added in https://github.com/pytorch/pytorch/issues/26791

Differential Revision: D17656967

Pulled By: ngimel

fbshipit-source-id: 8b051aedd1df089b21f666ec553a5acefffdac88
diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h
index 4da4481..62a9a18 100644
--- a/aten/src/ATen/cudnn/Descriptors.h
+++ b/aten/src/ATen/cudnn/Descriptors.h
@@ -4,6 +4,7 @@
 #include <ATen/cuda/Exceptions.h>
 
 #include <ATen/cudnn/cudnn-wrapper.h>
+#include <ATen/cudnn/Utils.h>
 #include <ATen/ATen.h>
 #include <ATen/TensorUtils.h>
 #include <ATen/cuda/ATenCUDAGeneral.h>
@@ -190,6 +191,7 @@
     AT_ASSERT(options.device().type() == kCUDA);
     AT_ASSERT(options.dtype() == kByte);
     state = at::empty({static_cast<int64_t>(state_size)}, options);
+    setCuDNNStreamToCurrent();
     AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed));
   }
 
@@ -200,6 +202,7 @@
     void *state_ptr = state.data_ptr();
     size_t state_size = state.size(0);
     // NB: The seed doesn't actually matter, so we give a dummy value
+    setCuDNNStreamToCurrent();
     AT_CUDNN_CHECK(cudnnRestoreDropoutDescriptor(mut_desc(), handle, dropout, state_ptr, state_size, 0 /* seed */));
   }
 
diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp
index 6ddacca..1a8b890 100644
--- a/aten/src/ATen/native/cudnn/RNN.cpp
+++ b/aten/src/ATen/native/cudnn/RNN.cpp
@@ -778,6 +778,7 @@
           &reserve_size
           ));
     reserve = at::empty(reserve_size, input.options().dtype(kByte));
+    setCuDNNStreamToCurrent();
     AT_CUDNN_CHECK(cudnnRNNForwardTraining(
           handle,
           descs.rnn_desc.desc(),
@@ -794,6 +795,7 @@
           ));
   } else { // inference
     reserve = at::empty({0}, input.options().dtype(kByte));
+    setCuDNNStreamToCurrent();
     AT_CUDNN_CHECK(cudnnRNNForwardInference(
           handle,
           descs.rnn_desc.desc(),
@@ -912,7 +914,7 @@
         ));
   // TODO: put this in the correct device???
   Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte));
-
+  setCuDNNStreamToCurrent();
   AT_CUDNN_CHECK(cudnnRNNBackwardData(
         handle,
         descs.rnn_desc.desc(),
@@ -1016,7 +1018,7 @@
         &workspace_size
         ));
   Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte));
-
+  setCuDNNStreamToCurrent();
   AT_CUDNN_CHECK(cudnnRNNBackwardWeights(
         handle,
         descs.rnn_desc.desc(),