blob: bf3108bf11a470282498ac438420130c08bb0d5b [file] [log] [blame]
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/SpatialBatchNormalization.c"
#else
void THNN_(SpatialBatchNormalization_updateOutput)(THNNState *state, THTensor *input, THTensor *output, THTensor *weight, THTensor *bias, THTensor *running_mean, THTensor *running_var, THTensor *save_mean, THTensor *save_std, bool train, double momentum, double eps)
{
long nBatch = THTensor_(size)(input, 0);
long nFeature = THTensor_(size)(input, 1);
long iH = THTensor_(size)(input, 2);
long iW = THTensor_(size)(input, 3);
long n = nBatch * iH * iW;
#pragma parallel for
for (long f = 0; f < nFeature; ++f) {
THTensor *in = THTensor_(newSelect)(input, 1, f);
THTensor *out = THTensor_(newSelect)(output, 1, f);
real mean, invstd;
if (train) {
// compute mean per feature plane
accreal sum = 0;
TH_TENSOR_APPLY(real, in, sum += *in_data;);
mean = (real) sum / n;
THTensor_(set1d)(save_mean, f, (real) mean);
// compute variance per feature plane
sum = 0;
TH_TENSOR_APPLY(real, in,
sum += (*in_data - mean) * (*in_data - mean););
if (sum == 0 && eps == 0.0) {
invstd = 0;
} else {
invstd = (real) (1 / sqrt(sum/n + eps));
}
THTensor_(set1d)(save_std, f, (real) invstd);
// update running averages
THTensor_(set1d)(running_mean, f,
(real) (momentum * mean + (1 - momentum) * THTensor_(get1d)(running_mean, f)));
accreal unbiased_var = sum / (n - 1);
THTensor_(set1d)(running_var, f,
(real) (momentum * unbiased_var + (1 - momentum) * THTensor_(get1d)(running_var, f)));
} else {
mean = THTensor_(get1d)(running_mean, f);
invstd = 1 / sqrt(THTensor_(get1d)(running_var, f) + eps);
}
// compute output
real w = weight ? THTensor_(get1d)(weight, f) : 1;
real b = bias ? THTensor_(get1d)(bias, f) : 0;
TH_TENSOR_APPLY2(real, in, real, out,
*out_data = (real) (((*in_data - mean) * invstd) * w + b););
THTensor_(free)(out);
THTensor_(free)(in);
}
}
void THNN_(SpatialBatchNormalization_backward)(THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput, THTensor *gradWeight, THTensor *gradBias, THTensor *weight, THTensor *save_mean, THTensor *save_std, double scale)
{
long nBatch = THTensor_(size)(input, 0);
long nFeature = THTensor_(size)(input, 1);
long iH = THTensor_(size)(input, 2);
long iW = THTensor_(size)(input, 3);
long n = nBatch * iH * iW;
// Q(X) = X - E[x] ; i.e. input centered to zero mean
// Y = Q(X) / σ ; i.e. BN output before weight and bias
// dL/dX = (Q(dL/dY) - dot(Y, dL/dY) * Y) / σ
#pragma parallel for
for (long f = 0; f < nFeature; ++f) {
THTensor *in = THTensor_(newSelect)(input, 1, f);
THTensor *gradOut = THTensor_(newSelect)(gradOutput, 1, f);
real mean = THTensor_(get1d)(save_mean, f);
real invstd = THTensor_(get1d)(save_std, f);
real w = weight ? THTensor_(get1d)(weight, f) : 1;
// sum over all gradOutput in feature plane
accreal sum = 0;
TH_TENSOR_APPLY(real, gradOut, sum += *gradOut_data;);
// dot product of the Q(X) and gradOuput
accreal dotp = 0;
TH_TENSOR_APPLY2(real, in, real, gradOut,
dotp += (*in_data - mean) * (*gradOut_data););
if (gradInput) {
THTensor *gradIn = THTensor_(newSelect)(gradInput, 1, f);
// projection of gradOutput on to output scaled by std
real k = (real) dotp * invstd * invstd / n;
TH_TENSOR_APPLY2(real, gradIn, real, in,
*gradIn_data = (*in_data - mean) * k;);
accreal gradMean = sum / n;
TH_TENSOR_APPLY2(real, gradIn, real, gradOut,
*gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * invstd * w;);
THTensor_(free)(gradIn);
}
if (gradWeight) {
real val = THTensor_(get1d)(gradWeight, f);
THTensor_(set1d)(gradWeight, f, val + scale * dotp * invstd);
}
if (gradBias) {
real val = THTensor_(get1d)(gradBias, f);
THTensor_(set1d)(gradBias, f, val + scale * sum);
}
THTensor_(free)(gradOut);
THTensor_(free)(in);
}
}
#endif