Initial support for building on Ampere GPU, CUDA 11, cuDNN 8 (#39277)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39277
This PR contains initial changes that makes PyTorch build with Ampere GPU, CUDA 11, and cuDNN 8.
TF32 related features will not be included in this PR.
Test Plan: Imported from OSS
Differential Revision: D21832814
Pulled By: malfet
fbshipit-source-id: 37f9c6827e0c26ae3e303580f666584230832d06
diff --git a/caffe2/operators/rnn/recurrent_op_cudnn.cc b/caffe2/operators/rnn/recurrent_op_cudnn.cc
index 8f69944..3679c9d 100644
--- a/caffe2/operators/rnn/recurrent_op_cudnn.cc
+++ b/caffe2/operators/rnn/recurrent_op_cudnn.cc
@@ -99,7 +99,7 @@
// RNN setup
{
#if CUDNN_VERSION_MIN(7, 0, 0)
- CUDNN_ENFORCE(cudnnSetRNNDescriptor(
+ CUDNN_ENFORCE(cudnnSetRNNDescriptor_v6(
cudnn_wrapper_.inline_cudnn_handle(),
rnnDesc_,
hiddenSize,
diff --git a/caffe2/utils/GpuDefs.cuh b/caffe2/utils/GpuDefs.cuh
index 75897e8..46d8058 100644
--- a/caffe2/utils/GpuDefs.cuh
+++ b/caffe2/utils/GpuDefs.cuh
@@ -8,7 +8,7 @@
// Static definition of GPU warp size for unrolling and code generation
#ifdef __CUDA_ARCH__
-#if __CUDA_ARCH__ <= 750
+#if __CUDA_ARCH__ <= 800
constexpr int kWarpSize = 32;
#else
#error Unknown __CUDA_ARCH__; please define parameters for compute capability
diff --git a/cmake/public/cuda.cmake b/cmake/public/cuda.cmake
index 272b6fa..873dd58 100644
--- a/cmake/public/cuda.cmake
+++ b/cmake/public/cuda.cmake
@@ -145,7 +145,11 @@
# ---[ Extract versions
if(CAFFE2_USE_CUDNN)
# Get cuDNN version
- file(READ ${CUDNN_INCLUDE_PATH}/cudnn.h CUDNN_HEADER_CONTENTS)
+ if(EXISTS ${CUDNN_INCLUDE_PATH}/cudnn_version.h)
+ file(READ ${CUDNN_INCLUDE_PATH}/cudnn_version.h CUDNN_HEADER_CONTENTS)
+ else()
+ file(READ ${CUDNN_INCLUDE_PATH}/cudnn.h CUDNN_HEADER_CONTENTS)
+ endif()
string(REGEX MATCH "define CUDNN_MAJOR * +([0-9]+)"
CUDNN_VERSION_MAJOR "${CUDNN_HEADER_CONTENTS}")
string(REGEX REPLACE "define CUDNN_MAJOR * +([0-9]+)" "\\1"
diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py
index 622f722..9f2fcf1 100644
--- a/torch/utils/cpp_extension.py
+++ b/torch/utils/cpp_extension.py
@@ -1372,10 +1372,11 @@
('Pascal', '6.0;6.1+PTX'),
('Volta', '7.0+PTX'),
('Turing', '7.5+PTX'),
+ ('Ampere', '8.0+PTX'),
])
supported_arches = ['3.5', '3.7', '5.0', '5.2', '5.3', '6.0', '6.1', '6.2',
- '7.0', '7.2', '7.5']
+ '7.0', '7.2', '7.5', '8.0']
valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches]
# The default is sm_30 for CUDA 9.x and 10.x