blob: c3ec153a33cadb9892664282455897922c0e0418 [file] [log] [blame]
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Config.h>
#if !AT_CUDNN_ENABLED()
namespace at { namespace native {
// See Note [ATen preprocessor philosophy]
Tensor cudnn_grid_sampler_forward(
const Tensor& input_t, const Tensor& grid_t) {
throw std::runtime_error("cudnn_grid_sampler_forward: ATen not compiled with cuDNN support");
}
std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward(
const Tensor& input_t, const Tensor& grid_t,
const Tensor& grad_output_t) {
throw std::runtime_error("cudnn_grid_sampler_backward: ATen not compiled with cuDNN support");
}
}}
#else // AT_CUDNN_ENABLED
#include <ATen/cudnn/Descriptors.h>
#include <ATen/cudnn/Types.h>
#include <ATen/cudnn/Utils.h>
#include <ATen/Check.h>
// TODO: descriptor checking
namespace at { namespace native {
namespace {
void setSamplerDescriptor(SpatialTransformerDescriptor& desc, cudnnDataType_t dataType, const at::Tensor& tensor)
{
int inputSize[4] = {0};
for (int i = 0; i < tensor.dim(); ++i) {
inputSize[i] = (int) tensor.size(i);
}
desc.set(dataType, 4, inputSize);
}
void checkGridSize(CheckedFrom c, TensorArg grid, TensorArg input)
{
// assert size of grid is n*h*w*2
// FYI: grid is between [-1, 1], where -1 left most pixel,
// 1 represents right most pixel (and hence 0 is the center pixel)
// if grid has values >1 or <-1, those values are ignored
checkContiguous(c, grid);
checkDim(c, grid, 4);
// TODO: Maybe more user friendly to report where the expected size
// came from
checkSize(c, grid, 0, input->size(0));
checkSize(c, grid, 3, 2);
}
} // namespace
Tensor cudnn_grid_sampler_forward(
const Tensor& input_t, const Tensor& grid_t)
{
TensorArg input{ contiguousIfZeroInStrides(input_t), "input", 1 },
grid{ grid_t.contiguous(), "grid", 2 };
CheckedFrom c = "cudnn_grid_sampler_forward";
setCuDNNStreamToCurrent();
checkAllSameGPU(c, {input, grid});
checkAllSameType(c, {input, grid});
checkGridSize(c, grid, input);
checkDim(c, input, 4);
auto output_t = input->type().tensor();
output_t.resize_({input->size(0), input->size(1), grid->size(1), grid->size(2)});
TensorDescriptor idesc{ *input }; // input descriptor
TensorDescriptor odesc{ output_t }; // output descriptor
SpatialTransformerDescriptor desc; // sampler descriptor
auto handle = getCudnnHandle();
auto dataType = getCudnnDataType(*input);
setSamplerDescriptor(desc, dataType, output_t);
Constant one(dataType, 1);
Constant zero(dataType, 0);
CUDNN_CHECK(cudnnSpatialTfSamplerForward(
handle, desc.desc(),
&one, idesc.desc(), input->data_ptr(),
grid->data_ptr(),
&zero, odesc.desc(), output_t.data_ptr()
));
return output_t;
}
// NB: CuDNN does not support output mask; you always get both
// gradients.
std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward(
const Tensor& input_t, const Tensor& grid_t,
const Tensor& grad_output_t)
{
TensorArg input{ contiguousIfZeroInStrides(input_t), "input", 1 },
grid{ grid_t.contiguous(), "grid", 2 },
grad_output{ contiguousIfZeroInStrides(grad_output_t), "grad_output", 3 };
CheckedFrom c = "cudnn_grid_sampler_backward";
setCuDNNStreamToCurrent();
checkAllSameGPU(c, {input, grad_output, grid});
checkGridSize(c, grid, input);
checkDim(c, input, 4);
checkDim(c, grad_output, 4);
auto grad_input_t = input->type().tensor();
grad_input_t.resize_(input->sizes());
auto grad_grid_t = grid->type().tensor();
grad_grid_t.resize_(grid->sizes());
TensorDescriptor idesc{ *input }; // input descriptor
TensorDescriptor odesc{ *grad_output }; // grad_output descriptor
TensorDescriptor gdesc{ grad_input_t }; // grad_input descriptor
SpatialTransformerDescriptor desc; // sampler descriptor
auto handle = getCudnnHandle();
auto dataType = getCudnnDataType(*input);
setSamplerDescriptor(desc, dataType, *grad_output);
Constant one(dataType, 1);
Constant zero(dataType, 0);
CUDNN_CHECK(cudnnSpatialTfSamplerBackward(
handle, desc.desc(),
&one, idesc.desc(), input->data_ptr(),
&zero, gdesc.desc(), grad_input_t.data_ptr(),
&one, odesc.desc(), grad_output->data_ptr(),
// intruigingly, the outputs don't need descriptors
grid->data_ptr(),
&zero, grad_grid_t.data_ptr()
));
return std::tuple<Tensor, Tensor>{ grad_input_t, grad_grid_t };
}
}} // namespace at::cudnn
#endif