Optimize relu op on GPU (#18506)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18506
Optimize relu op on GPU
Reviewed By: houseroad
Differential Revision: D14633171
fbshipit-source-id: bd3afa9a0bae1325d32ad4153736a0c7ecb0ec64
diff --git a/caffe2/operators/relu_op.cu b/caffe2/operators/relu_op.cu
index 6ee8b15..613ed39 100644
--- a/caffe2/operators/relu_op.cu
+++ b/caffe2/operators/relu_op.cu
@@ -4,29 +4,35 @@
#include <functional>
#include "caffe2/core/context_gpu.h"
+#include "caffe2/utils/math.h"
namespace caffe2 {
namespace {
#ifdef __HIPCC__
-typedef __half2 half2;
-#endif
+using half2 = __half2;
+#endif // __HIPCC__
template <typename T>
-__global__ void ReluCUDAKernel(const int N, const T* X, T* Y) {
- CUDA_1D_KERNEL_LOOP(i, N) {
-#if __CUDA_ARCH__ >= 350
- Y[i] = __ldg(X + i) > 0 ? __ldg(X + i) : T(0);
-#else
- Y[i] = X[i] > 0 ? X[i] : T(0);
-#endif
- }
-}
+__global__ void ReluCUDAKernel(const int N, const T* X, T* Y);
-__global__ void ReluHalfCUDAKernel(const int N, const half* X, half* Y) {
- const half kZero = __float2half(0.0f);
- CUDA_1D_KERNEL_LOOP(i, N) {
+#define DELEGATE_RELU_CUDA_KERNEL(T, MaxFunc) \
+ template <> \
+ __global__ void ReluCUDAKernel<T>(const int N, const T* X, T* Y) { \
+ const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x; \
+ if (i < N) { \
+ Y[i] = MaxFunc(X[i], T(0)); \
+ } \
+ }
+DELEGATE_RELU_CUDA_KERNEL(float, fmaxf)
+#undef DELEGATE_RELU_CUDA_KERNEL
+
+template <>
+__global__ void ReluCUDAKernel<half>(const int N, const half* X, half* Y) {
+ const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+ if (i < N) {
+ const half kZero = __float2half(0.0f);
#if __CUDA_ARCH__ >= 530
Y[i] = __hgt(__ldg(X + i), kZero) ? __ldg(X + i) : kZero;
#else
@@ -35,14 +41,17 @@
}
}
-__global__ void ReluHalf2CUDAKernel(const int N, const half2* X, half2* Y) {
- const half2 kZero = __float2half2_rn(0.0f);
- CUDA_1D_KERNEL_LOOP(i, N) {
+template <>
+__global__ void ReluCUDAKernel<half2>(const int N, const half2* X, half2* Y) {
+ const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+ if (i < N) {
+ const half2 kZero = __float2half2_rn(0.0f);
#if __CUDA_ARCH__ >= 530
Y[i] = __hmul2(__hgt2(__ldg(X + i), kZero), __ldg(X + i));
#else
const float2 xx = __half22float2(X[i]);
- Y[i] = __floats2half2_rn(xx.x > 0 ? xx.x : 0.f, xx.y > 0 ? xx.y : 0.f);
+ Y[i] =
+ __floats2half2_rn(xx.x > 0.0f ? xx.x : 0.0f, xx.y > 0.0f ? xx.y : 0.0f);
#endif
}
}
@@ -50,22 +59,25 @@
template <typename T>
__global__ void
ReluGradientCUDAKernel(const int N, const T* dY, const T* Y, T* dX) {
- CUDA_1D_KERNEL_LOOP(i, N) {
+ const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+ if (i < N) {
#if __CUDA_ARCH__ >= 350
- dX[i] = __ldg(Y + i) > 0 ? __ldg(dY + i) : 0;
+ dX[i] = __ldg(Y + i) > T(0) ? __ldg(dY + i) : T(0);
#else
- dX[i] = Y[i] > 0 ? dY[i] : 0;
+ dX[i] = Y[i] > T(0) ? dY[i] : T(0);
#endif
}
}
-__global__ void ReluGradientHalfCUDAKernel(
+template <>
+__global__ void ReluGradientCUDAKernel<half>(
const int N,
const half* dY,
const half* Y,
half* dX) {
- const half kZero = __float2half(0.0f);
- CUDA_1D_KERNEL_LOOP(i, N) {
+ const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+ if (i < N) {
+ const half kZero = __float2half(0.0f);
#if __CUDA_ARCH__ >= 530
dX[i] = __hgt(__ldg(Y + i), kZero) ? __ldg(dY + i) : kZero;
#else
@@ -74,19 +86,22 @@
}
}
-__global__ void ReluGradientHalf2CUDAKernel(
+template <>
+__global__ void ReluGradientCUDAKernel<half2>(
const int N,
const half2* dY,
const half2* Y,
half2* dX) {
- const half2 kZero = __float2half2_rn(0.0f);
- CUDA_1D_KERNEL_LOOP(i, N) {
+ const int i = blockIdx.x * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+ if (i < N) {
+ const half2 kZero = __float2half2_rn(0.0f);
#if __CUDA_ARCH__ >= 530
dX[i] = __hmul2(__hgt2(__ldg(Y + i), kZero), __ldg(dY + i));
#else
const float2 dy = __half22float2(dY[i]);
const float2 yy = __half22float2(Y[i]);
- dX[i] = __floats2half2_rn(yy.x > 0 ? dy.x : 0.f, yy.y > 0 ? dy.y : 0.f);
+ dX[i] =
+ __floats2half2_rn(yy.x > 0.0f ? dy.x : 0.0f, yy.y > 0.0f ? dy.y : 0.0f);
#endif
}
}
@@ -97,11 +112,9 @@
template <typename T>
bool ReluFunctor<CUDAContext>::
operator()(const int N, const T* X, T* Y, CUDAContext* context) const {
+ const int M = math::DivUp(N, CAFFE_CUDA_NUM_THREADS);
ReluCUDAKernel<T>
- <<<CAFFE_GET_BLOCKS(N),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context->cuda_stream()>>>(N, X, Y);
+ <<<M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(N, X, Y);
return true;
}
@@ -112,22 +125,18 @@
const at::Half* X,
at::Half* Y,
CUDAContext* context) const {
- if ((N & 1) == 0) {
- ReluHalf2CUDAKernel<<<
- CAFFE_GET_BLOCKS((N >> 1)),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context->cuda_stream()>>>(
- (N >> 1),
- reinterpret_cast<const half2*>(X),
- reinterpret_cast<half2*>(Y));
+ if (N % 2 == 0) {
+ const int M = math::DivUp(N / 2, CAFFE_CUDA_NUM_THREADS);
+ ReluCUDAKernel<half2>
+ <<<M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+ N / 2,
+ reinterpret_cast<const half2*>(X),
+ reinterpret_cast<half2*>(Y));
} else {
- ReluHalfCUDAKernel<<<
- CAFFE_GET_BLOCKS(N),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context->cuda_stream()>>>(
- N, reinterpret_cast<const half*>(X), reinterpret_cast<half*>(Y));
+ const int M = math::DivUp(N, CAFFE_CUDA_NUM_THREADS);
+ ReluCUDAKernel<half>
+ <<<M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+ N, reinterpret_cast<const half*>(X), reinterpret_cast<half*>(Y));
}
return true;
}
@@ -141,13 +150,11 @@
const T* dY,
T* dX,
CUDAContext* context) const {
- const int size = std::accumulate(
+ const int N = std::accumulate(
Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
+ const int M = math::DivUp(N, CAFFE_CUDA_NUM_THREADS);
ReluGradientCUDAKernel<T>
- <<<CAFFE_GET_BLOCKS(size),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context->cuda_stream()>>>(size, dY, Y, dX);
+ <<<M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(N, dY, Y, dX);
return true;
}
@@ -160,28 +167,24 @@
const at::Half* dY,
at::Half* dX,
CUDAContext* context) const {
- const int size = std::accumulate(
+ const int N = std::accumulate(
Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
- if ((size & 1) == 0) {
- ReluGradientHalf2CUDAKernel<<<
- CAFFE_GET_BLOCKS((size >> 1)),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context->cuda_stream()>>>(
- (size >> 1),
- reinterpret_cast<const half2*>(dY),
- reinterpret_cast<const half2*>(Y),
- reinterpret_cast<half2*>(dX));
+ if (N % 2 == 0) {
+ const int M = math::DivUp(N / 2, CAFFE_CUDA_NUM_THREADS);
+ ReluGradientCUDAKernel<half2>
+ <<<M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+ N / 2,
+ reinterpret_cast<const half2*>(dY),
+ reinterpret_cast<const half2*>(Y),
+ reinterpret_cast<half2*>(dX));
} else {
- ReluGradientHalfCUDAKernel<<<
- CAFFE_GET_BLOCKS(size),
- CAFFE_CUDA_NUM_THREADS,
- 0,
- context->cuda_stream()>>>(
- size,
- reinterpret_cast<const half*>(dY),
- reinterpret_cast<const half*>(Y),
- reinterpret_cast<half*>(dX));
+ const int M = math::DivUp(N, CAFFE_CUDA_NUM_THREADS);
+ ReluGradientCUDAKernel<half>
+ <<<M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
+ N,
+ reinterpret_cast<const half*>(dY),
+ reinterpret_cast<const half*>(Y),
+ reinterpret_cast<half*>(dX));
}
return true;
}
diff --git a/caffe2/operators/relu_op_cudnn.cc b/caffe2/operators/relu_op_cudnn.cc
deleted file mode 100644
index 75dc6cd..0000000
--- a/caffe2/operators/relu_op_cudnn.cc
+++ /dev/null
@@ -1,12 +0,0 @@
-#include "caffe2/operators/relu_op.h"
-
-#include "caffe2/operators/activation_ops_cudnn.h"
-
-namespace caffe2 {
-
-REGISTER_CUDNN_OPERATOR(Relu, CuDNNActivationOp<CUDNN_ACTIVATION_RELU>);
-REGISTER_CUDNN_OPERATOR(
- ReluGradient,
- CuDNNActivationGradientOp<CUDNN_ACTIVATION_RELU>);
-
-} // namespace caffe2