blob: ca2cd8cd59b8878c61e183bfb641694df294790e [file] [log] [blame]
#include "THCUNN.h"
#include "common.h"
#include "THCDeviceTensor.cuh"
#include "THCDeviceTensorUtils.cuh"
const int WARP_SIZE = 32;
typedef THCDeviceTensor<float, 3> DeviceTensor3;
typedef THCDeviceTensor<float, 1> DeviceTensor1;
// The maximum number of threads in a block
const int MAX_BLOCK_SIZE = 512;
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static int getNumThreads(int nElem) {
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
for (int i = 0; i != 5; ++i) {
if (nElem <= threadSizes[i]) {
return threadSizes[i];
}
}
return MAX_BLOCK_SIZE;
}
// Returns the index of the most significant 1 bit in `val`.
__device__ __forceinline__ int getMSB(int val) {
return 31 - __clz(val);
}
struct Float2 {
float v1, v2;
__device__ Float2() {}
__device__ Float2(float v1, float v2) : v1(v1), v2(v2) {}
__device__ Float2(float v) : v1(v), v2(v) {}
__device__ Float2& operator+=(const Float2& a) {
v1 += a.v1;
v2 += a.v2;
return *this;
}
};
struct SumOp {
__device__ SumOp(const DeviceTensor3 t) : tensor(t) {}
__device__ __forceinline__ float operator()(int batch, int plane, int n) {
return tensor[batch][plane][n];
}
const DeviceTensor3 tensor;
};
struct VarOp {
__device__ VarOp(float m, const DeviceTensor3 t) : mean(m), tensor(t) {}
__device__ __forceinline__ float operator()(int batch, int plane, int n) {
float val = tensor[batch][plane][n];
return (val - mean) * (val - mean);
}
const float mean;
const DeviceTensor3 tensor;
};
struct GradOp {
__device__ GradOp(float m, const DeviceTensor3 i, const DeviceTensor3 g)
: mean(m), input(i), gradOutput(g) {}
__device__ __forceinline__ Float2 operator()(int batch, int plane, int n) {
float g = gradOutput[batch][plane][n];
float c = input[batch][plane][n] - mean;
return Float2(g, g * c);
}
const float mean;
const DeviceTensor3 input;
const DeviceTensor3 gradOutput;
};
// Sum across all threads within a warp
static __device__ __forceinline__ float warpSum(float val) {
#if __CUDA_ARCH__ >= 300
for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
val += __shfl_xor(val, 1 << i, WARP_SIZE);
}
#else
__shared__ float values[MAX_BLOCK_SIZE];
values[threadIdx.x] = val;
__threadfence_block();
const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
for (int i = 1; i < WARP_SIZE; i++) {
val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
}
#endif
return val;
}
static __device__ __forceinline__ Float2 warpSum(Float2 value) {
value.v1 = warpSum(value.v1);
value.v2 = warpSum(value.v2);
return value;
}
// Sum across (batch, x/y/z) applying Op() pointwise
template<typename T, typename Op>
__device__ T reduce(Op op, DeviceTensor3 tensor, int plane) {
T sum = (T)0;
for (int batch = 0; batch < tensor.getSize(0); ++batch) {
for (int x = threadIdx.x; x < tensor.getSize(2); x += blockDim.x) {
sum += op(batch, plane, x);
}
}
// sum over NumThreads within a warp
sum = warpSum(sum);
// 'transpose', and reduce within warp again
__shared__ T shared[32];
__syncthreads();
if (threadIdx.x % WARP_SIZE == 0) {
shared[threadIdx.x / WARP_SIZE] = sum;
}
if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
// zero out the other entries in shared
shared[threadIdx.x] = (T)0;
}
__syncthreads();
if (threadIdx.x / WARP_SIZE == 0) {
sum = warpSum(shared[threadIdx.x]);
if (threadIdx.x == 0) {
shared[0] = sum;
}
}
__syncthreads();
// Everyone picks it up, should be broadcast into the whole gradInput
return shared[0];
}
template <int Dim>
static THCDeviceTensor<float, Dim> devicetensor(THCState *state, THCudaTensor *t) {
if (!t) {
return THCDeviceTensor<float, Dim>();
}
int inDim = THCudaTensor_nDimension(state, t);
if (inDim == Dim) {
return toDeviceTensor<float, Dim>(state, t);
}
// View in which the last dimensions are collapsed or expanded as needed
THAssert(THCudaTensor_isContiguous(state, t));
int size[Dim];
for (int i = 0; i < Dim || i < inDim; ++i) {
if (i < Dim && i < inDim) {
size[i] = t->size[i];
} else if (i < Dim) {
size[i] = 1;
} else {
size[Dim - 1] *= t->size[i];
}
}
return THCDeviceTensor<float, Dim>(THCudaTensor_data(state, t), size);
}
__global__ void BatchNormalizationUpdateOutputInference_kernel(
const DeviceTensor3 input,
DeviceTensor3 output,
DeviceTensor1 runningMean,
DeviceTensor1 runningVar,
const DeviceTensor1 weight,
const DeviceTensor1 bias,
float epsilon) {
int plane = blockIdx.x;
float invstd = 1.0f / sqrt(runningVar[plane].ldg() + epsilon);
float mean = runningMean[plane].ldg();
float gamma = weight.numElements() > 0 ? weight[plane].ldg() : 1.0f;
float beta = bias.numElements() > 0 ? bias[plane].ldg() : 0.0f;
// Write normalized and update the output
for (int batch = 0; batch < input.getSize(0); batch++) {
for (int x = threadIdx.x; x < input.getSize(2); x += blockDim.x) {
float inp = input[batch][plane][x].ldg();
output[batch][plane][x] = gamma * (inp - mean) * invstd + beta;
}
}
}
__global__ void BatchNormalizationUpdateOutput_kernel(
const DeviceTensor3 input,
DeviceTensor3 output,
const DeviceTensor1 weight,
const DeviceTensor1 bias,
const float epsilon,
const float momentum,
DeviceTensor1 runningMean,
DeviceTensor1 runningVar,
DeviceTensor1 saveMean,
DeviceTensor1 saveStd) {
int plane = blockIdx.x;
int N = input.getSize(0) * input.getSize(2);
float norm = 1.0f / N;
// Compute the mean and variance across (batch, x/y/z)
float mean = reduce<float>(SumOp(input), input, plane) * norm;
__syncthreads();
float varN = reduce<float>(VarOp(mean, input), input, plane);
float invStd = 0.0f;
if (varN != 0.0f || epsilon != 0.0f) {
invStd = 1 / sqrt(varN * norm + epsilon);
}
// Save the mean, variance, and moving averages
if (threadIdx.x == 0) {
// Momentum based writeback
float unbiasedVar = varN / (N - 1);
saveMean[plane] = mean;
saveStd[plane] = invStd;
runningMean[plane] = (1 - momentum) * runningMean[plane] + momentum * mean;
runningVar[plane] = (1 - momentum) * runningVar[plane] + momentum * unbiasedVar;
}
// Write normalized and update the output
float gamma = weight.numElements() > 0 ? weight[plane] : 1.0f;
float beta = bias.numElements() > 0 ? bias[plane] : 0.0f;
for (int batch = 0; batch < input.getSize(0); ++batch) {
for (int x = threadIdx.x; x < input.getSize(2); x += blockDim.x) {
float inp = input[batch][plane][x].ldg();
output[batch][plane][x] = gamma * (inp - mean) * invStd + beta;
}
}
}
void THNN_CudaBatchNormalization_updateOutput(
THCState *state, THCudaTensor *input_, THCudaTensor *output_,
THCudaTensor *weight_, THCudaTensor *bias_, THCudaTensor *runningMean_,
THCudaTensor *runningVar_, THCudaTensor *saveMean_, THCudaTensor *saveStd_,
bool train, double momentum, double eps) {
THCUNN_assertSameGPU(state, 8, input_, output_, weight_, bias_, runningMean_,
runningVar_, saveMean_, saveStd_);
DeviceTensor3 input = devicetensor<3>(state, input_);
DeviceTensor3 output = devicetensor<3>(state, output_);
DeviceTensor1 weight = devicetensor<1>(state, weight_);
DeviceTensor1 bias = devicetensor<1>(state, bias_);
DeviceTensor1 runningMean = devicetensor<1>(state, runningMean_);
DeviceTensor1 runningVar = devicetensor<1>(state, runningVar_);
DeviceTensor1 saveMean = devicetensor<1>(state, saveMean_);
DeviceTensor1 saveStd = devicetensor<1>(state, saveStd_);
cudaStream_t s = THCState_getCurrentStream(state);
cudaDeviceProp *prop = THCState_getCurrentDeviceProperties(state);
if (!train) {
dim3 blocks(input.getSize(1));
dim3 threads(getNumThreads(input.getSize(2)));
BatchNormalizationUpdateOutputInference_kernel<<<blocks, threads, 0, s>>>(
input, output, runningMean, runningVar, weight, bias, eps);
} else {
dim3 blocks(input.getSize(1));
dim3 threads(getNumThreads(input.getSize(2)));
BatchNormalizationUpdateOutput_kernel<<<blocks, threads, 0, s>>>(
input, output, weight, bias, eps, momentum, runningMean, runningVar,
saveMean, saveStd);
}
THCudaCheck(cudaGetLastError());
}
__global__ void BatchNormalizationBackward_kernel(
const DeviceTensor3 input,
const DeviceTensor3 gradOutput,
DeviceTensor3 gradInput,
DeviceTensor1 gradWeight,
DeviceTensor1 gradBias,
const DeviceTensor1 weight,
const DeviceTensor1 runningMean,
const DeviceTensor1 runningVar,
const DeviceTensor1 saveMean,
const DeviceTensor1 saveStd,
bool train,
float scale,
double eps) {
int plane = blockIdx.x;
int N = gradOutput.getSize(0) * gradOutput.getSize(2);
float mean, stdVal;
if (train) {
mean = saveMean[plane];
stdVal = saveStd[plane];
} else {
mean = runningMean[plane];
stdVal = 1 / sqrt(runningVar[plane] + eps);
}
float weightVal = weight.numElements() > 0 ? weight[plane] : 1.0f;
float norm = 1.0f / N;
// Compute two values across (batch, x/y/z) in one pass:
// 1. Sum(gradOutput)
// 2. DotProduct(input - mean, gradOutput)
Float2 res = reduce<Float2>(GradOp(mean, input, gradOutput), gradOutput, plane);
float gradOutputSum = res.v1;
float dotP = res.v2;
float gradMean = gradOutputSum * norm;
float projScale = dotP * norm * stdVal * stdVal;
float gradScale = stdVal * weightVal;
if (gradInput.numElements() > 0) {
for (int batch = 0; batch < gradOutput.getSize(0); ++batch) {
for (int x = threadIdx.x; x < gradOutput.getSize(2); x += blockDim.x) {
float gradOut = gradOutput[batch][plane][x];
if (train) {
float inp = input[batch][plane][x];
float proj = (inp - mean) * projScale;
gradInput[batch][plane][x] = (gradOut - proj - gradMean) * gradScale;
} else {
gradInput[batch][plane][x] = gradOut * gradScale;
}
}
}
}
if (gradWeight.numElements() > 0) {
if (threadIdx.x == 0) {
gradWeight[plane] += scale * dotP * stdVal;
}
}
if (gradBias.numElements() > 0) {
if (threadIdx.x == 0) {
gradBias[plane] += scale * gradOutputSum;
}
}
}
void THNN_CudaBatchNormalization_backward(
THCState *state, THCudaTensor *input_, THCudaTensor *gradOutput_,
THCudaTensor *gradInput_, THCudaTensor *gradWeight_, THCudaTensor *gradBias_,
THCudaTensor *weight_, THCudaTensor *runningMean_, THCudaTensor *runningVar_,
THCudaTensor *saveMean_, THCudaTensor *saveStd_, bool train, float scale, double eps) {
THCUNN_assertSameGPU(state, 10, input_, gradOutput_, gradInput_, gradWeight_,
gradBias_, weight_, runningMean_, runningVar_, saveMean_, saveStd_);
DeviceTensor3 input = devicetensor<3>(state, input_);
DeviceTensor3 gradOutput = devicetensor<3>(state, gradOutput_);
DeviceTensor3 gradInput = devicetensor<3>(state, gradInput_);
DeviceTensor1 gradWeight = devicetensor<1>(state, gradWeight_);
DeviceTensor1 gradBias = devicetensor<1>(state, gradBias_);
DeviceTensor1 weight = devicetensor<1>(state, weight_);
DeviceTensor1 runningMean = devicetensor<1>(state, runningMean_);
DeviceTensor1 runningVar = devicetensor<1>(state, runningVar_);
DeviceTensor1 saveMean = devicetensor<1>(state, saveMean_);
DeviceTensor1 saveStd = devicetensor<1>(state, saveStd_);
cudaStream_t s = THCState_getCurrentStream(state);
dim3 blocks(gradOutput.getSize(1));
dim3 threads(getNumThreads(gradOutput.getSize(2)));
BatchNormalizationBackward_kernel<<<blocks, threads, 0, s>>>(
input, gradOutput, gradInput, gradWeight, gradBias, weight, runningMean, runningVar,
saveMean, saveStd, train, scale, eps);
THCudaCheck(cudaGetLastError());
}