blob: a525165bf8df4ba0a5b6b30830f2770cb589425f [file] [log] [blame]
#pragma once
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/cpu/vml.h>
#include <tuple>
namespace at { namespace native {
using forward_2d_fn = Tensor(*)(const Tensor &, const Tensor &, int64_t, int64_t, bool);
using backward_2d_fn = std::tuple<Tensor, Tensor>(*)(const Tensor &, const Tensor &, const Tensor &, int64_t, int64_t, bool);
DECLARE_DISPATCH(forward_2d_fn, grid_sampler_2d_cpu_kernel);
DECLARE_DISPATCH(backward_2d_fn, grid_sampler_2d_backward_cpu_kernel);
}} // namespace at::native