blob: 5531c08bbf126d6a3ce99dab286b5c1b565dd523 [file] [log] [blame]
#include "ATen/ATen.h"
#include "ATen/NativeFunctions.h"
#include <ATen/Parallel.h>
#include <tuple>
#include <vector>
namespace at {
namespace native {
namespace {
template <typename scalar_t>
static std::vector<int> fractional_max_pool2d_generate_intervals(
scalar_t sample,
int inputSize,
int outputSize,
int poolSize) {
scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize) /
static_cast<scalar_t>(outputSize - 1);
std::vector<int> sequence(outputSize);
for (int i = 0; i < outputSize - 1; ++i) {
sequence[i] =
static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha);
}
sequence[outputSize - 1] = inputSize - poolSize;
return sequence;
}
template <typename scalar_t>
static void fractional_max_pool2d_out_single_batch_frame(
scalar_t* input,
scalar_t* output,
int64_t* indices,
scalar_t* randomSamples,
int numPlanes,
int inputW, int inputH,
int outputW, int outputH,
int poolSizeW, int poolSizeH) {
at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) {
for (auto plane = start; plane < end; ++plane) {
/* each plane contains 2 random samples, one for W and one for H */
scalar_t* randomSamplesForPlane = randomSamples + plane * 2;
/* Generate interval sequence */
auto sequenceW = fractional_max_pool2d_generate_intervals<scalar_t>(
randomSamplesForPlane[0], inputW, outputW, poolSizeW);
auto sequenceH = fractional_max_pool2d_generate_intervals<scalar_t>(
randomSamplesForPlane[1], inputH, outputH, poolSizeH);
/* loop over output */
int h, w;
scalar_t* inputForPlane = input + plane * inputW * inputH;
scalar_t* outputForPlane = output + plane * outputW * outputH;
int64_t* indicesForPlane = indices + plane * outputW * outputH;
for (h = 0; h < outputH; ++h) {
int inputHStart = sequenceH[h];
for (w = 0; w < outputW; ++w) {
int inputWStart = sequenceW[w];
scalar_t maxVal = -std::numeric_limits<scalar_t>::infinity();
int64_t maxIndex = -1;
int h2, w2;
for (h2 = inputHStart; h2 < inputHStart + poolSizeH; ++h2) {
for (w2 = inputWStart; w2 < inputWStart + poolSizeW; ++w2) {
AT_ASSERT(h2 >= 0 && h2 < inputH);
AT_ASSERT(w2 >= 0 && w2 < inputW);
int planeIndex = h2 * inputW + w2;
scalar_t val = inputForPlane[planeIndex];
if (val > maxVal) {
maxVal = val;
maxIndex = planeIndex;
}
}
}
AT_ASSERT(maxVal != -std::numeric_limits<scalar_t>::infinity());
AT_ASSERT(maxIndex != -1);
outputForPlane[h * outputW + w] = maxVal;
indicesForPlane[h * outputW + w] = maxIndex;
}
}
}
});
}
template <typename scalar_t>
static void fractional_max_pool2d_out_frame(
scalar_t* input,
scalar_t* output,
int64_t* indices,
scalar_t* randomSamples,
int numBatch, int numPlanes,
int inputW, int inputH,
int outputW, int outputH,
int poolSizeW, int poolSizeH) {
if(numBatch == 1) {
fractional_max_pool2d_out_single_batch_frame<scalar_t>(
input,
output,
indices,
randomSamples,
numPlanes, inputW, inputH, outputW, outputH, poolSizeW, poolSizeH
);
return;
}
at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) {
for (auto batch = start; batch < end; ++batch) {
fractional_max_pool2d_out_single_batch_frame<scalar_t>(
input + batch * numPlanes * inputH * inputW,
output + batch * numPlanes * outputH * outputW,
indices + batch * numPlanes * outputH * outputW,
randomSamples + batch * numPlanes * 2,
numPlanes, inputW, inputH, outputW, outputH, poolSizeW, poolSizeH);
}
});
}
void fractional_max_pool2d_out_cpu_template(
const at::Tensor& input_,
at::Tensor& output,
IntArrayRef output_size,
IntArrayRef pool_size,
at::Tensor& indices,
const at::Tensor& randomSamples) {
int numBatch = 1;
int planeDim = 0;
int heightDim = 1;
int widthDim = 2;
int outputH = output_size[0];
int outputW = output_size[1];
int poolSizeH = pool_size[0];
int poolSizeW = pool_size[1];
/* get contiguous input */
auto input = input_.contiguous();
int ndims = input.ndimension();
TORCH_CHECK(input.numel() > 0 && (ndims == 3 || ndims == 4),
"non-empty 3D or 4D (batch mode) tensor expected for input, but got: ",
ndims);
if (ndims == 4) {
numBatch = input.size(0);
planeDim++;
heightDim++;
widthDim++;
}
/* sizes */
int numPlanes = input.size(planeDim);
int inputH = input.size(heightDim);
int inputW = input.size(widthDim);
TORCH_CHECK(outputH + poolSizeH - 1 <= inputH,
"fractional_max_pool2d(): pool height ", poolSizeH,
" too large relative to input height ", inputH);
TORCH_CHECK(outputW + poolSizeW - 1 <= inputW,
"fractional_max_pool2d(): pool width ", poolSizeW,
" too large relative to input width ", inputW);
if (ndims == 3) {
/* resize output */
output.resize_({numPlanes, outputH, outputW});
/* indices will contain the locations for each output point */
indices.resize_({numPlanes, outputH, outputW});
} else {
output.resize_({numBatch, numPlanes, outputH, outputW});
/* indices will contain the locations for each output point */
indices.resize_({numBatch, numPlanes, outputH, outputW});
}
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(),
"fractional_max_pool2d_out_frame", [&] {
auto input_data = input.data_ptr<scalar_t>();
auto output_data = output.data_ptr<scalar_t>();
auto indices_data = indices.data_ptr<int64_t>();
auto randomSamples_data = randomSamples.data_ptr<scalar_t>();
fractional_max_pool2d_out_frame<scalar_t>(
input_data,
output_data,
indices_data,
randomSamples_data,
numBatch, numPlanes,
inputW, inputH,
outputW, outputH,
poolSizeW, poolSizeH);
}
);
}
template <typename scalar_t>
static void fractional_max_pool2d_backward_out_single_batch_frame(
scalar_t* gradInput,
scalar_t* gradOutput,
int64_t* indices,
int numPlanes,
int inputW, int inputH,
int outputW, int outputH) {
at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) {
for (auto plane = start; plane < end; plane++) {
scalar_t* gradInputForPlane = gradInput + plane * inputW * inputH;
scalar_t* gradOutputForPlane = gradOutput + plane * outputW * outputH;
int64_t* indicesForPlane = indices + plane * outputW * outputH;
int h, w;
for (h = 0; h < outputH; ++h) {
for (w = 0; w < outputW; ++w) {
int outputIndex = h * outputW + w;
int64_t index = indicesForPlane[outputIndex];
AT_ASSERT(index >= 0 && index < inputW * inputH);
gradInputForPlane[index] += gradOutputForPlane[outputIndex];
}
}
}
});
}
template <typename scalar_t>
static void fractional_max_pool2d_backward_out_frame(
scalar_t* gradInput,
scalar_t* gradOutput,
int64_t* indices,
int numBatch, int numPlanes,
int inputW, int inputH,
int outputW, int outputH) {
if(numBatch == 1) {
fractional_max_pool2d_backward_out_single_batch_frame<scalar_t>(
gradInput, gradOutput, indices,
numPlanes,
inputW, inputH, outputW, outputH
);
return;
}
at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) {
for (auto batch = start; batch < end; ++batch) {
fractional_max_pool2d_backward_out_single_batch_frame<scalar_t>(
gradInput + batch * numPlanes * inputH * inputW,
gradOutput + batch * numPlanes * outputH * outputW,
indices + batch * numPlanes * outputH * outputW,
numPlanes, inputW, inputH, outputW, outputH);
}
});
}
Tensor& fractional_max_pool2d_backward_out_cpu_template(
const at::Tensor& input,
const at::Tensor& gradOutput_,
at::Tensor& gradInput,
IntArrayRef output_size,
IntArrayRef pool_size /* unused */,
const at::Tensor& indices) {
int numBatch = 1;
int planeDim = 0;
int heightDim = 1;
int widthDim = 2;
int outputH = output_size[0];
int outputW = output_size[1];
int ndims = input.ndimension();
if (ndims == 4) {
numBatch = input.size(0);
planeDim = 1;
heightDim++;
widthDim++;
}
/* sizes */
int numPlanes = input.size(planeDim);
int inputH = input.size(heightDim);
int inputW = input.size(widthDim);
/* get contiguous gradOutput */
auto gradOutput = gradOutput_.contiguous();
TORCH_CHECK(outputW == gradOutput.size(widthDim),
"fractional_max_pool2d_backward(): gradOutput width unexpected");
TORCH_CHECK(outputH == gradOutput.size(heightDim),
"fractional_max_pool2d_backward(): gradOutput height unexpected");
/* resize */
gradInput.resize_as_(input);
gradInput.zero_();
/* backprop */
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "fractional_max_pool2d_backward_out_frame", [&] {
auto gradInput_data = gradInput.data_ptr<scalar_t>();
auto gradOutput_data = gradOutput.data_ptr<scalar_t>();
auto indices_data = indices.data_ptr<int64_t>();
fractional_max_pool2d_backward_out_frame<scalar_t>(
gradInput_data,
gradOutput_data,
indices_data,
numBatch, numPlanes,
inputW, inputH,
outputW, outputH
);
}
);
return gradInput;
}
} // namespace
std::tuple<Tensor&, Tensor&> fractional_max_pool2d_out_cpu(
at::Tensor& output,
at::Tensor& indices,
const at::Tensor& input,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& randomSamples)
{
fractional_max_pool2d_out_cpu_template(
input,
output,
output_size,
pool_size,
indices,
randomSamples);
return std::tuple<Tensor&, Tensor&>(output, indices);
}
std::tuple<Tensor, Tensor> fractional_max_pool2d_cpu(
const at::Tensor& input,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& randomSamples)
{
Tensor output = at::empty({0}, input.options());
Tensor indices = at::empty({0}, input.options().dtype(kLong));
fractional_max_pool2d_out_cpu_template(
input,
output,
output_size,
pool_size,
indices,
randomSamples);
return std::tuple<Tensor, Tensor>(output, indices);
}
Tensor& fractional_max_pool2d_backward_out_cpu(
at::Tensor& gradInput,
const at::Tensor& gradOutput_,
const at::Tensor& input,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& indices)
{
gradInput.resize_as_(input);
fractional_max_pool2d_backward_out_cpu_template(
input,
gradOutput_,
gradInput,
output_size,
pool_size,
indices);
return gradInput;
}
Tensor fractional_max_pool2d_backward_cpu(
const at::Tensor& gradOutput_,
const at::Tensor& input,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& indices)
{
Tensor gradInput = at::empty({0}, input.options());
fractional_max_pool2d_backward_out_cpu_template(
input,
gradOutput_,
gradInput,
output_size,
pool_size,
indices);
return gradInput;
}
} // at::native
} // at