blob: 65060c96579e970bc6e72e33e1edbcd8679868a5 [file] [log] [blame]
#include "THCUNN.h"
#include "common.h"
template <typename Dtype, bool COUNT_INCLUDE_PAD>
__global__ void AvePoolForward(const int nthreads,
const Dtype* const bottom_data, const int num, const int channels,
const int height, const int width, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
Dtype* const top_data) {
CUDA_KERNEL_LOOP(index, nthreads) {
const int pw = index % pooled_width;
const int ph = (index / pooled_width) % pooled_height;
const int c = (index / pooled_width / pooled_height) % channels;
const int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
int hend = min(hstart + kernel_h, height + pad_h);
int wend = min(wstart + kernel_w, width + pad_w);
const int pool_size = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
hend = min(hend, height);
wend = min(wend, width);
Dtype aveval = 0;
const Dtype* const bottom_slice = bottom_data + (n * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
aveval += bottom_slice[h * width + w];
}
}
if(COUNT_INCLUDE_PAD)
top_data[index] = aveval / pool_size;
else
top_data[index] = aveval / ((hend - hstart) * (wend - wstart));
}
}
void THNN_CudaSpatialAveragePooling_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output, int kW, int kH, int dW, int dH, int padW, int padH, bool ceil_mode, bool count_include_pad)
{
THCUNN_assertSameGPU(state, 2, input, output);
THArgCheck(input->nDimension == 3 || input->nDimension == 4, 2, "3D or 4D (batch) tensor expected");
long nInputCols, nInputRows, nInputPlane, batchSize;
long nOutputCols, nOutputRows;
if (input->nDimension == 3) {
nInputCols = input->size[2];
nInputRows = input->size[1];
nInputPlane = input->size[0];
batchSize = 1;
}
else
{
nInputCols = input->size[3];
nInputRows = input->size[2];
nInputPlane = input->size[1];
batchSize = input->size[0];
}
THArgCheck(nInputCols >= kW - 2*padW && nInputRows >= kH - 2*padH, 2, "input image smaller than kernel size");
THArgCheck(kW/2 >= padW && kH/2 >= padH, 2, "pad should be smaller than half of kernel size");
if(ceil_mode) {
nOutputCols = ceil(float(nInputCols - kW + 2*padW) / float(dW)) + 1;
nOutputRows = ceil(float(nInputRows - kH + 2*padH) / float(dH)) + 1;
}
else {
nOutputCols = floor(float(nInputCols - kW + 2*padW) / float(dW)) + 1;
nOutputRows = floor(float(nInputRows - kH + 2*padH) / float(dH)) + 1;
}
if (padW || padH)
{
// ensure that the last pooling starts inside the image
// needed to avoid problems in ceil mode
if ((nOutputRows - 1)*dH >= nInputRows + padH)
--nOutputRows;
if ((nOutputCols - 1)*dW >= nInputCols + padW)
--nOutputCols;
}
input = THCudaTensor_newContiguous(state, input);
float* input_data = THCudaTensor_data(state, input);
THCudaTensor_resize4d(state, output, batchSize, nInputPlane, nOutputRows, nOutputCols);
float* output_data = THCudaTensor_data(state, output);
int count = THCudaTensor_nElement(state, output);
if(count_include_pad)
AvePoolForward<float, true>
<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>(
count, input_data,
batchSize, nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
kH, kW, dH, dW, padH, padW, output_data);
else
AvePoolForward<float, false>
<<<GET_BLOCKS(count), CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>(
count, input_data,
batchSize, nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
kH, kW, dH, dW, padH, padW, output_data);
THCudaCheck(cudaGetLastError());
if(input->nDimension == 3)
THCudaTensor_resize3d(state, output, nInputPlane, nOutputRows, nOutputCols);
THCudaTensor_free(state, input);
}
template <typename Dtype, bool COUNT_INCLUDE_PAD>
__global__ void AvePoolBackward(const int nthreads, const Dtype* const top_diff,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_h, const int kernel_w, const int stride_h,
const int stride_w, const int pad_h, const int pad_w,
Dtype* const bottom_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local index
// find out the local offset
const int w = index % width + pad_w;
const int h = (index / width) % height + pad_h;
const int c = (index / width / height) % channels;
const int n = index / width / height / channels;
const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
const int phend = min(h / stride_h + 1, pooled_height);
const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
const int pwend = min(w / stride_w + 1, pooled_width);
Dtype gradient = 0;
const Dtype* const top_diff_slice =
top_diff + (n * channels + c) * pooled_height * pooled_width;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
// figure out the pooling size
int hstart = ph * stride_h - pad_h;
int wstart = pw * stride_w - pad_w;
int hend = min(hstart + kernel_h, height + pad_h);
int wend = min(wstart + kernel_w, width + pad_w);
int pool_size = (hend - hstart) * (wend - wstart);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
hend = min(hend, height);
wend = min(wend, width);
if(COUNT_INCLUDE_PAD)
gradient += top_diff_slice[ph * pooled_width + pw] / pool_size;
else
gradient += top_diff_slice[ph * pooled_width + pw] / ((hend - hstart) * (wend - wstart));
}
}
bottom_diff[index] = gradient;
}
}
void THNN_CudaSpatialAveragePooling_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput, int kW, int kH, int dW, int dH, int padW, int padH, bool ceil_mode, bool count_include_pad)
{
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput);
input = THCudaTensor_newContiguous(state, input);
gradOutput = THCudaTensor_newContiguous(state, gradOutput);
long nInputCols, nInputRows, nInputPlane, batchSize;
long nOutputCols, nOutputRows;
if (input->nDimension == 3) {
nInputCols = input->size[2];
nInputRows = input->size[1];
nInputPlane = input->size[0];
batchSize = 1;
}
else
{
nInputCols = input->size[3];
nInputRows = input->size[2];
nInputPlane = input->size[1];
batchSize = input->size[0];
}
if(ceil_mode) {
nOutputCols = ceil(float(nInputCols - kW + 2*padW) / float(dW)) + 1;
nOutputRows = ceil(float(nInputRows - kH + 2*padH) / float(dH)) + 1;
}
else {
nOutputCols = floor(float(nInputCols - kW + 2*padW) / float(dW)) + 1;
nOutputRows = floor(float(nInputRows - kH + 2*padH) / float(dH)) + 1;
}
if (padW || padH)
{
// ensure that the last pooling starts inside the image
// needed to avoid problems in ceil mode
if ((nOutputRows - 1)*dH >= nInputRows + padH)
--nOutputRows;
if ((nOutputCols - 1)*dW >= nInputCols + padW)
--nOutputCols;
}
THCudaTensor_resizeAs(state, gradInput, input);
int count = THCudaTensor_nElement(state, input);
if(count_include_pad)
AvePoolBackward<float, true>
<<< GET_BLOCKS(count), CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>
(count,
THCudaTensor_data(state, gradOutput),
batchSize, nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
kH, kW, dH, dW, padH, padW,
THCudaTensor_data(state, gradInput));
else
AvePoolBackward<float, false>
<<< GET_BLOCKS(count), CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>
(count,
THCudaTensor_data(state, gradOutput),
batchSize, nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols,
kH, kW, dH, dW, padH, padW,
THCudaTensor_data(state, gradInput));
THCudaCheck(cudaGetLastError());
// clean
THCudaTensor_free(state, input);
THCudaTensor_free(state, gradOutput);
}