blob: c9b2e3d999ad233983f1df0ba7beef998a49f92e [file] [log] [blame]
#pragma once
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <tuple>
namespace at {
namespace native {
TORCH_API std::pair<Tensor, Tensor> softmax_sparse_input_preprocessing(
const Tensor& input_,
const int64_t dim_,
const bool half_to_float,
CheckedFrom function_name);
TORCH_API std::tuple<Tensor, Tensor, Tensor> softmax_backward_sparse_input_preprocessing(
const Tensor& grad_,
const Tensor& output_,
int64_t dim_,
const Tensor& input_,
CheckedFrom function_name);
} // namespace native
} // namespace at