Use macro for reduce on 2d blocks (#16344)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16344
Use macro for reduce on 2d blocks
i-am-not-moving-c2-to-c10
Reviewed By: houseroad
Differential Revision: D13808988
fbshipit-source-id: b68c0fb6079c1b6e203a072083aba7a95c202bc2
diff --git a/caffe2/operators/group_norm_op.cu b/caffe2/operators/group_norm_op.cu
index 62436d2..ad35e12 100644
--- a/caffe2/operators/group_norm_op.cu
+++ b/caffe2/operators/group_norm_op.cu
@@ -9,22 +9,17 @@
#include "caffe2/operators/group_norm_op.h"
#include <cub/block/block_reduce.cuh>
+#include <cub/cub.cuh>
#include "caffe2/core/context_gpu.h"
#include "caffe2/utils/math.h"
+#include "caffe2/utils/math/reduce.cuh"
namespace caffe2 {
namespace {
template <typename T>
-using BlockReduce = cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS>;
-
-template <typename T, int kBlockDimX, int kBlockDimY>
-using BlockReduce2D = cub::
- BlockReduce<T, kBlockDimX, cub::BLOCK_REDUCE_WARP_REDUCTIONS, kBlockDimY>;
-
-template <typename T>
__global__ void ComputeFusedParamsCUDAKernel(
const int G,
const int K,
@@ -54,7 +49,7 @@
template <typename T>
__global__ void GroupNormForwardNCHWCUDAKernel(
- const int K,
+ const int M,
const int HxW,
const T* X,
const T* scale,
@@ -63,14 +58,14 @@
template <>
__global__ void GroupNormForwardNCHWCUDAKernel<float>(
- const int W,
+ const int M,
const int HxW,
const float* X,
const float* scale,
const float* bias,
float* Y) {
- const int nc = blockIdx.x / W;
- const int hw = blockIdx.x % W * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+ const int nc = blockIdx.x / M;
+ const int hw = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
if (hw < HxW) {
const int index = nc * HxW + hw;
#if __CUDA_ARCH__ >= 350
@@ -99,7 +94,8 @@
const float* bias,
float* Y) {
const int n = blockIdx.x / HxW;
- for (int c = threadIdx.x; c < C; c += blockDim.x) {
+ const int c = blockIdx.y * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+ if (c < C) {
const int index = blockIdx.x * C + c;
const int nc = n * C + c;
#if __CUDA_ARCH__ >= 350
@@ -206,7 +202,7 @@
__global__ void GroupNormBackwardNCHWCUDAKernel(
const int G,
const int K,
- const int W,
+ const int M,
const int HxW,
const T* dY,
const T* X,
@@ -218,12 +214,12 @@
T* dX) {
const int C = G * K;
const T denom = T(1) / static_cast<T>(K * HxW);
- const int nc = blockIdx.x / W;
+ const int nc = blockIdx.x / M;
const int n = nc / C;
const int c = nc % C;
const int g = c / K;
const int ng = n * G + g;
- const int hw = blockIdx.x % W * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+ const int hw = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
const int index = nc * HxW + hw;
if (hw < HxW) {
#if __CUDA_ARCH__ >= 350
@@ -261,7 +257,8 @@
const int g = blockIdx.y;
const int n = x / HxW;
const int ng = n * G + g;
- for (int i = threadIdx.x; i < K; i += blockDim.x) {
+ const int i = blockIdx.z * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+ if (i < K) {
const int c = g * K + i;
const int index = x * C + c;
#if __CUDA_ARCH__ >= 350
@@ -393,10 +390,10 @@
const float* scale,
const float* bias,
float* Y) {
- const int W = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
+ const int M = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
GroupNormForwardNCHWCUDAKernel<float>
- <<<N * C * W, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
- W, HxW, X, scale, bias, Y);
+ <<<N * C * M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+ M, HxW, X, scale, bias, Y);
}
template <>
@@ -408,8 +405,9 @@
const float* scale,
const float* bias,
float* Y) {
+ const int M = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
GroupNormForwardNHWCCUDAKernel<float>
- <<<N * HxW, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+ <<<dim3(N * HxW, M), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
C, HxW, X, scale, bias, Y);
}
@@ -440,31 +438,28 @@
// Computes dL/ds and dL/db.
// dL/ds = Sum(dL/dY * gamma * X)
// dL/db = Sum(dL/dY * gamma)
- if (HxW >= 128) {
- ComputeInternalGradientsNCHWCUDAKernel<float, 1, 128>
- <<<dim3(N, G), dim3(1, 128), 0, context_.cuda_stream()>>>(
- G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
- } else if (HxW >= 64) {
- ComputeInternalGradientsNCHWCUDAKernel<float, 2, 64>
- <<<dim3(N, G), dim3(2, 64), 0, context_.cuda_stream()>>>(
- G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
- } else if (HxW >= 32) {
- ComputeInternalGradientsNCHWCUDAKernel<float, 4, 32>
- <<<dim3(N, G), dim3(4, 32), 0, context_.cuda_stream()>>>(
- G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
- } else {
- ComputeInternalGradientsNCHWCUDAKernel<float, 8, 16>
- <<<dim3(N, G), dim3(8, 16), 0, context_.cuda_stream()>>>(
- G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
- }
+ DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK(
+ HxW,
+ ComputeInternalGradientsNCHWCUDAKernel,
+ float,
+ dim3(N, G),
+ context_.cuda_stream(),
+ G,
+ K,
+ HxW,
+ dY_data,
+ X_data,
+ gamma_data,
+ ds_data,
+ db_data);
// Computes dL/dX.
- const int W = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
+ const int M = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
GroupNormBackwardNCHWCUDAKernel<float>
- <<<N * C * W, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+ <<<N * C * M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
G,
K,
- W,
+ M,
HxW,
dY_data,
X_data,
@@ -476,84 +471,45 @@
dX_data);
// Computes dL/dgamma and dL/dbeta.
- if (HxW >= 128) {
- GammaBetaBackwardNCHWCUDAKernel<float, 1, 128>
- <<<C, dim3(1, 128), 0, context_.cuda_stream()>>>(
- N,
- G,
- K,
- HxW,
- dY_data,
- X_data,
- mu_data,
- rsig_data,
- dgamma_data,
- dbeta_data);
- } else if (HxW >= 64) {
- GammaBetaBackwardNCHWCUDAKernel<float, 2, 64>
- <<<C, dim3(2, 64), 0, context_.cuda_stream()>>>(
- N,
- G,
- K,
- HxW,
- dY_data,
- X_data,
- mu_data,
- rsig_data,
- dgamma_data,
- dbeta_data);
- } else if (HxW >= 32) {
- GammaBetaBackwardNCHWCUDAKernel<float, 4, 32>
- <<<C, dim3(4, 32), 0, context_.cuda_stream()>>>(
- N,
- G,
- K,
- HxW,
- dY_data,
- X_data,
- mu_data,
- rsig_data,
- dgamma_data,
- dbeta_data);
- } else {
- GammaBetaBackwardNCHWCUDAKernel<float, 8, 16>
- <<<C, dim3(8, 16), 0, context_.cuda_stream()>>>(
- N,
- G,
- K,
- HxW,
- dY_data,
- X_data,
- mu_data,
- rsig_data,
- dgamma_data,
- dbeta_data);
- }
+ DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK(
+ HxW,
+ GammaBetaBackwardNCHWCUDAKernel,
+ float,
+ C,
+ context_.cuda_stream(),
+ N,
+ G,
+ K,
+ HxW,
+ dY_data,
+ X_data,
+ mu_data,
+ rsig_data,
+ dgamma_data,
+ dbeta_data);
} else {
// Computes dL/ds and dL/db.
// dL/ds = Sum(dL/dY * gamma * X)
// dL/db = Sum(dL/dY * gamma)
- if (K >= 128) {
- ComputeInternalGradientsNHWCCUDAKernel<float, 1, 128>
- <<<dim3(N, G), dim3(1, 128), 0, context_.cuda_stream()>>>(
- G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
- } else if (K >= 64) {
- ComputeInternalGradientsNHWCCUDAKernel<float, 2, 64>
- <<<dim3(N, G), dim3(2, 64), 0, context_.cuda_stream()>>>(
- G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
- } else if (K >= 32) {
- ComputeInternalGradientsNHWCCUDAKernel<float, 4, 32>
- <<<dim3(N, G), dim3(4, 32), 0, context_.cuda_stream()>>>(
- G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
- } else {
- ComputeInternalGradientsNHWCCUDAKernel<float, 8, 16>
- <<<dim3(N, G), dim3(8, 16), 0, context_.cuda_stream()>>>(
- G, K, HxW, dY_data, X_data, gamma_data, ds_data, db_data);
- }
+ DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK(
+ K,
+ ComputeInternalGradientsNHWCCUDAKernel,
+ float,
+ dim3(N, G),
+ context_.cuda_stream(),
+ G,
+ K,
+ HxW,
+ dY_data,
+ X_data,
+ gamma_data,
+ ds_data,
+ db_data);
// Computes dL/dX.
+ const int M = math::DivUp(K, CAFFE_CUDA_NUM_THREADS);
GroupNormBackwardNHWCCUDAKernel<float>
- <<<dim3(N * HxW, G),
+ <<<dim3(N * HxW, G, M),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
@@ -570,59 +526,22 @@
dX_data);
// Computes dL/dgamma and dL/dbeta.
- if (HxW >= 128) {
- GammaBetaBackwardNHWCCUDAKernel<float, 1, 128>
- <<<C, dim3(1, 128), 0, context_.cuda_stream()>>>(
- N,
- G,
- K,
- HxW,
- dY_data,
- X_data,
- mu_data,
- rsig_data,
- dgamma_data,
- dbeta_data);
- } else if (HxW >= 64) {
- GammaBetaBackwardNHWCCUDAKernel<float, 2, 64>
- <<<C, dim3(2, 64), 0, context_.cuda_stream()>>>(
- N,
- G,
- K,
- HxW,
- dY_data,
- X_data,
- mu_data,
- rsig_data,
- dgamma_data,
- dbeta_data);
- } else if (HxW >= 32) {
- GammaBetaBackwardNHWCCUDAKernel<float, 4, 32>
- <<<C, dim3(4, 32), 0, context_.cuda_stream()>>>(
- N,
- G,
- K,
- HxW,
- dY_data,
- X_data,
- mu_data,
- rsig_data,
- dgamma_data,
- dbeta_data);
- } else {
- GammaBetaBackwardNHWCCUDAKernel<float, 8, 16>
- <<<C, dim3(8, 16), 0, context_.cuda_stream()>>>(
- N,
- G,
- K,
- HxW,
- dY_data,
- X_data,
- mu_data,
- rsig_data,
- dgamma_data,
- dbeta_data);
- }
+ DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK(
+ HxW,
+ GammaBetaBackwardNHWCCUDAKernel,
+ float,
+ C,
+ context_.cuda_stream(),
+ N,
+ G,
+ K,
+ HxW,
+ dY_data,
+ X_data,
+ mu_data,
+ rsig_data,
+ dgamma_data,
+ dbeta_data);
}
return true;
}
diff --git a/caffe2/operators/spatial_batch_norm_op_impl.cuh b/caffe2/operators/spatial_batch_norm_op_impl.cuh
index 6be58d2..94a4697 100644
--- a/caffe2/operators/spatial_batch_norm_op_impl.cuh
+++ b/caffe2/operators/spatial_batch_norm_op_impl.cuh
@@ -8,19 +8,13 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/utils/math.h"
+#include "caffe2/utils/math/reduce.cuh"
namespace caffe2 {
namespace {
template <typename T>
-using BlockReduce = cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS>;
-
-template <typename T, int kBlockDimX, int kBlockDimY>
-using BlockReduce2D = cub::
- BlockReduce<T, kBlockDimX, cub::BLOCK_REDUCE_WARP_REDUCTIONS, kBlockDimY>;
-
-template <typename T>
__global__ void ComputeFusedParamCUDAKernel(
const int C,
const T epsilon,
@@ -316,7 +310,7 @@
template <typename T>
__global__ void ComputeXGradientNCHWCUDAKernel(
const int C,
- const int K,
+ const int M,
const int HxW,
const T* dY,
const T* X,
@@ -328,7 +322,7 @@
template <>
__global__ void ComputeXGradientNCHWCUDAKernel<float>(
const int C,
- const int K,
+ const int M,
const int HxW,
const float* dY,
const float* X,
@@ -336,9 +330,9 @@
const float* beta,
const float* gamma,
float* dX) {
- const int nc = blockIdx.x / K;
+ const int nc = blockIdx.x / M;
const int c = nc % C;
- const int x = blockIdx.x % K * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+ const int x = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
if (x < HxW) {
const int index = nc * HxW + x;
#if __CUDA_ARCH__ >= 350
@@ -399,9 +393,9 @@
const T* var,
T* alpha,
T* beta) {
- const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
+ const int M = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
ComputeFusedParamCUDAKernel<T>
- <<<K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+ <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
C, static_cast<T>(epsilon_), scale, bias, mean, var, alpha, beta);
}
@@ -415,10 +409,10 @@
const T* batch_var_sum,
T* mean,
T* var) {
- const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
+ const int M = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
const T scale = T(1) / static_cast<T>(num_batches_ * N * HxW);
ComputeBatchMomentsCUDAKernel<T>
- <<<K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+ <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
C, scale, batch_mean_sum, batch_var_sum, mean, var);
}
@@ -435,9 +429,9 @@
T* rstd,
T* alpha,
T* beta) {
- const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
+ const int M = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
ComputeRunningMomentsAndFusedParamCUDAKernel<T>
- <<<K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+ <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
C,
static_cast<T>(momentum_),
static_cast<T>(epsilon_),
@@ -469,11 +463,11 @@
T* alpha,
T* beta,
T* gamma) {
- const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
+ const int M = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
const T batch_scale = T(1) / static_cast<T>(num_batches_);
const T mean_scale = T(1) / static_cast<T>(N * HxW);
ComputeMultiBatchScaleBiasGradientsAndFusedParamsCUDAKernel<T>
- <<<K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+ <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
C,
batch_scale,
mean_scale,
@@ -507,71 +501,25 @@
T* gamma,
T* scratch) {
if (order_ == StorageOrder::NCHW) {
- if (HxW >= 128) {
- ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel<T, 1, 128>
- <<<C, dim3(1, 128), 0, context_.cuda_stream()>>>(
- N,
- C,
- HxW,
- dY,
- X,
- scale,
- mean,
- rstd,
- dscale,
- dbias,
- alpha,
- beta,
- gamma);
- } else if (HxW >= 64) {
- ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel<T, 2, 64>
- <<<C, dim3(2, 64), 0, context_.cuda_stream()>>>(
- N,
- C,
- HxW,
- dY,
- X,
- scale,
- mean,
- rstd,
- dscale,
- dbias,
- alpha,
- beta,
- gamma);
- } else if (HxW >= 32) {
- ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel<T, 4, 32>
- <<<C, dim3(4, 32), 0, context_.cuda_stream()>>>(
- N,
- C,
- HxW,
- dY,
- X,
- scale,
- mean,
- rstd,
- dscale,
- dbias,
- alpha,
- beta,
- gamma);
- } else {
- ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel<T, 8, 16>
- <<<C, dim3(8, 16), 0, context_.cuda_stream()>>>(
- N,
- C,
- HxW,
- dY,
- X,
- scale,
- mean,
- rstd,
- dscale,
- dbias,
- alpha,
- beta,
- gamma);
- }
+ DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK(
+ HxW,
+ ComputeScaleBiasGradientsAndFusedParamsNCHWCUDAKernel,
+ T,
+ C,
+ context_.cuda_stream(),
+ N,
+ C,
+ HxW,
+ dY,
+ X,
+ scale,
+ mean,
+ rstd,
+ dscale,
+ dbias,
+ alpha,
+ beta,
+ gamma);
} else {
ReinitializeTensor(&ones_, N * HxW, at::dtype<T>().device(CUDA));
math::Set<T, CUDAContext>(
@@ -602,9 +550,9 @@
0.0f,
dbias,
&context_);
- const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
+ const int M = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
ComputeScaleGradientAndFusedParamsNHWCCUDAKernel<T>
- <<<K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+ <<<M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
C,
T(1) / static_cast<T>(N * HxW),
dscale,
@@ -632,15 +580,17 @@
const T* gamma,
T* dX) {
if (order_ == StorageOrder::NCHW) {
- const int K = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
+ const int M = math::DivUp(HxW, CAFFE_CUDA_NUM_THREADS);
ComputeXGradientNCHWCUDAKernel<T>
- <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
- C, K, HxW, dY, X, alpha, beta, gamma, dX);
+ <<<N * C * M, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
+ C, M, HxW, dY, X, alpha, beta, gamma, dX);
} else {
- const int K = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
+ const int M = math::DivUp(C, CAFFE_CUDA_NUM_THREADS);
ComputeXGradientNHWCCUDAKernel<T>
- <<<dim3(N * HxW, K), CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
- C, HxW, dY, X, alpha, beta, gamma, dX);
+ <<<dim3(N * HxW, M),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context_.cuda_stream()>>>(C, HxW, dY, X, alpha, beta, gamma, dX);
}
}
diff --git a/caffe2/utils/math/elementwise.cu b/caffe2/utils/math/elementwise.cu
index 0798b6f..cc7613c 100644
--- a/caffe2/utils/math/elementwise.cu
+++ b/caffe2/utils/math/elementwise.cu
@@ -23,8 +23,8 @@
template <typename T>
__global__ void AffineChannelNCHWCUDAKernel(
const int C,
+ const int M,
const int HxW,
- const int K,
const T* X,
const T* scale,
const T* bias,
@@ -33,15 +33,15 @@
template <>
__global__ void AffineChannelNCHWCUDAKernel<float>(
const int C,
+ const int M,
const int HxW,
- const int K,
const float* X,
const float* scale,
const float* bias,
float* Y) {
- const int nc = blockIdx.x / K;
+ const int nc = blockIdx.x / M;
const int c = nc % C;
- const int w = blockIdx.x % K * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
+ const int w = blockIdx.x % M * CAFFE_CUDA_NUM_THREADS + threadIdx.x;
if (w < HxW) {
const int index = nc * HxW + w;
#if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
@@ -180,10 +180,10 @@
const T* bias, \
T* Y, \
CUDAContext* context) { \
- const int K = DivUp(HxW, CAFFE_CUDA_NUM_THREADS); \
+ const int M = DivUp(HxW, CAFFE_CUDA_NUM_THREADS); \
AffineChannelNCHWCUDAKernel<T> \
- <<<N * C * K, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>( \
- C, HxW, K, X, scale, bias, Y); \
+ <<<N * C * M, CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>( \
+ C, M, HxW, X, scale, bias, Y); \
} \
template <> \
CAFFE2_CUDA_EXPORT void AffineChannel<T, CUDAContext, StorageOrder::NHWC>( \
@@ -195,9 +195,9 @@
const T* bias, \
T* Y, \
CUDAContext* context) { \
- const int K = DivUp(C, CAFFE_CUDA_NUM_THREADS); \
+ const int M = DivUp(C, CAFFE_CUDA_NUM_THREADS); \
AffineChannelNHWCCUDAKernel<T> \
- <<<dim3(N* HxW, K), \
+ <<<dim3(N* HxW, M), \
CAFFE_CUDA_NUM_THREADS, \
0, \
context->cuda_stream()>>>(C, X, scale, bias, Y); \
diff --git a/caffe2/utils/math/reduce.cu b/caffe2/utils/math/reduce.cu
index f597ec7..31a6539 100644
--- a/caffe2/utils/math/reduce.cu
+++ b/caffe2/utils/math/reduce.cu
@@ -10,6 +10,7 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/utils/fixed_divisor.h"
+#include "caffe2/utils/math/reduce.cuh"
#include "caffe2/utils/math_utils.h"
namespace caffe2 {
@@ -18,13 +19,6 @@
namespace {
template <typename T>
-using BlockReduce = cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS>;
-
-template <typename T, int kBlockDimX, int kBlockDimY>
-using BlockReduce2D = cub::
- BlockReduce<T, kBlockDimX, cub::BLOCK_REDUCE_WARP_REDUCTIONS, kBlockDimY>;
-
-template <typename T>
__global__ void
RowwiseMomentsCUDAKernel(const int cols, const T* X, T* mean, T* var) {
__shared__ typename BlockReduce<T>::TempStorage m_storage;
@@ -229,23 +223,18 @@
int N;
int K;
if (utils::IsBothEndsReduce(ndim, X_dims, Y_dims, &M, &N, &K)) {
- if (K >= 128) {
- BothEndsMomentsCUDAKernel<T, 1, 128>
- <<<N, dim3(1, 128), 0, context->cuda_stream()>>>(
- M, N, K, X, mean, var);
- } else if (K >= 64) {
- BothEndsMomentsCUDAKernel<T, 2, 64>
- <<<N, dim3(2, 64), 0, context->cuda_stream()>>>(
- M, N, K, X, mean, var);
- } else if (K >= 32) {
- BothEndsMomentsCUDAKernel<T, 4, 32>
- <<<N, dim3(4, 32), 0, context->cuda_stream()>>>(
- M, N, K, X, mean, var);
- } else {
- BothEndsMomentsCUDAKernel<T, 8, 16>
- <<<N, dim3(8, 16), 0, context->cuda_stream()>>>(
- M, N, K, X, mean, var);
- }
+ DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK(
+ K,
+ BothEndsMomentsCUDAKernel,
+ T,
+ N,
+ context->cuda_stream(),
+ M,
+ N,
+ K,
+ X,
+ mean,
+ var);
return;
}
std::vector<int> axes(ndim);
diff --git a/caffe2/utils/math/reduce.cuh b/caffe2/utils/math/reduce.cuh
new file mode 100644
index 0000000..d191cbc
--- /dev/null
+++ b/caffe2/utils/math/reduce.cuh
@@ -0,0 +1,35 @@
+#ifndef CAFFE2_UTILS_MATH_REDUCE_CUH_
+#define CAFFE2_UTILS_MATH_REDUCE_CUH_
+
+#include <cub/block/block_reduce.cuh>
+#include <cub/cub.cuh>
+
+#include "caffe2/core/common_gpu.h"
+
+namespace caffe2 {
+
+template <typename T>
+using BlockReduce = cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS>;
+
+template <typename T, int kBlockDimX, int kBlockDimY>
+using BlockReduce2D = cub::
+ BlockReduce<T, kBlockDimX, cub::BLOCK_REDUCE_WARP_REDUCTIONS, kBlockDimY>;
+
+#define DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK( \
+ size, Func, T, grid_dim, cuda_stream, ...) \
+ do { \
+ if (size >= 128) { \
+ Func<T, 1, 128> \
+ <<<grid_dim, dim3(1, 128), 0, cuda_stream>>>(__VA_ARGS__); \
+ } else if (size >= 64) { \
+ Func<T, 2, 64><<<grid_dim, dim3(2, 64), 0, cuda_stream>>>(__VA_ARGS__); \
+ } else if (size >= 32) { \
+ Func<T, 4, 32><<<grid_dim, dim3(4, 32), 0, cuda_stream>>>(__VA_ARGS__); \
+ } else { \
+ Func<T, 8, 16><<<grid_dim, dim3(8, 16), 0, cuda_stream>>>(__VA_ARGS__); \
+ } \
+ } while (false)
+
+} // namespace caffe2
+
+#endif // CAFFE2_UTILS_MATH_REDUCE_CUH_
diff --git a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py
index bfd4f7f..926a797 100644
--- a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py
+++ b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py
@@ -2258,6 +2258,7 @@
("/GpuDefs", ("/hip/GpuDefs", API_CAFFE2)),
("/GpuScanUtils", ("/hip/GpuScanUtils", API_CAFFE2)),
("/GpuBitonicSort", ("/hip/GpuBitonicSort", API_CAFFE2)),
+ ("/math/reduce.cuh", ("/math/hip/reduce.cuh", API_CAFFE2)),
("/gather_op.cuh", ("/hip/gather_op.cuh", API_CAFFE2)),
("caffe2/core/common_cudnn.h", ("caffe2/core/hip/common_miopen.h", API_CAFFE2)),
("REGISTER_CUDA_OPERATOR" , ("REGISTER_HIP_OPERATOR", API_CAFFE2)),