(#16825)

Summary:
setting the correct math type for cudnn rnn, which is enforced starting from cudnn 7.5+

1. Updating persistent rnn check with input data type instead of rnn math type;
2. Updating rnn type promotion to set correct math type for accumulation;
3. Replace datatype check for filter descriptor from rnn.datatype to input.datatype;
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16825

Differential Revision: D14071190

Pulled By: ezyang

fbshipit-source-id: 1c9a1531ccf510cb0619e830be444c20c5e72f3f
diff --git a/aten/src/ATen/cudnn/Descriptors.h b/aten/src/ATen/cudnn/Descriptors.h
index 2705afc..10d5ccb 100644
--- a/aten/src/ATen/cudnn/Descriptors.h
+++ b/aten/src/ATen/cudnn/Descriptors.h
@@ -223,7 +223,7 @@
   DropoutDescriptor dropout_desc_;
   void set(cudnnHandle_t handle, int hidden_size, int num_layers, DropoutDescriptor&& dropout_desc,
            cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional,
-           cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnRNNAlgo_t algo) {
+           cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo) {
     dropout_desc_ = std::move(dropout_desc);
     AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6(
           handle,
@@ -239,7 +239,7 @@
 #if CUDA_VERSION >= 9000
     cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
     if (prop->major >= 7) {
-      if (datatype == CUDNN_DATA_HALF) {
+      if (input_type == CUDNN_DATA_HALF) {
         cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH);
       } else {
         // Technically, as the default it's not necessary to explicitly
diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp
index 865d967..39e0e1c 100644
--- a/aten/src/ATen/native/cudnn/RNN.cpp
+++ b/aten/src/ATen/native/cudnn/RNN.cpp
@@ -99,6 +99,7 @@
     cudnnDirectionMode_t bidirectional;
     cudnnRNNMode_t mode;
     cudnnDataType_t datatype;
+    cudnnDataType_t input_datatype;
     cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD;
     cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
 
@@ -137,18 +138,19 @@
       this->algo = algo;
     }
 
-    void set(int64_t mode, int64_t hidden_size, int64_t num_layers, bool bidirectional, cudnnDataType_t datatype) {
+    void set(int64_t mode, int64_t hidden_size, int64_t num_layers, bool bidirectional, cudnnDataType_t datatype, cudnnDataType_t input_datatype) {
       this->set_mode(mode);
       this->hidden_size = hidden_size;
       this->num_layers = num_layers;
       this->set_bidirectional(bidirectional);
       this->datatype = datatype;
+      this->input_datatype = input_datatype;
     }
 
 
     RNNDescriptor descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const {
       RNNDescriptor rnn_desc;
-      rnn_desc.set(handle, hidden_size, num_layers, std::move(dropout_desc), input_mode, bidirectional, mode, datatype, algo);
+      rnn_desc.set(handle, hidden_size, num_layers, std::move(dropout_desc), input_mode, bidirectional, mode, datatype, input_datatype, algo);
       return rnn_desc;
     }
 
@@ -448,7 +450,7 @@
 
           AT_ASSERTM(nb_dims <= min_dim, "nb_dims = ", nb_dims, "; min_dim  = ", min_dim);
           filter_dim_a = filter_dim_a.slice(0, 0, nb_dims);
-          auto elem_size = dataSize(rnn.datatype);
+          auto elem_size = dataSize(getCudnnDataType(weight_buf));
           auto offset_bytes = (char*)matrix_pointer - (char*)weight_buf.data_ptr();
           AT_ASSERTM(offset_bytes % elem_size == 0, "offset_bytes = ", offset_bytes, "; elem_size = ", elem_size);
           size_t offset = offset_bytes / elem_size;
@@ -575,14 +577,14 @@
     }
   }
 
-  cudnnRNNAlgo_t get_algo(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors){
+  cudnnRNNAlgo_t get_algo(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors, const Tensor input){
 #if CUDNN_VERSION < 7200 || CUDA_VERSION < 9010
       return CUDNN_RNN_ALGO_STANDARD;
 #else
       cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
       const int64_t bsize = tensors.mini_batch;
       //excluding Turing from using persistent rnn.
-      if (prop->major == 7 && prop->minor != 5 && rnn.datatype == CUDNN_DATA_HALF && !tensors.is_input_packed()) {
+      if (prop->major == 7 && prop->minor != 5 && getCudnnDataType(input) == CUDNN_DATA_HALF && !tensors.is_input_packed()) {
           if (rnn.num_layers == 1 && rnn.hidden_size <= 1024 && rnn.num_directions() == 1 &&
                   rnn.hidden_size % 128 == 0 && tensors.input_size % 128 == 0){
               //technically, batch size should be multiple of 8, but there are quite a few multiple-of-8 batchsizes that give bad perf,
@@ -600,6 +602,17 @@
 #endif
   }
 
+  cudnnDataType_t promote_rnn_math_type(cudnnDataType_t dtype) {
+#if CUDNN_VERSION != 7103
+// CUDNN 7.1.3 enforces RNN descriptor type to be identical to input/weight. This check throws an error for type
+// promotion. The check has since been removed.
+    if (dtype == CUDNN_DATA_HALF) {
+      return CUDNN_DATA_FLOAT;
+    }
+#endif
+    return dtype;
+  }
+
 } // anonymous namespace
 
 // NB: does inplace update into TensorList
@@ -618,9 +631,10 @@
            "_cudnn_rnn_flatten_weight_: cannot flatten empty weight list");
 
   auto any_param = weight_arr[0];
+  auto datatype = getCudnnDataType(any_param);
 
   RNNDescriptorParams rnn;
-  rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, getCudnnDataType(any_param));
+  rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype);
 
   auto handle = getCudnnHandle();
   RNNDescriptor rnn_desc = rnn.descriptor(handle);
@@ -629,7 +643,7 @@
   TensorDescriptor x_desc;
   x_desc.set(getCudnnDataType(any_param), x_geom.sizes(), x_geom.strides(), 5);
 
-  auto num_weights = get_num_weights(handle, rnn_desc, x_desc, rnn.datatype);
+  auto num_weights = get_num_weights(handle, rnn_desc, x_desc, datatype);
   auto weight_buf = at::zeros(num_weights, any_param.options());
 
   FilterDescriptor w_desc;
@@ -679,7 +693,8 @@
       checkSameGPU("cudnn_rnn", input_arg, dropout_state_arg);
   }
   RNNParams fn;
-  fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, getCudnnDataType(input));
+  auto datatype = getCudnnDataType(input);
+  fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype);
   fn.dropout.set(fn_train, fn_dropout, fn_dropout_state);
   fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
 
@@ -716,13 +731,13 @@
   auto y = output;
 
   auto handle = getCudnnHandle();
-  cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors);
+  cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input);
   fn.rnn.set_algo(algo);
   RNNDescriptors descs(fn, handle, x, y, hx, cx);
 
   FilterDescriptor w_desc;
   if (!weight_buf.defined()) {
-    auto num_weights = get_num_weights(handle, descs.rnn_desc, descs.x_descs[0], fn.rnn.datatype);
+    auto num_weights = get_num_weights(handle, descs.rnn_desc, descs.x_descs[0], datatype);
     weight_buf = at::empty(num_weights, x.options());
     w_desc.set(weight_buf, 3);
     weight_buf.zero_();
@@ -818,7 +833,8 @@
   auto output = output_r;
 
   RNNParams fn;
-  fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, getCudnnDataType(input));
+  auto datatype = getCudnnDataType(input);
+  fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype);
   fn.dropout.set(fn_train, fn_dropout, fn_dropout_state);
   fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
 
@@ -877,7 +893,7 @@
   AT_CHECK(dhy.is_cuda() && dy.is_cuda() && (!dcy.defined() || dcy.is_cuda()),
            "Gradients aren't CUDA tensors");
 
-  cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors);
+  cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input);
   fn.rnn.set_algo(algo);
   RNNDescriptors descs(fn, handle, x, y, hx, cx);
 
@@ -941,7 +957,8 @@
   auto output = output_r;
 
   RNNParams fn;
-  fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, getCudnnDataType(input));
+  auto datatype = getCudnnDataType(input);
+  fn.rnn.set(fn_mode, fn_hidden_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype);
   fn.dropout.set(fn_train, fn_dropout, fn_dropout_state);
   fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
 
@@ -981,7 +998,7 @@
   const auto& y = output;
   auto dw = at::zeros(weight_buf.sizes(), weight_buf.options());
 
-  cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors);
+  cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input);
   fn.rnn.set_algo(algo);
   RNNDescriptors descs(fn, handle, x, y, hx, cx);
 
@@ -1162,7 +1179,7 @@
   auto datatype = getCudnnDataType(input);
 
   RNNDescriptorParams rnn;
-  rnn.set(mode, hidden_size, num_layers, bidirectional, datatype);
+  rnn.set(mode, hidden_size, num_layers, bidirectional, promote_rnn_math_type(datatype), datatype);
   RNNDescriptor rnn_desc = rnn.descriptor(handle);
 
   TensorGeometry x_geom ({1, input.size(-1)});