blob: ec4a2d7bf64c7ea4f6e31a207ba2f013431c3e4e [file] [log] [blame]
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/UnfoldBackward.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/unfold_backward_native.h>
#include <ATen/ops/zeros.h>
#endif
namespace at::native {
DEFINE_DISPATCH(unfold_backward_stub);
Tensor unfold_backward(
const Tensor& grad,
IntArrayRef input_sizes,
int64_t dim,
int64_t size,
int64_t step
) {
auto grad_input = at::zeros(input_sizes, grad.options());
if (step >= size) {
auto gI_unfolded = grad_input.unfold(dim, size, step);
gI_unfolded.copy_(grad);
return grad_input;
}
unfold_backward_stub(
grad.device().type(),
grad_input,
grad,
dim, size, step
);
return grad_input;
}
} // namespace at::native