| #include <ATen/native/GridSampler.h> |
| #include <ATen/ATen.h> |
| #include <ATen/Device.h> |
| #include <ATen/NativeFunctions.h> |
| #include <ATen/Parallel.h> |
| #include <c10/core/Layout.h> |
| #include <ATen/cpu/vml.h> |
| #include <ATen/native/cpu/GridSamplerKernel.h> |
| #include <c10/util/Exception.h> |
| |
| namespace at { namespace native { |
| |
| using at::native::detail::GridSamplerInterpolation; |
| using at::native::detail::GridSamplerPadding; |
| |
| namespace { |
| |
| template<typename scalar_t> |
| Tensor grid_sampler_3d_cpu_impl(const Tensor& input, const Tensor& grid, |
| GridSamplerInterpolation interpolation_mode, |
| GridSamplerPadding padding_mode, |
| bool align_corners) { |
| int64_t N = input.size(0); |
| int64_t C = input.size(1); |
| int64_t inp_D = input.size(2); |
| int64_t inp_H = input.size(3); |
| int64_t inp_W = input.size(4); |
| int64_t out_D = grid.size(1); |
| int64_t out_H = grid.size(2); |
| int64_t out_W = grid.size(3); |
| auto output = at::empty({N, C, out_D, out_H, out_W}, input.options()); |
| int64_t inp_sN = input.stride(0); |
| int64_t inp_sC = input.stride(1); |
| int64_t inp_sD = input.stride(2); |
| int64_t inp_sH = input.stride(3); |
| int64_t inp_sW = input.stride(4); |
| int64_t grid_sN = grid.stride(0); |
| int64_t grid_sD = grid.stride(1); |
| int64_t grid_sH = grid.stride(2); |
| int64_t grid_sW = grid.stride(3); |
| int64_t grid_sCoor = grid.stride(4); |
| int64_t out_sN = output.stride(0); |
| int64_t out_sC = output.stride(1); |
| int64_t out_sD = output.stride(2); |
| int64_t out_sH = output.stride(3); |
| int64_t out_sW = output.stride(4); |
| scalar_t *inp_ptr = input.data_ptr<scalar_t>(); |
| scalar_t *out_ptr = output.data_ptr<scalar_t>(); |
| scalar_t *grid_ptr = grid.data_ptr<scalar_t>(); |
| // loop over each output pixel |
| at::parallel_for(0, N, 0, [&](int64_t start, int64_t end) { |
| for (int64_t n = start; n < end; ++n) { |
| scalar_t *grid_ptr_N = grid_ptr + n * grid_sN; |
| scalar_t *inp_ptr_N = inp_ptr + n * inp_sN; |
| for (int64_t d = 0; d < out_D; ++d) { |
| for (int64_t h = 0; h < out_H; ++h) { |
| for (int64_t w = 0; w < out_W; ++w) { |
| // get the corresponding input x, y, z co-ordinates from grid |
| scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW; |
| scalar_t ix = *grid_ptr_NDHW; |
| scalar_t iy = grid_ptr_NDHW[grid_sCoor]; |
| scalar_t iz = grid_ptr_NDHW[2 * grid_sCoor]; |
| |
| ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); |
| iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); |
| iz = grid_sampler_compute_source_index(iz, inp_D, padding_mode, align_corners); |
| |
| if (interpolation_mode == GridSamplerInterpolation::Bilinear) { |
| // get corner pixel values from (x, y, z) |
| // for 4d, we used north-east-south-west |
| // for 5d, we add top-bottom |
| int64_t ix_tnw = static_cast<int64_t>(std::floor(ix)); |
| int64_t iy_tnw = static_cast<int64_t>(std::floor(iy)); |
| int64_t iz_tnw = static_cast<int64_t>(std::floor(iz)); |
| |
| int64_t ix_tne = ix_tnw + 1; |
| int64_t iy_tne = iy_tnw; |
| int64_t iz_tne = iz_tnw; |
| |
| int64_t ix_tsw = ix_tnw; |
| int64_t iy_tsw = iy_tnw + 1; |
| int64_t iz_tsw = iz_tnw; |
| |
| int64_t ix_tse = ix_tnw + 1; |
| int64_t iy_tse = iy_tnw + 1; |
| int64_t iz_tse = iz_tnw; |
| |
| int64_t ix_bnw = ix_tnw; |
| int64_t iy_bnw = iy_tnw; |
| int64_t iz_bnw = iz_tnw + 1; |
| |
| int64_t ix_bne = ix_tnw + 1; |
| int64_t iy_bne = iy_tnw; |
| int64_t iz_bne = iz_tnw + 1; |
| |
| int64_t ix_bsw = ix_tnw; |
| int64_t iy_bsw = iy_tnw + 1; |
| int64_t iz_bsw = iz_tnw + 1; |
| |
| int64_t ix_bse = ix_tnw + 1; |
| int64_t iy_bse = iy_tnw + 1; |
| int64_t iz_bse = iz_tnw + 1; |
| |
| // get surfaces to each neighbor: |
| scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); |
| scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); |
| scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); |
| scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); |
| scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); |
| scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); |
| scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); |
| scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); |
| |
| // calculate bilinear weighted pixel value and set output pixel |
| scalar_t *out_ptr_NCDHW = out_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW; |
| scalar_t *inp_ptr_NC = inp_ptr_N; |
| for (int c = 0; c < C; ++c, out_ptr_NCDHW += out_sC, inp_ptr_NC += inp_sC) { |
| // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne |
| // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse |
| // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne |
| // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse |
| *out_ptr_NCDHW = static_cast<scalar_t>(0); |
| if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { |
| *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; |
| } |
| if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { |
| *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; |
| } |
| if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { |
| *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; |
| } |
| if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { |
| *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; |
| } |
| if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { |
| *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; |
| } |
| if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { |
| *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; |
| } |
| if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { |
| *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; |
| } |
| if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { |
| *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; |
| } |
| } |
| } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { |
| int64_t ix_nearest = static_cast<int64_t>(std::round(ix)); |
| int64_t iy_nearest = static_cast<int64_t>(std::round(iy)); |
| int64_t iz_nearest = static_cast<int64_t>(std::round(iz)); |
| |
| // assign nearest neighor pixel value to output pixel |
| scalar_t *out_ptr_NCDHW = out_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW; |
| scalar_t *inp_ptr_NC = inp_ptr_N; |
| for (int c = 0; c < C; ++c, out_ptr_NCDHW += out_sC, inp_ptr_NC += inp_sC) { |
| if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) { |
| *out_ptr_NCDHW = inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW]; |
| } else { |
| *out_ptr_NCDHW = static_cast<scalar_t>(0); |
| } |
| } |
| } |
| } |
| } |
| } |
| } |
| }); |
| return output; |
| } |
| |
| template<typename scalar_t> |
| std::tuple<Tensor, Tensor> |
| grid_sampler_3d_backward_cpu_impl(const Tensor& grad_output, |
| const Tensor& input, const Tensor& grid, |
| GridSamplerInterpolation interpolation_mode, |
| GridSamplerPadding padding_mode, |
| bool align_corners) { |
| auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| // If interpolation mode is Nearest, then grad_grid is not filled in the |
| // loop below. |
| if (interpolation_mode == GridSamplerInterpolation::Nearest) { |
| grad_grid.zero_(); |
| } |
| int64_t N = input.size(0); |
| int64_t C = input.size(1); |
| int64_t inp_D = input.size(2); |
| int64_t inp_H = input.size(3); |
| int64_t inp_W = input.size(4); |
| int64_t out_D = grid.size(1); |
| int64_t out_H = grid.size(2); |
| int64_t out_W = grid.size(3); |
| int64_t inp_sN = input.stride(0); |
| int64_t inp_sC = input.stride(1); |
| int64_t inp_sD = input.stride(2); |
| int64_t inp_sH = input.stride(3); |
| int64_t inp_sW = input.stride(4); |
| int64_t grid_sN = grid.stride(0); |
| int64_t grid_sD = grid.stride(1); |
| int64_t grid_sH = grid.stride(2); |
| int64_t grid_sW = grid.stride(3); |
| int64_t grid_sCoor = grid.stride(4); |
| int64_t gOut_sN = grad_output.stride(0); |
| int64_t gOut_sC = grad_output.stride(1); |
| int64_t gOut_sD = grad_output.stride(2); |
| int64_t gOut_sH = grad_output.stride(3); |
| int64_t gOut_sW = grad_output.stride(4); |
| int64_t gInp_sN = grad_input.stride(0); |
| int64_t gInp_sC = grad_input.stride(1); |
| int64_t gInp_sD = grad_input.stride(2); |
| int64_t gInp_sH = grad_input.stride(3); |
| int64_t gInp_sW = grad_input.stride(4); |
| int64_t gGrid_sN = grad_grid.stride(0); |
| int64_t gGrid_sW = grad_grid.stride(3); |
| scalar_t *inp_ptr = input.data_ptr<scalar_t>(); |
| scalar_t *grid_ptr = grid.data_ptr<scalar_t>(); |
| scalar_t *gOut_ptr = grad_output.data_ptr<scalar_t>(); |
| scalar_t *gInp_ptr = grad_input.data_ptr<scalar_t>(); |
| scalar_t *gGrid_ptr = grad_grid.data_ptr<scalar_t>(); |
| // loop over each output pixel |
| at::parallel_for(0, N, 0, [&](int64_t start, int64_t end) { |
| for (int64_t n = start; n < end; ++n) { |
| scalar_t *grid_ptr_N = grid_ptr + n * grid_sN; |
| scalar_t *inp_ptr_N = inp_ptr + n * inp_sN; |
| scalar_t *gGrid_ptr_NDHW = gGrid_ptr + n * gGrid_sN; |
| for (int64_t d = 0; d < out_D; ++d) { |
| for (int64_t h = 0; h < out_H; ++h) { |
| for (int64_t w = 0; w < out_W; ++w, gGrid_ptr_NDHW += gGrid_sW /* grad_grid is contiguous */ ) { |
| // get the corresponding input x, y, z co-ordinates from grid |
| scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW; |
| scalar_t ix = *grid_ptr_NDHW; |
| scalar_t iy = grid_ptr_NDHW[grid_sCoor]; |
| scalar_t iz = grid_ptr_NDHW[2 * grid_sCoor]; |
| |
| // multipliers for gradients on ix, iy, and iz |
| scalar_t gix_mult, giy_mult, giz_mult; |
| ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult); |
| iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult); |
| iz = grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult); |
| |
| if (interpolation_mode == GridSamplerInterpolation::Bilinear) { |
| // get corner pixel values from (x, y, z) |
| // for 4d, we used north-east-south-west |
| // for 5d, we add top-bottom |
| int64_t ix_tnw = static_cast<int64_t>(std::floor(ix)); |
| int64_t iy_tnw = static_cast<int64_t>(std::floor(iy)); |
| int64_t iz_tnw = static_cast<int64_t>(std::floor(iz)); |
| |
| int64_t ix_tne = ix_tnw + 1; |
| int64_t iy_tne = iy_tnw; |
| int64_t iz_tne = iz_tnw; |
| |
| int64_t ix_tsw = ix_tnw; |
| int64_t iy_tsw = iy_tnw + 1; |
| int64_t iz_tsw = iz_tnw; |
| |
| int64_t ix_tse = ix_tnw + 1; |
| int64_t iy_tse = iy_tnw + 1; |
| int64_t iz_tse = iz_tnw; |
| |
| int64_t ix_bnw = ix_tnw; |
| int64_t iy_bnw = iy_tnw; |
| int64_t iz_bnw = iz_tnw + 1; |
| |
| int64_t ix_bne = ix_tnw + 1; |
| int64_t iy_bne = iy_tnw; |
| int64_t iz_bne = iz_tnw + 1; |
| |
| int64_t ix_bsw = ix_tnw; |
| int64_t iy_bsw = iy_tnw + 1; |
| int64_t iz_bsw = iz_tnw + 1; |
| |
| int64_t ix_bse = ix_tnw + 1; |
| int64_t iy_bse = iy_tnw + 1; |
| int64_t iz_bse = iz_tnw + 1; |
| |
| // get surfaces to each neighbor: |
| scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); |
| scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); |
| scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); |
| scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); |
| scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); |
| scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); |
| scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); |
| scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); |
| |
| scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0), giz = static_cast<scalar_t>(0); |
| scalar_t *gOut_ptr_NCDHW = gOut_ptr + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; |
| scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN; |
| scalar_t *inp_ptr_NC = inp_ptr_N; |
| // calculate bilinear weighted pixel value and set output pixel |
| for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) { |
| scalar_t gOut = *gOut_ptr_NCDHW; |
| |
| // calculate and set grad_input |
| safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut); |
| safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut); |
| safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut); |
| safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut); |
| safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut); |
| safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut); |
| safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut); |
| safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut); |
| |
| // calculate grad_grid |
| if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { |
| scalar_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW]; |
| gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut; |
| giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut; |
| giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut; |
| } |
| if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { |
| scalar_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW]; |
| gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut; |
| giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut; |
| giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut; |
| } |
| if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { |
| scalar_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW]; |
| gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut; |
| giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut; |
| giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut; |
| } |
| if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { |
| scalar_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW]; |
| gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut; |
| giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut; |
| giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut; |
| } |
| if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { |
| scalar_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW]; |
| gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut; |
| giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut; |
| giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut; |
| } |
| if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { |
| scalar_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW]; |
| gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut; |
| giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut; |
| giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut; |
| } |
| if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { |
| scalar_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW]; |
| gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut; |
| giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut; |
| giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut; |
| } |
| if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { |
| scalar_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW]; |
| gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut; |
| giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut; |
| giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut; |
| } |
| } |
| |
| // assuming grad_grid is contiguous |
| gGrid_ptr_NDHW[0] = gix_mult * gix; |
| gGrid_ptr_NDHW[1] = giy_mult * giy; |
| gGrid_ptr_NDHW[2] = giz_mult * giz; |
| } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { |
| int64_t ix_nearest = static_cast<int64_t>(std::round(ix)); |
| int64_t iy_nearest = static_cast<int64_t>(std::round(iy)); |
| int64_t iz_nearest = static_cast<int64_t>(std::round(iz)); |
| |
| // assign nearest neighor pixel value to output pixel |
| scalar_t *gOut_ptr_NCDHW = gOut_ptr + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; |
| scalar_t *gInp_ptr_NC = gInp_ptr + n * gInp_sN; |
| for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC) { |
| // calculate and set grad_input |
| safe_add_3d(gInp_ptr_NC, iz_nearest, iy_nearest, ix_nearest, |
| gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, *gOut_ptr_NCDHW); |
| } |
| } |
| } |
| } |
| } |
| } |
| }); |
| return std::make_tuple(grad_input, grad_grid); |
| } |
| |
| } // namespace |
| |
| // No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. |
| Tensor grid_sampler_2d_cpu(const Tensor& input, const Tensor& grid, |
| int64_t interpolation_mode, int64_t padding_mode, |
| bool align_corners) { |
| return grid_sampler_2d_cpu_kernel( |
| kCPU, input, grid, interpolation_mode, padding_mode, align_corners); |
| } |
| |
| DEFINE_DISPATCH(grid_sampler_2d_cpu_kernel); |
| |
| |
| // No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. |
| Tensor grid_sampler_3d_cpu(const Tensor& input, const Tensor& grid, |
| int64_t interpolation_mode, int64_t padding_mode, |
| bool align_corners) { |
| return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler3d_cpu", [&] { |
| return grid_sampler_3d_cpu_impl<scalar_t>( |
| input, grid, static_cast<GridSamplerInterpolation>(interpolation_mode), |
| static_cast<GridSamplerPadding>(padding_mode), align_corners); |
| }); |
| } |
| |
| // No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. |
| std::tuple<Tensor, Tensor> |
| grid_sampler_2d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid, |
| int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { |
| return grid_sampler_2d_backward_cpu_kernel( |
| kCPU, grad_output, input, grid, interpolation_mode, padding_mode, align_corners); |
| } |
| |
| DEFINE_DISPATCH(grid_sampler_2d_backward_cpu_kernel); |
| |
| // No shape checking needed here. See # NOTE [ grid_sampler Native Functions ]. |
| std::tuple<Tensor, Tensor> |
| grid_sampler_3d_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& grid, |
| int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { |
| return AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_sampler_3d_backward_cpu", [&] { |
| return grid_sampler_3d_backward_cpu_impl<scalar_t>( |
| grad_output, input, grid, |
| static_cast<GridSamplerInterpolation>(interpolation_mode), |
| static_cast<GridSamplerPadding>(padding_mode), align_corners); |
| }); |
| } |
| |
| Tensor grid_sampler(const Tensor& input, const Tensor& grid, |
| int64_t interpolation_mode, int64_t padding_mode, |
| bool align_corners) { |
| TORCH_CHECK( |
| input.defined() && grid.defined(), |
| "grid_sampler(): expected input and grid to not be undefined, but input " |
| "is ", input, " and grid is ", grid); |
| auto input_opt = input.options(); |
| auto grid_opt = grid.options(); |
| TORCH_CHECK( |
| input_opt.device() == grid_opt.device(), |
| "grid_sampler(): expected input and grid to be on same device, but input " |
| "is on ", input_opt.device(), " and grid is on ", grid_opt.device()); |
| TORCH_CHECK( |
| input_opt.dtype() == grid_opt.dtype(), |
| "grid_sampler(): expected input and grid to have same dtype, but input " |
| "has ", input_opt.dtype(), " and grid has ", grid_opt.dtype()); |
| TORCH_CHECK( |
| input_opt.layout() == kStrided && grid_opt.layout() == kStrided, |
| "grid_sampler(): expected input and grid to have torch.strided layout, but " |
| "input has ", input_opt.layout(), " and grid has ", grid_opt.layout()); |
| TORCH_CHECK( |
| (input.dim() == 4 || input.dim() == 5) && input.dim() == grid.dim(), |
| "grid_sampler(): expected 4D or 5D input and grid with same number of " |
| "dimensions, but got input with sizes ", input.sizes(), |
| " and grid with sizes ", grid.sizes()); |
| TORCH_CHECK( |
| input.size(0) == grid.size(0), |
| "grid_sampler(): expected grid and input to have same batch size, but got " |
| "input with sizes ", input.sizes(), " and grid with sizes ", grid.sizes()); |
| TORCH_CHECK( |
| grid.size(-1) == input.dim() - 2, |
| "grid_sampler(): expected grid to have size ", input.dim() - 2, " in last " |
| "dimension, but got grid with sizes ", grid.sizes()); |
| for (int64_t i = 2; i < input.dim(); i++) { |
| TORCH_CHECK(input.size(i) > 0, |
| "grid_sampler(): expected input to have non-empty spatial dimensions, " |
| "but input has sizes ", input.sizes(), " with dimension ", i, " being " |
| "empty"); |
| } |
| // cudnn does not support inputs larger than 1024 |
| if (at::native::cudnn_is_acceptable(input) && |
| at::native::cudnn_is_acceptable(grid) && |
| static_cast<GridSamplerInterpolation>(interpolation_mode) == GridSamplerInterpolation::Bilinear && |
| static_cast<GridSamplerPadding>(padding_mode) == GridSamplerPadding::Zeros && |
| align_corners && |
| input.dim() == 4 && |
| input.size(1) <= 1024) { |
| return cudnn_grid_sampler(input, grid); |
| } |
| if (input.dim() == 4) { |
| return at::grid_sampler_2d(input, grid, interpolation_mode, padding_mode, align_corners); |
| } else { |
| return at::grid_sampler_3d(input, grid, interpolation_mode, padding_mode, align_corners); |
| } |
| } |
| |
| }} // namespace at::native |