Disable cuDNN persistent RNN on sm_86 devices (#49534)
Summary:
Excludes sm_86 GPU devices from using cuDNN persistent RNN.
This is because there are some hard-to-detect edge cases that will throw exceptions with cudnn 8.0.5 on Nvidia A40 GPU.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49534
Reviewed By: mruberry
Differential Revision: D25632378
Pulled By: mrshenli
fbshipit-source-id: cbe78236d85d4d0c2e4ca63a3fc2c4e2de662d9e
diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp
index 1dc9d5c..8e1f254 100644
--- a/aten/src/ATen/native/cudnn/RNN.cpp
+++ b/aten/src/ATen/native/cudnn/RNN.cpp
@@ -722,6 +722,11 @@
(tensors.seq_length >=10 && bsize <=32));
}
} else if (prop->major >= 8) {
+ if (prop->minor == 6) {
+ // Excludes sm_86 GPU devices from using persistent rnn.
+ // This is because there are some edge cases that will throw exceptions with cudnn 8.0.5 on Nvidia A40 GPU.
+ return false;
+ }
// Based on tests by Vasily Volkov and xwang233. Vasily only tried bsize <= 128,
// so conservatively enable persistence for bsize <= 128 only.
// TODO: Run more tests for bsize > 128.