make cudnn rnn respect current stream (#27026)
Summary:
Make cudnn rnn respect current stream. After this lands, non-default test stream can be reenabled in https://github.com/pytorch/pytorch/issues/26791
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27026
Test Plan: default stream functionality is tested in existing tests, stream safety tests will be added in https://github.com/pytorch/pytorch/issues/26791
Differential Revision: D17656967
Pulled By: ngimel
fbshipit-source-id: 8b051aedd1df089b21f666ec553a5acefffdac88
diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h
index 4da4481..62a9a18 100644
--- a/aten/src/ATen/cudnn/Descriptors.h
+++ b/aten/src/ATen/cudnn/Descriptors.h
@@ -4,6 +4,7 @@
#include <ATen/cuda/Exceptions.h>
#include <ATen/cudnn/cudnn-wrapper.h>
+#include <ATen/cudnn/Utils.h>
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/ATenCUDAGeneral.h>
@@ -190,6 +191,7 @@
AT_ASSERT(options.device().type() == kCUDA);
AT_ASSERT(options.dtype() == kByte);
state = at::empty({static_cast<int64_t>(state_size)}, options);
+ setCuDNNStreamToCurrent();
AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed));
}
@@ -200,6 +202,7 @@
void *state_ptr = state.data_ptr();
size_t state_size = state.size(0);
// NB: The seed doesn't actually matter, so we give a dummy value
+ setCuDNNStreamToCurrent();
AT_CUDNN_CHECK(cudnnRestoreDropoutDescriptor(mut_desc(), handle, dropout, state_ptr, state_size, 0 /* seed */));
}
diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp
index 6ddacca..1a8b890 100644
--- a/aten/src/ATen/native/cudnn/RNN.cpp
+++ b/aten/src/ATen/native/cudnn/RNN.cpp
@@ -778,6 +778,7 @@
&reserve_size
));
reserve = at::empty(reserve_size, input.options().dtype(kByte));
+ setCuDNNStreamToCurrent();
AT_CUDNN_CHECK(cudnnRNNForwardTraining(
handle,
descs.rnn_desc.desc(),
@@ -794,6 +795,7 @@
));
} else { // inference
reserve = at::empty({0}, input.options().dtype(kByte));
+ setCuDNNStreamToCurrent();
AT_CUDNN_CHECK(cudnnRNNForwardInference(
handle,
descs.rnn_desc.desc(),
@@ -912,7 +914,7 @@
));
// TODO: put this in the correct device???
Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte));
-
+ setCuDNNStreamToCurrent();
AT_CUDNN_CHECK(cudnnRNNBackwardData(
handle,
descs.rnn_desc.desc(),
@@ -1016,7 +1018,7 @@
&workspace_size
));
Tensor workspace = at::empty(workspace_size, input.options().dtype(kByte));
-
+ setCuDNNStreamToCurrent();
AT_CUDNN_CHECK(cudnnRNNBackwardWeights(
handle,
descs.rnn_desc.desc(),