blob: c524f0545473c8899130ab749ce8efec9a327fff [file] [log] [blame]
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/TensorMeta.h>
#include <ATen/native/FractionalMaxPooling.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/fractional_max_pool3d_backward_native.h>
#include <ATen/ops/fractional_max_pool3d_native.h>
#endif
#include <vector>
namespace at {
namespace meta {
TORCH_PRECOMPUTE_META_FUNC(fractional_max_pool3d)(
const at::Tensor& input_,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& randomSamples
) {
TORCH_CHECK(
pool_size.size() == 3,
"fractional_max_pool3d: kernel_size must either be a single Int or tuple of three Ints")
TORCH_CHECK(
output_size.size() == 3,
"fractional_max_pool3d: output_size must either be a single Int or tuple of three Ints")
int64_t outputT = output_size[0];
int64_t outputH = output_size[1];
int64_t outputW = output_size[2];
int64_t poolSizeT = pool_size[0];
int64_t poolSizeH = pool_size[1];
int64_t poolSizeW = pool_size[2];
int64_t numBatch = 1;
int64_t planeDim = 0;
int64_t timeDim = 1;
int64_t heightDim = 2;
int64_t widthDim = 3;
int64_t ndims = input_.ndimension();
TORCH_CHECK(ndims == 4 || ndims == 5,
"fractional_max_pool3d_out(): Expected 4D or 5D tensor, but got: ",
input_.sizes());
for (const auto i : c10::irange(1, ndims)) {
TORCH_CHECK(input_.size(i) > 0,
"fractional_max_pool3d_out(): Expected input to have non-zero size for non-batch dimensions, but got",
input_.sizes(), " with dimension ", i, " being empty.");
}
if (ndims == 5) {
numBatch = input_.size(0);
planeDim++;
timeDim++;
heightDim++;
widthDim++;
}
/* sizes */
int64_t numPlanes = input_.size(planeDim);
int64_t inputT = input_.size(timeDim);
int64_t inputH = input_.size(heightDim);
int64_t inputW = input_.size(widthDim);
TORCH_CHECK(outputT + poolSizeT - 1 < inputT,
"fractional_max_pool3d_out(): pool time ", poolSizeT,
" too large relative to input time ", inputT);
TORCH_CHECK(outputW + poolSizeW - 1 < inputW,
"fractional_max_pool3d_out(): pool width ", poolSizeW,
" too large relative to input width ", inputW);
TORCH_CHECK(outputH + poolSizeH - 1 < inputH,
"fractional_max_pool3d_out(): pool height ", poolSizeH,
" too large relative to input height ", inputH);
if (ndims == 4) {
/* resize output */
set_output_raw_strided(0, {numPlanes, outputT, outputH, outputW}, {}, input_.options());
/* indices will contain the locations for each output point */
set_output_raw_strided(1, {numPlanes, outputT, outputH, outputW}, {}, input_.options().dtype(kLong));
} else {
set_output_raw_strided(0, {numBatch, numPlanes, outputT, outputH, outputW}, {}, input_.options());
/* indices will contain the locations for each output point */
set_output_raw_strided(1, {numBatch, numPlanes, outputT, outputH, outputW}, {}, input_.options().dtype(kLong));
}
return TORCH_PRECOMPUTE_STRUCT(fractional_max_pool3d)().set_numBatch(numBatch).set_numPlanes(numPlanes).set_inputT(inputT).set_inputH(inputH).set_inputW(inputW)
.set_poolSizeT(poolSizeT).set_poolSizeH(poolSizeH).set_poolSizeW(poolSizeW)
.set_outputT(outputT).set_outputH(outputH).set_outputW(outputW);
}
} // namespace meta
namespace native {
namespace {
template<typename scalar_t>
static void fractional_max_pool3d_out_single_batch_frame(
scalar_t* input,
scalar_t* output,
int64_t* indices,
scalar_t* randomSamples,
int64_t numPlanes,
int64_t inputT, int64_t inputH, int64_t inputW,
int64_t outputT, int64_t outputH, int64_t outputW,
int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) {
at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) {
for (const auto plane : c10::irange(start, end)) {
/* each plane contains 3 random samples,
one for T, one for W, and one for H */
scalar_t* randomSamplesForPlane = randomSamples + plane * 3;
/* Generate interval sequence */
auto sequenceT = generate_intervals<scalar_t>(
randomSamplesForPlane[0], inputT, outputT, poolSizeT);
auto sequenceH = generate_intervals<scalar_t>(
randomSamplesForPlane[1], inputH, outputH, poolSizeH);
auto sequenceW = generate_intervals<scalar_t>(
randomSamplesForPlane[2], inputW, outputW, poolSizeW);
/* loop over output */
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t t, h, w;
scalar_t* inputForPlane = input + plane * inputT * inputH * inputW;
scalar_t* outputForPlane = output + plane * outputT * outputH * outputW;
int64_t* indicesForPlane = indices + plane * outputT * outputH * outputW;
for (t = 0; t < outputT; ++t) {
int64_t inputTStart = sequenceT[t];
for (h = 0; h < outputH; ++h) {
int64_t inputHStart = sequenceH[h];
for (w = 0; w < outputW; ++w) {
int64_t inputWStart = sequenceW[w];
int64_t t2 = inputTStart, h2 = inputHStart, w2 = inputWStart;
scalar_t maxVal = -std::numeric_limits<scalar_t>::infinity();
int64_t maxIndex = t2 * inputH * inputW + h2 * inputW + w2;
for (t2 = inputTStart; t2 < inputTStart + poolSizeT; ++t2) {
for (h2 = inputHStart; h2 < inputHStart + poolSizeH; ++h2) {
for (w2 = inputWStart; w2 < inputWStart + poolSizeW; ++w2) {
AT_ASSERT(t2 >= 0 && t2 < inputT);
AT_ASSERT(h2 >= 0 && h2 < inputH);
AT_ASSERT(w2 >= 0 && w2 < inputW);
int64_t planeIndex = t2 * inputH * inputW + h2 * inputW + w2;
scalar_t val = inputForPlane[planeIndex];
if (val > maxVal || std::isnan(val)) {
maxVal = val;
maxIndex = planeIndex;
}
}
}
}
outputForPlane[t * outputH * outputW + h * outputW + w] = maxVal;
indicesForPlane[t * outputH * outputW + h * outputW + w] = maxIndex;
}
}
}
}
});
}
template<typename scalar_t>
static void fractional_max_pool3d_out_frame(
scalar_t* input,
scalar_t* output,
int64_t* indices,
scalar_t* randomSamples,
int64_t numBatch, int64_t numPlanes,
int64_t inputT, int64_t inputH, int64_t inputW,
int64_t outputT, int64_t outputH, int64_t outputW,
int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) {
if(numBatch == 1) {
fractional_max_pool3d_out_single_batch_frame<scalar_t>(
input, output, indices, randomSamples,
numPlanes,
inputT, inputH, inputW,
outputT, outputH, outputW,
poolSizeT, poolSizeH, poolSizeW
);
return;
}
at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) {
for (const auto batch : c10::irange(start, end)) {
fractional_max_pool3d_out_single_batch_frame<scalar_t>(
input + batch * numPlanes * inputW * inputH * inputT,
output + batch * numPlanes * outputW * outputH * outputT,
indices + batch * numPlanes * outputW * outputH * outputT,
randomSamples + batch * numPlanes * 3,
numPlanes,
inputT, inputH, inputW,
outputT, outputH, outputW,
poolSizeT, poolSizeH, poolSizeW
);
}
});
}
} // anonymous namespace
TORCH_IMPL_FUNC(fractional_max_pool3d_out_cpu)(
const at::Tensor& input_,
int64_t poolSizeT,
int64_t poolSizeH,
int64_t poolSizeW,
int64_t outputT,
int64_t outputH,
int64_t outputW,
const at::Tensor& randomSamples_,
int64_t numBatch,
int64_t numPlanes,
int64_t inputT,
int64_t inputH,
int64_t inputW,
const at::Tensor& output,
const at::Tensor& indices) {
fractional_max_pool_check_shape</*ndim*/ 3>(input_, randomSamples_);
if (output.numel() == 0) {
return;
}
/* get contiguous input and samples */
auto input = input_.contiguous();
auto randomSamples = randomSamples_.contiguous();
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"fractional_max_pool3d_out_frame",
[&] {
fractional_max_pool3d_out_frame<scalar_t>(
input.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
indices.data_ptr<int64_t>(),
randomSamples.data_ptr<scalar_t>(),
numBatch, numPlanes,
inputT, inputH, inputW,
outputT, outputH, outputW,
poolSizeT, poolSizeH, poolSizeW
);
}
);
}
namespace {
template<typename scalar_t>
static void fractional_max_pool3d_backward_out_single_batch_frame(
scalar_t* gradInput,
scalar_t* gradOutput,
int64_t* indices,
int64_t numPlanes,
int64_t inputT, int64_t inputH, int64_t inputW,
int64_t outputT, int64_t outputH, int64_t outputW) {
at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) {
for (const auto plane : c10::irange(start, end)) {
scalar_t* gradInputForPlane = gradInput + plane * inputT * inputH * inputW;
scalar_t* gradOutputForPlane = gradOutput +
plane * outputT * outputH * outputW;
int64_t* indicesForPlane = indices + plane * outputT * outputH * outputW;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t h, w, t;
for (t = 0; t < outputT; ++t) {
for (h = 0; h < outputH; ++h) {
for (w = 0; w < outputW; ++w) {
int64_t outputIndex = t * outputH * outputW + h * outputW + w;
int64_t index = indicesForPlane[outputIndex];
AT_ASSERT(index >= 0 && index < inputT * inputH * inputW);
gradInputForPlane[index] += gradOutputForPlane[outputIndex];
}
}
}
}
});
}
template<typename scalar_t>
static void fractional_max_pool3d_backward_out_frame(
scalar_t* gradInput,
scalar_t* gradOutput,
int64_t* indices,
int64_t numBatch, int64_t numPlanes,
int64_t inputT, int64_t inputH, int64_t inputW,
int64_t outputT, int64_t outputH, int64_t outputW) {
if(numBatch == 1) {
fractional_max_pool3d_backward_out_single_batch_frame<scalar_t>(
gradInput, gradOutput, indices,
numPlanes,
inputT, inputH, inputW,
outputT, outputH, outputW
);
return;
}
at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) {
for (const auto batch : c10::irange(start, end)) {
fractional_max_pool3d_backward_out_single_batch_frame<scalar_t>(
gradInput + batch * numPlanes * inputW * inputH * inputT,
gradOutput + batch * numPlanes * outputW * outputH * outputT,
indices + batch * numPlanes * outputW * outputH * outputT,
numPlanes,
inputT, inputH, inputW,
outputT, outputH, outputW
);
}
});
}
void fractional_max_pool3d_backward_out_cpu_template(
const Tensor& input,
const Tensor& gradOutput_,
Tensor& gradInput,
IntArrayRef output_size,
IntArrayRef pool_size /* unused */,
const Tensor& indices) {
int64_t outputT = output_size[0];
int64_t outputH = output_size[1];
int64_t outputW = output_size[2];
int64_t numBatch = 1;
int64_t planeDim = 0;
int64_t timeDim = 1;
int64_t heightDim = 2;
int64_t widthDim = 3;
int64_t ndims = input.ndimension();
if (ndims == 5) {
numBatch = input.size(0);
planeDim = 1;
heightDim++;
widthDim++;
timeDim++;
}
/* sizes */
int64_t numPlanes = input.size(planeDim);
int64_t inputT = input.size(timeDim);
int64_t inputH = input.size(heightDim);
int64_t inputW = input.size(widthDim);
TORCH_CHECK(outputT == gradOutput_.size(timeDim),
"fractional_max_pool3d_backward_out(): gradOutput time unexpected");
TORCH_CHECK(outputH == gradOutput_.size(heightDim),
"fractional_max_pool3d_backward_out(): ",
"gradOutput height unexpected");
TORCH_CHECK(outputW == gradOutput_.size(widthDim),
"fractional_max_pool3d_backward_out(): gradOutput width unexpected");
/* get contiguous gradOutput */
auto gradOutput = gradOutput_.contiguous();
/* resize */
gradInput.resize_as_(input);
gradInput.zero_();
/* backprop */
AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"fractional_max_pool3d_backward_out_frame",
[&]{
fractional_max_pool3d_backward_out_frame<scalar_t>(
gradInput.data_ptr<scalar_t>(),
gradOutput.data_ptr<scalar_t>(),
indices.data_ptr<int64_t>(),
numBatch, numPlanes,
inputT, inputH, inputW,
outputT, outputH, outputW
);
}
);
}
}// anonymous namespace
Tensor& fractional_max_pool3d_backward_out_cpu(const at::Tensor& gradOutput_,
const at::Tensor& input,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& indices,
at::Tensor& gradInput) {
fractional_max_pool3d_backward_out_cpu_template(
input,
gradOutput_,
gradInput,
output_size,
pool_size,
indices);
return gradInput;
}
Tensor fractional_max_pool3d_backward_cpu(
const at::Tensor& gradOutput_,
const at::Tensor& input,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& indices) {
Tensor gradInput = at::empty({0}, input.options());
fractional_max_pool3d_backward_out_cpu_template(
input,
gradOutput_,
gradInput,
output_size,
pool_size,
indices);
return gradInput;
}
}// native
}// at