Optimize order_swich_ops on GPU (#11404)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11404

Optimize order_swich_ops on GPU

Reviewed By: houseroad

Differential Revision: D9728642

fbshipit-source-id: 74ff62268856fb1613fa61eb214bed6ec6716632
diff --git a/caffe2/operators/order_switch_ops.cu b/caffe2/operators/order_switch_ops.cu
index 27a71a6..c213b7c 100644
--- a/caffe2/operators/order_switch_ops.cu
+++ b/caffe2/operators/order_switch_ops.cu
@@ -1,91 +1,115 @@
 #include "caffe2/operators/order_switch_ops.h"
+
 #include "caffe2/core/context_gpu.h"
+#include "caffe2/utils/fixed_divisor.h"
 
 namespace caffe2 {
 
-__global__ void NHWC2NCHWKernel(
-    const int N,
-    const int HW,
-    const int C,
-    const float* X,
-    float* Y) {
-  CUDA_1D_KERNEL_LOOP(i, N * HW * C) {
-    const int c = i % C;
-    const int hw = i / C % HW;
-    const int n = i / C / HW;
-    Y[(n * C + c) * HW + hw] = X[i];
+template <typename T>
+__global__ void NHWC2NCHWCUDAKernel(
+    const int size,
+    const FixedDivisor<int> C,
+    const FixedDivisor<int> HxW,
+    const T* X,
+    T* Y) {
+  CUDA_1D_KERNEL_LOOP(i, size) {
+    int n;
+    int c;
+    int hxw;
+    HxW.DivMod(i, &c, &hxw);
+    C.DivMod(c, &n, &c);
+#if __CUDA_ARCH__ >= 350
+    Y[i] = __ldg(X + (n * HxW.d() + hxw) * C.d() + c);
+#else
+    Y[i] = X[(n * HxW.d() + hxw) * C.d() + c];
+#endif
   }
 }
 
-__global__ void NCHW2NHWCKernel(
-    const int N,
-    const int C,
-    const int HW,
-    const float* X,
-    float* Y) {
-  CUDA_1D_KERNEL_LOOP(i, N * C * HW) {
-    const int hw = i % HW;
-    const int c = i / HW % C;
-    const int n = i / C / HW;
-    Y[(n * HW + hw) * C + c] = X[i];
+template <typename T>
+__global__ void NCHW2NHWCCUDAKernel(
+    const int size,
+    const FixedDivisor<int> C,
+    const FixedDivisor<int> HxW,
+    const T* X,
+    T* Y) {
+  CUDA_1D_KERNEL_LOOP(i, size) {
+    int n;
+    int c;
+    int hxw;
+    C.DivMod(i, &hxw, &c);
+    HxW.DivMod(hxw, &n, &hxw);
+#if __CUDA_ARCH__ >= 350
+    Y[i] = __ldg(X + (n * C.d() + c) * HxW.d() + hxw);
+#else
+    Y[i] = X[(n * C.d() + c) * HxW.d() + hxw];
+#endif
   }
 }
 
 template <>
 bool NHWC2NCHWOp<float, CUDAContext>::RunOnDevice() {
-  auto& X = Input(0);
+  const auto& X = Input(0);
   auto* Y = Output(0);
-
-  auto ndim = X.ndim();
-  DCHECK_GE(ndim, 3);
-  const int N = X.dim32(0), C = X.dim32(ndim - 1);
+  const int ndim = X.ndim();
+  CAFFE_ENFORCE_GE(ndim, 3);
+  const int N = X.dim32(0);
+  const int C = X.dim32(ndim - 1);
   vector<TIndex> Y_dims(ndim);
   Y_dims[0] = N;
   Y_dims[1] = C;
-  size_t image_size = 1;
-  for (auto i = 2; i < ndim; ++i) {
+  int HxW = 1;
+  for (int i = 2; i < ndim; ++i) {
     Y_dims[i] = X.dim32(i - 1);
-    image_size *= Y_dims[i];
+    HxW *= Y_dims[i];
   }
   Y->Resize(Y_dims);
-
-  NHWC2NCHWKernel<<<
-      CAFFE_GET_BLOCKS(X.size()),
-      CAFFE_CUDA_NUM_THREADS,
-      0,
-      context_.cuda_stream()>>>(
-      N, image_size, C, X.data<float>(), Y->template mutable_data<float>());
+  const int size = X.size();
+  NHWC2NCHWCUDAKernel<float>
+      <<<CAFFE_GET_BLOCKS(size),
+         CAFFE_CUDA_NUM_THREADS,
+         0,
+         context_.cuda_stream()>>>(
+          size,
+          FixedDivisor<int>(C),
+          FixedDivisor<int>(HxW),
+          X.data<float>(),
+          Y->template mutable_data<float>());
   return true;
 }
 
 template <>
 bool NCHW2NHWCOp<float, CUDAContext>::RunOnDevice() {
-  auto& X = Input(0);
+  const auto& X = Input(0);
   auto* Y = Output(0);
-
-  auto ndim = X.ndim();
-  DCHECK_GE(X.ndim(), 3);
-  const int N = X.dim32(0), C = X.dim32(1);
+  const int ndim = X.ndim();
+  CAFFE_ENFORCE_GE(X.ndim(), 3);
+  const int N = X.dim32(0);
+  const int C = X.dim32(1);
   vector<TIndex> Y_dims(ndim);
   Y_dims[0] = N;
-  size_t image_size = 1;
+  int HxW = 1;
   for (auto i = 1; i < ndim - 1; ++i) {
     Y_dims[i] = X.dim32(i + 1);
-    image_size *= Y_dims[i];
+    HxW *= Y_dims[i];
   }
   Y_dims[ndim - 1] = C;
   Y->Resize(Y_dims);
-
-  NCHW2NHWCKernel<<<
-      CAFFE_GET_BLOCKS(X.size()),
-      CAFFE_CUDA_NUM_THREADS,
-      0,
-      context_.cuda_stream()>>>(
-      N, C, image_size, X.data<float>(), Y->template mutable_data<float>());
+  const int size = X.size();
+  NCHW2NHWCCUDAKernel<float>
+      <<<CAFFE_GET_BLOCKS(size),
+         CAFFE_CUDA_NUM_THREADS,
+         0,
+         context_.cuda_stream()>>>(
+          size,
+          FixedDivisor<int>(C),
+          FixedDivisor<int>(HxW),
+          X.data<float>(),
+          Y->template mutable_data<float>());
   return true;
 }
 
-
 REGISTER_CUDA_OPERATOR(NHWC2NCHW, NHWC2NCHWOp<float, CUDAContext>);
 REGISTER_CUDA_OPERATOR(NCHW2NHWC, NCHW2NHWCOp<float, CUDAContext>);
-}  // namespace caffe2
+
+} // namespace caffe2
diff --git a/caffe2/operators/order_switch_ops_cudnn.cc b/caffe2/operators/order_switch_ops_cudnn.cc
new file mode 100644
index 0000000..4cb0034
--- /dev/null
+++ b/caffe2/operators/order_switch_ops_cudnn.cc
@@ -0,0 +1,160 @@
+#include "caffe2/operators/order_switch_ops.h"
+
+#include <algorithm>
+#include <functional>
+#include <vector>
+
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/core/cudnn_wrappers.h"
+#include "caffe2/core/types.h"
+
+namespace caffe2 {
+
+namespace {
+
+class CuDNNOrderSwithOpBase : public Operator<CUDAContext> {
+ public:
+  USE_OPERATOR_FUNCTIONS(CUDAContext);
+
+  CuDNNOrderSwithOpBase(const OperatorDef& operator_def, Workspace* ws)
+      : Operator<CUDAContext>(operator_def, ws), cudnn_wrapper_(&context_) {
+    CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&X_desc_));
+    CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&Y_desc_));
+  }
+
+  virtual ~CuDNNOrderSwithOpBase() {
+    CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(X_desc_));
+    CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(Y_desc_));
+  }
+
+ protected:
+  void SetTensorDescriptor(
+      const cudnnDataType_t data_type,
+      const StorageOrder order,
+      const std::vector<int>& data_dims,
+      cudnnTensorDescriptor_t data_desc) const {
+    const int ndim = data_dims.size();
+    const int N = data_dims[0];
+    const int C = order == StorageOrder::NCHW ? data_dims[1] : data_dims.back();
+    if (ndim == 3) {
+      const int H = 1;
+      const int W = order == StorageOrder::NCHW ? data_dims[2] : data_dims[1];
+      CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
+          data_desc, GetCudnnTensorFormat(order), data_type, N, C, H, W));
+    } else if (ndim == 4) {
+      const int H = order == StorageOrder::NCHW ? data_dims[2] : data_dims[1];
+      const int W = order == StorageOrder::NCHW ? data_dims[3] : data_dims[2];
+      CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
+          data_desc, GetCudnnTensorFormat(order), data_type, N, C, H, W));
+    } else {
+      const int H = order == StorageOrder::NCHW ? data_dims[2] : data_dims[1];
+      const int W = order == StorageOrder::NCHW ? data_dims[3] : data_dims[2];
+      const auto l_iter = order == StorageOrder::NCHW ? data_dims.cbegin() + 4
+                                                      : data_dims.cbegin() + 3;
+      const auto r_iter =
+          order == StorageOrder::NCHW ? data_dims.cend() : data_dims.cend() - 1;
+      const int D = std::accumulate(l_iter, r_iter, 1, std::multiplies<int>());
+      const std::array<int, 5> dims = {N, C, H, W, D};
+      const std::array<int, 5> strides = order == StorageOrder::NCHW
+          ? std::array<int, 5>{C * H * W * D, H * W * D, W * D, D, 1}
+          : std::array<int, 5>{C * H * W * D, 1, W * D * C, D * C, C};
+      CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
+          data_desc, data_type, 5, dims.data(), strides.data()));
+    }
+  }
+
+  CuDNNWrapper cudnn_wrapper_;
+  cudnnTensorDescriptor_t X_desc_;
+  cudnnTensorDescriptor_t Y_desc_;
+
+  std::vector<int> cached_X_dims_;
+};
+
+class CuDNNNHWC2NCHWOp final : public CuDNNOrderSwithOpBase {
+ public:
+  CuDNNNHWC2NCHWOp(const OperatorDef& operator_def, Workspace* ws)
+      : CuDNNOrderSwithOpBase(operator_def, ws) {}
+
+  bool RunOnDevice() override {
+    return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0));
+  }
+
+  template <typename T>
+  bool DoRunWithType() {
+    const auto& X = Input(0);
+    auto* Y = Output(0);
+    const int ndim = X.ndim();
+    const int N = X.dim32(0);
+    const int C = X.dim32(ndim - 1);
+    const std::vector<int> X_dims(X.dims().cbegin(), X.dims().cend());
+    std::vector<int> Y_dims(ndim);
+    Y_dims[0] = N;
+    Y_dims[1] = C;
+    std::copy(X_dims.cbegin() + 1, X_dims.cend() - 1, Y_dims.begin() + 2);
+    Y->Resize(Y_dims);
+    if (cached_X_dims_ != X_dims) {
+      cached_X_dims_ = X_dims;
+      SetTensorDescriptor(
+          cudnnTypeWrapper<T>::type, StorageOrder::NHWC, X_dims, X_desc_);
+      SetTensorDescriptor(
+          cudnnTypeWrapper<T>::type, StorageOrder::NCHW, Y_dims, Y_desc_);
+    }
+    CUDNN_ENFORCE(cudnnTransformTensor(
+        cudnn_wrapper_.inline_cudnn_handle(),
+        cudnnTypeWrapper<T>::kOne(),
+        X_desc_,
+        X.template data<T>(),
+        cudnnTypeWrapper<T>::kZero(),
+        Y_desc_,
+        Y->template mutable_data<T>()));
+    return true;
+  }
+};
+
+class CuDNNNCHW2NHWCOp final : public CuDNNOrderSwithOpBase {
+ public:
+  CuDNNNCHW2NHWCOp(const OperatorDef& operator_def, Workspace* ws)
+      : CuDNNOrderSwithOpBase(operator_def, ws) {}
+
+  bool RunOnDevice() override {
+    return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0));
+  }
+
+  template <typename T>
+  bool DoRunWithType() {
+    const auto& X = Input(0);
+    auto* Y = Output(0);
+    const int ndim = X.ndim();
+    const int N = X.dim32(0);
+    const int C = X.dim32(1);
+    const std::vector<int> X_dims(X.dims().cbegin(), X.dims().cend());
+    std::vector<int> Y_dims(ndim);
+    Y_dims[0] = N;
+    Y_dims[ndim - 1] = C;
+    std::copy(X_dims.cbegin() + 2, X_dims.cend(), Y_dims.begin() + 1);
+    Y->Resize(Y_dims);
+    if (cached_X_dims_ != X_dims) {
+      cached_X_dims_ = X_dims;
+      SetTensorDescriptor(
+          cudnnTypeWrapper<T>::type, StorageOrder::NCHW, X_dims, X_desc_);
+      SetTensorDescriptor(
+          cudnnTypeWrapper<T>::type, StorageOrder::NHWC, Y_dims, Y_desc_);
+    }
+    CUDNN_ENFORCE(cudnnTransformTensor(
+        cudnn_wrapper_.inline_cudnn_handle(),
+        cudnnTypeWrapper<T>::kOne(),
+        X_desc_,
+        X.template data<T>(),
+        cudnnTypeWrapper<T>::kZero(),
+        Y_desc_,
+        Y->template mutable_data<T>()));
+    return true;
+  }
+};
+
+} // namespace
+
+REGISTER_CUDNN_OPERATOR(NHWC2NCHW, CuDNNNHWC2NCHWOp);
+REGISTER_CUDNN_OPERATOR(NCHW2NHWC, CuDNNNCHW2NHWCOp);
+
+} // namespace caffe2
diff --git a/caffe2/python/operator_test/order_switch_test.py b/caffe2/python/operator_test/order_switch_test.py
index d54ac26..5d3fd0e 100644
--- a/caffe2/python/operator_test/order_switch_test.py
+++ b/caffe2/python/operator_test/order_switch_test.py
@@ -1,14 +1,17 @@
 from __future__ import absolute_import, division, print_function, unicode_literals
 
 import caffe2.python.hypothesis_test_util as hu
+import hypothesis.strategies as st
+
 from caffe2.python import core
 from hypothesis import given
 
 
 class OrderSwitchOpsTest(hu.HypothesisTestCase):
-    @given(X=hu.tensor(min_dim=3, max_dim=5, min_value=1, max_value=5), **hu.gcs)
-    def test_nchw2nhwc(self, X, gc, dc):
-        op = core.CreateOperator("NCHW2NHWC", ["X"], ["Y"], device_option=gc)
+    @given(X=hu.tensor(min_dim=3, max_dim=5, min_value=1, max_value=5),
+           engine=st.sampled_from(["", "CUDNN"]), **hu.gcs)
+    def test_nchw2nhwc(self, X, engine, gc, dc):
+        op = core.CreateOperator("NCHW2NHWC", ["X"], ["Y"], engine=engine)
 
         def nchw2nhwc_ref(X):
             X_reshaped = X.transpose((0,) + tuple(range(2, X.ndim)) + (1,))
@@ -18,12 +21,14 @@
         self.assertGradientChecks(gc, op, [X], 0, [0])
         self.assertDeviceChecks(dc, op, [X], [0])
 
-    @given(X=hu.tensor(min_dim=3, max_dim=5, min_value=1, max_value=5), **hu.gcs)
-    def test_nhwc2nchw(self, X, gc, dc):
-        op = core.CreateOperator("NHWC2NCHW", ["X"], ["Y"], device_option=gc)
+    @given(X=hu.tensor(min_dim=3, max_dim=5, min_value=1, max_value=5),
+           engine=st.sampled_from(["", "CUDNN"]), **hu.gcs)
+    def test_nhwc2nchw(self, X, engine, gc, dc):
+        op = core.CreateOperator("NHWC2NCHW", ["X"], ["Y"], engine=engine)
 
         def nhwc2nchw_ref(X):
-            X_reshaped = X.transpose((0, X.ndim - 1) + tuple(range(1, X.ndim - 1)))
+            X_reshaped = X.transpose(
+                (0, X.ndim - 1) + tuple(range(1, X.ndim - 1)))
             return (X_reshaped,)
 
         self.assertReferenceChecks(gc, op, [X], nhwc2nchw_ref)