(#16825)
Summary:
setting the correct math type for cudnn rnn, which is enforced starting from cudnn 7.5+
1. Updating persistent rnn check with input data type instead of rnn math type;
2. Updating rnn type promotion to set correct math type for accumulation;
3. Replace datatype check for filter descriptor from rnn.datatype to input.datatype;
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16825
Differential Revision: D14071190
Pulled By: ezyang
fbshipit-source-id: 1c9a1531ccf510cb0619e830be444c20c5e72f3f
diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h
index 2705afc..10d5ccb 100644
--- a/aten/src/ATen/cudnn/Descriptors.h
+++ b/aten/src/ATen/cudnn/Descriptors.h
@@ -223,7 +223,7 @@
DropoutDescriptor dropout_desc_;
void set(cudnnHandle_t handle, int hidden_size, int num_layers, DropoutDescriptor&& dropout_desc,
cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional,
- cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnRNNAlgo_t algo) {
+ cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo) {
dropout_desc_ = std::move(dropout_desc);
AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6(
handle,
@@ -239,7 +239,7 @@
#if CUDA_VERSION >= 9000
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 7) {
- if (datatype == CUDNN_DATA_HALF) {
+ if (input_type == CUDNN_DATA_HALF) {
cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH);
} else {
// Technically, as the default it's not necessary to explicitly
diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp
index 865d967..39e0e1c 100644
--- a/aten/src/ATen/native/cudnn/RNN.cpp
+++ b/aten/src/ATen/native/cudnn/RNN.cpp
@@ -99,6 +99,7 @@
cudnnDirectionMode_t bidirectional;
cudnnRNNMode_t mode;
cudnnDataType_t datatype;
+ cudnnDataType_t input_datatype;
cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD;
cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
@@ -137,18 +138,19 @@
this->algo = algo;
}
- void set(int64_t mode, int64_t hidden_size, int64_t num_layers, bool bidirectional, cudnnDataType_t datatype) {
+ void set(int64_t mode, int64_t hidden_size, int64_t num_layers, bool bidirectional, cudnnDataType_t datatype, cudnnDataType_t input_datatype) {
this->set_mode(mode);
this->hidden_size = hidden_size;
this->num_layers = num_layers;
this->set_bidirectional(bidirectional);
this->datatype = datatype;
+ this->input_datatype = input_datatype;
}
RNNDescriptor descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const {
RNNDescriptor rnn_desc;
- rnn_desc.set(handle, hidden_size, num_layers, std::move(dropout_desc), input_mode, bidirectional, mode, datatype, algo);
+ rnn_desc.set(handle, hidden_size, num_layers, std::move(dropout_desc), input_mode, bidirectional, mode, datatype, input_datatype, algo);
return rnn_desc;
}
@@ -448,7 +450,7 @@
AT_ASSERTM(nb_dims <= min_dim, "nb_dims = ", nb_dims, "; min_dim = ", min_dim);
filter_dim_a = filter_dim_a.slice(0, 0, nb_dims);
- auto elem_size = dataSize(rnn.datatype);
+ auto elem_size = dataSize(getCudnnDataType(weight_buf));
auto offset_bytes = (char*)matrix_pointer - (char*)weight_buf.data_ptr();
AT_ASSERTM(offset_bytes % elem_size == 0, "offset_bytes = ", offset_bytes, "; elem_size = ", elem_size);
size_t offset = offset_bytes / elem_size;
@@ -575,14 +577,14 @@
}
}
- cudnnRNNAlgo_t get_algo(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors){
+ cudnnRNNAlgo_t get_algo(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors, const Tensor input){
#if CUDNN_VERSION < 7200 || CUDA_VERSION < 9010
return CUDNN_RNN_ALGO_STANDARD;
#else
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
const int64_t bsize = tensors.mini_batch;
//excluding Turing from using persistent rnn.
- if (prop->major == 7 && prop->minor != 5 && rnn.datatype == CUDNN_DATA_HALF && !tensors.is_input_packed()) {
+ if (prop->major == 7 && prop->minor != 5 && getCudnnDataType(input) == CUDNN_DATA_HALF && !tensors.is_input_packed()) {
if (rnn.num_layers == 1 && rnn.hidden_size <= 1024 && rnn.num_directions() == 1 &&
rnn.hidden_size % 128 == 0 && tensors.input_size % 128 == 0){
//technically, batch size should be multiple of 8, but there are quite a few multiple-of-8 batchsizes that give bad perf,
@@ -600,6 +602,17 @@
#endif
}
+ cudnnDataType_t promote_rnn_math_type(cudnnDataType_t dtype) {
+#if CUDNN_VERSION != 7103
+// CUDNN 7.1.3 enforces RNN descriptor type to be identical to input/weight. This check throws an error for type
+// promotion. The check has since been removed.
+ if (dtype == CUDNN_DATA_HALF) {
+ return CUDNN_DATA_FLOAT;
+ }
+#endif
+ return dtype;
+ }
+
} // anonymous namespace
// NB: does inplace update into TensorList
@@ -618,9 +631,10 @@
"_cudnn_rnn_flatten_weight_: cannot flatten empty weight list");
auto any_param = weight_arr[0];
+ auto datatype = getCudnnDataType(any_param);
RNNDescriptorParams rnn;
- rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, getCudnnDataType(any_param));
+ rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype);
auto handle = getCudnnHandle();
RNNDescriptor rnn_desc = rnn.descriptor(handle);
@@ -629,7 +643,7 @@
TensorDescriptor x_desc;
x_desc.set(getCudnnDataType(any_param), x_geom.sizes(), x_geom.strides(), 5);
- auto num_weights = get_num_weights(handle, rnn_desc, x_desc, rnn.datatype);
+ auto num_weights = get_num_weights(handle, rnn_desc, x_desc, datatype);
auto weight_buf = at::zeros(num_weights, any_param.options());
FilterDescriptor w_desc;
@@ -679,7 +693,8 @@
checkSameGPU("cudnn_rnn", input_arg, dropout_state_arg);
}
RNNParams fn;
- fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, getCudnnDataType(input));
+ auto datatype = getCudnnDataType(input);
+ fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype);
fn.dropout.set(fn_train, fn_dropout, fn_dropout_state);
fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
@@ -716,13 +731,13 @@
auto y = output;
auto handle = getCudnnHandle();
- cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors);
+ cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input);
fn.rnn.set_algo(algo);
RNNDescriptors descs(fn, handle, x, y, hx, cx);
FilterDescriptor w_desc;
if (!weight_buf.defined()) {
- auto num_weights = get_num_weights(handle, descs.rnn_desc, descs.x_descs[0], fn.rnn.datatype);
+ auto num_weights = get_num_weights(handle, descs.rnn_desc, descs.x_descs[0], datatype);
weight_buf = at::empty(num_weights, x.options());
w_desc.set(weight_buf, 3);
weight_buf.zero_();
@@ -818,7 +833,8 @@
auto output = output_r;
RNNParams fn;
- fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, getCudnnDataType(input));
+ auto datatype = getCudnnDataType(input);
+ fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype);
fn.dropout.set(fn_train, fn_dropout, fn_dropout_state);
fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
@@ -877,7 +893,7 @@
AT_CHECK(dhy.is_cuda() && dy.is_cuda() && (!dcy.defined() || dcy.is_cuda()),
"Gradients aren't CUDA tensors");
- cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors);
+ cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input);
fn.rnn.set_algo(algo);
RNNDescriptors descs(fn, handle, x, y, hx, cx);
@@ -941,7 +957,8 @@
auto output = output_r;
RNNParams fn;
- fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, getCudnnDataType(input));
+ auto datatype = getCudnnDataType(input);
+ fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype);
fn.dropout.set(fn_train, fn_dropout, fn_dropout_state);
fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
@@ -981,7 +998,7 @@
const auto& y = output;
auto dw = at::zeros(weight_buf.sizes(), weight_buf.options());
- cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors);
+ cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input);
fn.rnn.set_algo(algo);
RNNDescriptors descs(fn, handle, x, y, hx, cx);
@@ -1162,7 +1179,7 @@
auto datatype = getCudnnDataType(input);
RNNDescriptorParams rnn;
- rnn.set(mode, hidden_size, num_layers, bidirectional, datatype);
+ rnn.set(mode, hidden_size, num_layers, bidirectional, promote_rnn_math_type(datatype), datatype);
RNNDescriptor rnn_desc = rnn.descriptor(handle);
TensorGeometry x_geom ({1, input.size(-1)});