blob: b226b68bbca15a2261468f97375205b51af59883 [file] [log] [blame]
#include <ATen/Parallel.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/UnfoldBackward.h>
#include <ATen/native/cpu/Loops.h>
#include <c10/util/irange.h>
#if (defined(_WIN32) || defined(_WIN64))
#define RESTRICT __restrict
#else
#define RESTRICT __restrict__
#endif
// Note on naming: it is unconventional.
// grad_in does not mean that it is a gradient wrt to input,
// grad_in/grad_out is just an input/output of unfold_backward kernel.
//
// unfold_backward, the algorithm.
//
// Consider out = in.unfold(dim, size, step), then
// out.shape[dim] == (in.shape[dim] - size) / step + 1,
// out.shape[-1] == size.
// out.dims() == in.dims() + 1
//
// unfold_backward receives grad_in and returns grad_out such that
// grad_in.shape == out.shape,
// grad_out.shape = in.shape.
//
// unfold_backward considers the following two cases:
// case1. step >= size.
// case2. step < size.
//
// case1. step >= size.
// In this case the iteration takes over grad_in and performs the following copy:
// grad_out[..., i_out_dim,...] = grad_in[..., i_in_dim,..., i_in_last_dim],
// where i_out_dim = i_in_dim * step + i_in_last_dim.
//
// case2. step < size.
// In this case the iteration takes over grad_out,
// where grad_out[...,i_out_dim,...] accumulates all values
// grad_in[...,i_in_dim,...,i_in_last_dim], where
// i_in_dim is in [left_idx_fold, right_idx_fold],
// i_in_last_dim = i_out_dim - i_in_dim * step,
// left_idx_fold = (i_out_dim - size) / step
// if i_out_dim in [left_idx_fold * step, left_idx_fold * step + size)
// else (i_out_dim - size) / step + 1,
// right_idx_fold = i_out_dim / step.
//
// Simply put, given i_out_dim, we find which folds of grad_in
// intersect with i_out_dim, these are precisely [left_idx_fold, right_idx_fold],
// and then the corresponding value of grad_in[...,i_in_dim,...,i_in_last_dim]
// gets added up to grad_out[...,i_out_dim,...].
namespace at {
namespace native {
namespace {
template <typename scalar_t>
void _unfold_backward_internal_kernel(
TensorIterator& iter,
int64_t size,
int64_t step,
int64_t grad_in_dim_stride,
int64_t grad_in_last_dim_stride,
int64_t grad_in_dim_size,
int64_t grad_out_dim_stride,
bool is_step_ge_size
) {
if (iter.numel() == 0) {
return;
}
auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
auto* RESTRICT grad_out_ptr = data[0];
auto* RESTRICT grad_in_ptr = data[1];
auto* RESTRICT idx_dim_ptr = data[2];
if (is_step_ge_size) {
auto* RESTRICT idx_last_dim_ptr = data[3];
for (const auto elem : c10::irange(nelems)) {
(void)elem; //Suppress unused variable warning
auto* RESTRICT grad_out_data = reinterpret_cast<scalar_t*>(grad_out_ptr);
auto* RESTRICT grad_in_data = reinterpret_cast<scalar_t*>(grad_in_ptr);
auto idx_dim = *reinterpret_cast<int64_t*>(idx_dim_ptr);
auto idx_last_dim = *reinterpret_cast<int64_t*>(idx_last_dim_ptr);
auto grad_out_idx_dim = idx_dim * step + idx_last_dim;
grad_out_data[grad_out_idx_dim * grad_out_dim_stride] = *grad_in_data;
grad_out_ptr += strides[0];
grad_in_ptr += strides[1];
idx_dim_ptr += strides[2];
idx_last_dim_ptr += strides[3];
}
}
else {
for (const auto elem : c10::irange(nelems)) {
(void)elem; //Suppress unused variable warning
auto* RESTRICT grad_out_data = reinterpret_cast<scalar_t*>(grad_out_ptr);
auto* RESTRICT grad_in_data = reinterpret_cast<scalar_t*>(grad_in_ptr);
auto idx_dim = *reinterpret_cast<int64_t*>(idx_dim_ptr);
// left_fold potentially intersecting with idx_dim
// is either (idx_dim - size) / step or the next integer.
int64_t left_fold_idx = (idx_dim > size) ? (idx_dim - size) / step : 0;
if (!(left_fold_idx * step <= idx_dim && idx_dim < left_fold_idx * step + size)) {
++left_fold_idx;
}
auto right_fold_idx = idx_dim / step;
right_fold_idx = (right_fold_idx >= grad_in_dim_size)
? (grad_in_dim_size - 1) : right_fold_idx;
for (auto fold_idx = left_fold_idx; fold_idx <= right_fold_idx; ++fold_idx) {
auto idx_last_dim = idx_dim - fold_idx * step;
*grad_out_data += grad_in_data[fold_idx * grad_in_dim_stride
+ idx_last_dim * grad_in_last_dim_stride];
}
grad_out_ptr += strides[0];
grad_in_ptr += strides[1];
idx_dim_ptr += strides[2];
}
}
};
iter.for_each(loop);
}
void unfold_backward_cpu_kernel(
Tensor& grad_out,
const Tensor& grad_in,
int64_t dim,
int64_t size,
int64_t step
) {
dim = maybe_wrap_dim(dim, grad_out.dim());
// last dim stores the folds
auto last_dim = maybe_wrap_dim(-1, grad_in.dim());
auto grad_in_dim_stride = ensure_nonempty_stride(grad_in, dim);
auto grad_in_last_dim_stride = ensure_nonempty_stride(grad_in, last_dim);
auto grad_in_dim_size = ensure_nonempty_size(grad_in, dim);
auto grad_out_dim_stride = ensure_nonempty_stride(grad_out, dim);
auto is_step_ge_size = (step >= size);
TensorIterator iter =
is_step_ge_size ?
_make_unfold_backward_iter_over_grad_in(
grad_out, grad_in, dim, size, step
) :
_make_unfold_backward_iter_over_grad_out(
grad_out, grad_in, dim, size, step
);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
iter.dtype(),
"unfold_backward_cpu", [&] {
_unfold_backward_internal_kernel<scalar_t>(
iter,
size,
step,
grad_in_dim_stride,
grad_in_last_dim_stride,
grad_in_dim_size,
grad_out_dim_stride,
is_step_ge_size
);
}
);
}
}
REGISTER_DISPATCH(unfold_backward_stub, &unfold_backward_cpu_kernel);
}} // namespace at::native