blob: cbe99f3cac2e7dfc62701c71c2d8ad27e4b1210c [file] [log] [blame]
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/BatchNormalization.cu"
#else
#define DeviceTensor3 THCDeviceTensor<real, 3>
#define DeviceTensor1 THCDeviceTensor<real, 1>
template <int Dim>
static THCDeviceTensor<real, Dim> devicetensor(THCState *state, THCTensor *t) {
if (!t) {
return THCDeviceTensor<real, Dim>();
}
int inDim = THCTensor_(nDimension)(state, t);
if (inDim == Dim) {
return toDeviceTensor<real, Dim>(state, t);
}
// View in which the last dimensions are collapsed or expanded as needed
THAssert(THCTensor_(isContiguous)(state, t));
int size[Dim];
for (int i = 0; i < Dim || i < inDim; ++i) {
if (i < Dim && i < inDim) {
size[i] = t->size[i];
} else if (i < Dim) {
size[i] = 1;
} else {
size[Dim - 1] *= t->size[i];
}
}
return THCDeviceTensor<real, Dim>(THCTensor_(data)(state, t), size);
}
void THNN_(BatchNormalization_updateOutput)(
THCState *state, THCTensor *input_, THCTensor *output_,
THCTensor *weight_, THCTensor *bias_, THCTensor *runningMean_,
THCTensor *runningVar_, THCTensor *saveMean_, THCTensor *saveStd_,
bool train, double momentum, double eps) {
THCTensor_(resizeAs)(state, output_, input_);
DeviceTensor3 input = devicetensor<3>(state, input_);
DeviceTensor3 output = devicetensor<3>(state, output_);
DeviceTensor1 weight = devicetensor<1>(state, weight_);
DeviceTensor1 bias = devicetensor<1>(state, bias_);
DeviceTensor1 runningMean = devicetensor<1>(state, runningMean_);
DeviceTensor1 runningVar = devicetensor<1>(state, runningVar_);
DeviceTensor1 saveMean = devicetensor<1>(state, saveMean_);
DeviceTensor1 saveStd = devicetensor<1>(state, saveStd_);
cudaStream_t s = THCState_getCurrentStream(state);
cudaDeviceProp *prop = THCState_getCurrentDeviceProperties(state);
if (!train) {
dim3 blocks(input.getSize(1));
dim3 threads(getNumThreads(input.getSize(2)));
BatchNormalizationUpdateOutputInference_kernel<real, accreal, DeviceTensor1, DeviceTensor3> <<<blocks, threads, 0, s>>>(
input, output, runningMean, runningVar, weight, bias, eps);
} else {
dim3 blocks(input.getSize(1));
dim3 threads(getNumThreads(input.getSize(2)));
BatchNormalizationUpdateOutput_kernel<real, accreal, DeviceTensor1, DeviceTensor3> <<<blocks, threads, 0, s>>>(
input, output, weight, bias, eps, momentum, runningMean, runningVar,
saveMean, saveStd);
}
THCudaCheck(cudaGetLastError());
}
void THNN_(BatchNormalization_backward)(
THCState *state, THCTensor *input_, THCTensor *gradOutput_,
THCTensor *gradInput_, THCTensor *gradWeight_, THCTensor *gradBias_,
THCTensor *weight_, THCTensor *runningMean_, THCTensor *runningVar_,
THCTensor *saveMean_, THCTensor *saveStd_, bool train, float scale, double eps) {
THCUNN_check_shape(state, input_, gradOutput_);
DeviceTensor3 input = devicetensor<3>(state, input_);
DeviceTensor3 gradOutput = devicetensor<3>(state, gradOutput_);
DeviceTensor3 gradInput = devicetensor<3>(state, gradInput_);
DeviceTensor1 gradWeight = devicetensor<1>(state, gradWeight_);
DeviceTensor1 gradBias = devicetensor<1>(state, gradBias_);
DeviceTensor1 weight = devicetensor<1>(state, weight_);
DeviceTensor1 runningMean = devicetensor<1>(state, runningMean_);
DeviceTensor1 runningVar = devicetensor<1>(state, runningVar_);
DeviceTensor1 saveMean = devicetensor<1>(state, saveMean_);
DeviceTensor1 saveStd = devicetensor<1>(state, saveStd_);
cudaStream_t s = THCState_getCurrentStream(state);
dim3 blocks(gradOutput.getSize(1));
dim3 threads(getNumThreads(gradOutput.getSize(2)));
BatchNormalizationBackward_kernel<real, accreal, DeviceTensor1, DeviceTensor3> <<<blocks, threads, 0, s>>>(
input, gradOutput, gradInput, gradWeight, gradBias, weight, runningMean, runningVar,
saveMean, saveStd, train, scale, eps);
THCudaCheck(cudaGetLastError());
}
#undef DeviceTensor3
#undef DeviceTensor1
#endif