blob: cb08b0a9f6cb6078a9d6527eeaf9de60dcc37276 [file] [log] [blame]
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/SpatialMaxUnpooling.cu"
#else
void THNN_(SpatialMaxUnpooling_updateOutput)(
THCState *state,
THCTensor *input,
THCTensor *output,
THCIndexTensor *indices,
int owidth, int oheight)
{
THCUNN_assertSameGPU(state, 3, input, output, indices);
THCUNN_argCheck(state, input->nDimension == 3 || input->nDimension == 4, 2, input,
"3D or 4D (batch mode) tensor expected for input, but got: %s");
THCUNN_check_shape_indices(state, indices, input);
int64_t nInputCols, nInputRows, nInputPlane, batchSize;
if (input->nDimension == 3) {
nInputCols = input->size[2];
nInputRows = input->size[1];
nInputPlane = input->size[0];
batchSize = 1;
}
else
{
nInputCols = input->size[3];
nInputRows = input->size[2];
nInputPlane = input->size[1];
batchSize = input->size[0];
}
input = THCTensor_(newContiguous)(state, input);
indices = THCIndexTensor_(newContiguous)(state, indices);
THCTensor_(resize4d)(state, output, batchSize, nInputPlane, oheight, owidth);
THCTensor_(zero)(state, output);
int count = THCTensor_(nElement)(state, input);
MaxUnpoolForward <<< GET_BLOCKS(count), CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>
(count, THCTensor_(data)(state, input), THCIndexTensor_(data)(state, indices),
batchSize, nInputPlane, nInputRows, nInputCols, oheight, owidth, THCTensor_(data)(state, output));
THCudaCheck(cudaGetLastError());
if(input->nDimension == 3)
THCTensor_(resize3d)(state, output, nInputPlane, oheight, owidth);
THCTensor_(free)(state, input);
THCIndexTensor_(free)(state, indices);
}
void THNN_(SpatialMaxUnpooling_updateGradInput)(
THCState *state,
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
THCIndexTensor *indices,
int owidth, int oheight)
{
THCUNN_assertSameGPU(state, 4, input, gradOutput, indices, gradInput);
THCUNN_check_shape_indices(state, indices, input);
int64_t nInputCols, nInputRows, nInputPlane, batchSize;
int dimw = 2;
int dimh = 1;
if (input->nDimension == 3) {
nInputPlane = input->size[0];
batchSize = 1;
}
else
{
++dimw;
++dimh;
nInputPlane = input->size[1];
batchSize = input->size[0];
}
nInputCols = input->size[dimw];
nInputRows = input->size[dimh];
if(owidth!=gradOutput->size[dimw] || oheight!=gradOutput->size[dimh]){
THError("Inconsistent gradOutput size. oheight= %d, owidth= %d, gradOutput: %dx%d",
oheight, owidth,gradOutput->size[dimh],gradOutput->size[dimw]);
}
input = THCTensor_(newContiguous)(state, input);
indices = THCIndexTensor_(newContiguous)(state, indices);
gradOutput = THCTensor_(newContiguous)(state, gradOutput);
THCTensor_(resizeAs)(state, gradInput, input);
int count = THCTensor_(nElement)(state, input);
MaxUnpoolBackward <<< GET_BLOCKS(count), CUDA_NUM_THREADS, 0, THCState_getCurrentStream(state) >>>
(count, THCTensor_(data)(state, gradOutput), THCIndexTensor_(data)(state, indices),
batchSize, nInputPlane, nInputRows, nInputCols, oheight, owidth, THCTensor_(data)(state, gradInput));
THCudaCheck(cudaGetLastError());
// clean
THCTensor_(free)(state, input);
THCIndexTensor_(free)(state, indices);
THCTensor_(free)(state, gradOutput);
}
#endif