blob: 7d1ae8020bd65fff0fca319bd62c5748047efbee [file] [log] [blame]
#include <ATen/native/ReduceOps.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/ExpandUtils.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/native/ReduceOpsUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/native/TensorDimApply.h>
#include <ATen/native/SharedReduceOps.h>
#include <ATen/core/grad_mode.h>
#include <c10/util/irange.h>
#include <c10/util/SmallBuffer.h>
#include <algorithm>
#include <functional>
#include <limits>
#include <numeric>
#include <vector>
#include <map>
#include <cmath>
#include <cfloat>
#include <type_traits>
namespace at {
namespace meta {
ScalarType check_allany_and_get_output_dtype(
const char* name,
const Tensor& self,
const Tensor& result,
bool keepdim) {
// Refer [all, any : uint8 compatibility]
TORCH_CHECK(
self.layout() == Layout::Strided,
name, " only supports strided layout, got: ",
self.layout());
ScalarType out_dtype;
if (result.defined()) {
// Refer [all, any : uint8 compatibility]
TORCH_CHECK(
result.scalar_type() == ScalarType::Bool ||
result.scalar_type() == ScalarType::Byte,
name, " only supports bool tensor for result, got: ",
result.scalar_type());
out_dtype = result.scalar_type();
} else {
if (self.scalar_type() == ScalarType::Byte) {
out_dtype = self.scalar_type();
} else {
out_dtype = ScalarType::Bool;
}
}
return out_dtype;
}
void check_allany_for_meta(
impl::MetaBase& meta,
const char* name,
const Tensor& self,
int64_t dim,
bool keepdim) {
dim = maybe_wrap_dim(dim, self.dim());
const auto& result = meta.maybe_get_output();
auto out_dtype = check_allany_and_get_output_dtype(name, self, result, keepdim);
auto shape = get_reduction_shape(self, dim, keepdim);
meta.set_output(shape, self.options().dtype(out_dtype));
namedinference::propagate_names_for_reduction(result, self, dim, keepdim);
}
TORCH_META_FUNC2(all, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_allany_for_meta(*this, "all", self, dim, keepdim);
}
TORCH_META_FUNC2(any, dim)(const Tensor& self, int64_t dim, bool keepdim) {
check_allany_for_meta(*this, "any", self, dim, keepdim);
}
TORCH_META_FUNC(argmax)
(const Tensor& self, c10::optional<int64_t> dim, bool keepdim) {
DimVector shape;
if (dim.has_value()) {
auto _dim = maybe_wrap_dim(dim.value(), self.dim());
native::zero_numel_check_dims(self, _dim, "argmax()");
shape = get_reduction_shape(self, _dim, keepdim);
} else {
TORCH_CHECK_INDEX(
self.numel() != 0,
"argmax(): Expected reduction dim to be specified for input.numel() == 0.");
}
set_output(shape, self.options().dtype(kLong));
}
} // namespace meta
namespace native {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(sum_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(nansum_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(std_var_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(prod_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(norm_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(mean_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(and_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(or_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(min_values_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(max_values_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(argmax_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(argmin_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(cumsum_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(cumprod_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(logcumsumexp_stub);
Tensor _logcumsumexp_cpu(const Tensor& self, int64_t dim) {
Tensor result = at::empty_like(self, MemoryFormat::Contiguous);
return _logcumsumexp_out_cpu(self, dim, result);
}
Tensor& _logcumsumexp_out_cpu(const Tensor& self, int64_t dim, Tensor& result) {
logcumsumexp_stub(self.device().type(), result, self, dim);
return result;
}
Tensor logcumsumexp(const Tensor& self, int64_t dim) {
auto result = [&]() {
NoNamesGuard guard;
return at::_logcumsumexp(self, dim);
}();
namedinference::propagate_names(result, self);
return result;
}
Tensor& logcumsumexp_out(const Tensor& self, int64_t dim, Tensor& result) {
check_scalar_type_device_layout_equal(result, self);
{
NoNamesGuard guard;
at::_logcumsumexp_out(result, self.toType(result.scalar_type()), dim);
}
namedinference::propagate_names(result, self);
return result;
}
Tensor _cumsum_cpu(const Tensor& self, int64_t dim) {
Tensor result = at::empty_like(self, MemoryFormat::Contiguous);
cumsum_stub(self.device().type(), result, self, dim);
return result;
}
Tensor& _cumsum_out_cpu(const Tensor& self, int64_t dim, Tensor& result) {
cumsum_stub(self.device().type(), result, self, dim);
return result;
}
Tensor cumsum(const Tensor& self, int64_t dim, c10::optional<ScalarType> dtype) {
auto result = [&]() {
NoNamesGuard guard;
return at::_cumsum(integer_upcast(self, dtype), dim);
}();
namedinference::propagate_names(result, self);
return result;
}
Tensor& cumsum_(Tensor& self, int64_t dim, c10::optional<ScalarType> dtype) {
TORCH_CHECK(
!dtype.has_value() || (self.scalar_type() == dtype.value()),
"provided dtype must match the dtype of self tensor in cumsum. Got ",
toString(self.scalar_type()),
" and ",
toString(dtype.value()),
".");
return at::_cumsum_out(self, self, dim);
}
Tensor& cumsum_out(const Tensor& self, int64_t dim, c10::optional<ScalarType> dtype, Tensor& result) {
// result type is favored over dtype; check that they match if provided (NumPy doesn't check)
TORCH_CHECK(
!dtype.has_value() || (result.scalar_type() == dtype.value()),
"provided dtype must match dtype of result in cumsum. Got ",
toString(result.scalar_type()),
" and ",
toString(dtype.value()),
".");
{
NoNamesGuard guard;
at::_cumsum_out(result, self.toType(result.scalar_type()), dim);
}
namedinference::propagate_names(result, self);
return result;
}
Tensor _cumprod_cpu(const Tensor& self, int64_t dim) {
Tensor result = at::empty_like(self, MemoryFormat::Contiguous);
cumprod_stub(self.device().type(), result, self, dim);
return result;
}
Tensor& _cumprod_out_cpu(const Tensor& self, int64_t dim, Tensor& result) {
cumprod_stub(self.device().type(), result, self, dim);
return result;
}
Tensor cumprod(const Tensor& self, int64_t dim, c10::optional<ScalarType> dtype) {
auto result = [&]() {
NoNamesGuard guard;
return at::_cumprod(integer_upcast(self, dtype), dim);
}();
namedinference::propagate_names(result, self);
return result;
}
Tensor& cumprod_(Tensor& self, int64_t dim, c10::optional<ScalarType> dtype) {
TORCH_CHECK(
!dtype.has_value() || (self.scalar_type() == dtype.value()),
"provided dtype must match the dtype of self tensor in cumprod. Got ",
toString(self.scalar_type()),
" and ",
toString(dtype.value()),
".");
return at::_cumprod_out(self, self, dim);
}
Tensor& cumprod_out(const Tensor& self, int64_t dim, c10::optional<ScalarType> dtype, Tensor& result) {
// result type is favored over dtype; check that they match if provided (NumPy doesn't check)
TORCH_CHECK(
!dtype.has_value() || (result.scalar_type() == dtype.value()),
"provided dtype must match dtype of result in cumprod. Got ",
toString(result.scalar_type()),
" and ",
toString(dtype.value()),
".");
{
NoNamesGuard guard;
at::_cumprod_out(result, self.toType(result.scalar_type()), dim);
}
namedinference::propagate_names(result, self);
return result;
}
Tensor reversed_cumsum(const Tensor& w, int64_t dim) {
return w.flip(dim).cumsum(dim).flip(dim);
}
Tensor cumprod_backward(const Tensor& grad, const Tensor& input, int64_t dim, const Tensor& output) {
/*
We show here how to derive an O(n) gradient formula for
abitrary inputs. It follows via a basic application of the
chain rule together with a number of observations for different
cases. We assume that x is an n-dimensional vector and y = cumprod(x).
In the actual implementation we will need to play a bit with masks
to be able to implement the formulas deduced here for tensors.
We will first deduce the formula for the case when
x[i] != 0 for 1 <= i <= n.
For F : R^n -> R the cost function (we will look at the complex case later),
we have
dF / dx_k = sum_j (dF / dy_j) * (dy_j / dx_k) (1)
The term dF / dy_j is just grad_output[j] (assuming again
everything is one-dimensional).
The term (dy_j / dx_k) is easilly seen to be
if j >= k
dy_j / dx_k = prod_{1 <= i <= j, i != k} x_i
else:
dy_j / dx_k = 0
Note that the indicator (j>=k) can be taken out
by replacing the sum in (1) with a sum from
k <= j <= n.
Thus,
dF / dx_k = sum_{k <= j <= n} grad_output[j] * (dy_j / dx_k)
with
dy_j / dx_k = prod_{1 <= i <= j, i != k} x_i (2)
Note that this last term is just the cumulative product
with k omitted. Thus, if x_k (the input) is nonzero, we can
just express this as
dy_j / dx_k = (prod_{1 <= i <= j} x_i) / x_k
= y_j / x_k
So therefore,
dF / dx_k = sum_{k <= j <= n} grad_output[j] * y_j / x_k
This formula just makes sense when input[i] != 0 for every i.
Assume now that there exists at least a zero in the input.
Denote by z1 the first element 1 <= z1 <= n with input[z1] = 0
and z2 the second element z1 < z2 <= n with input[z2] = 0,
(or z2 = n if there is just one zero in input)
We have three cases.
k > z1:
Looking at (2), we see that dy_j / dx_k = 0, for j >= k, as these terms
all include a x_{z1} which is zero. As such, dF / dx_k = 0 in this case
k < z1:
Reasoning as in the previous case, we see that for these elements we have that
dF / dx_k = sum_{k <= j < z1} grad_output[j] * (dy_j / dx_k)
as the terms of the sum for j in z1 <= j <= n are all zero
k = z1:
Similar to the case k < z1, we have that
dF / dx_z1 = sum_{z1 <= j < z2} grad_output[j] * (dy_j / dx_z1)
This case has a subtlety though. To compute (dy_j / dx_z1), we cannot use the formula
dy_j / dx_z1 = y_j / x_z1
as, y_j = x_z1 = 0 for j >= z1. We need to compute it with the formula for its derivative,
that is:
dy_j / dx_z1 = prod(x[:z1]) * (grad_output[z1] + sum(grad_output[z1+1:z2] * cumprod(x[z1+1:z2])))
When the imputs are complex, this is map is holomorphic. As such, to compute
its backwards is just the conjugate of the usual backwards. This simplifies to
conjugating the input. We may also reuse the output as, since the map is holomorphic,
cumprod(input.conj()) = cumprod(input).conj()
*/
if (input.numel() <= 1) {
return grad;
}
dim = at::maybe_wrap_dim(dim, input.dim());
const int64_t dim_size = input.sizes()[dim];
if (dim_size == 1) {
return grad;
}
// To enable complex support.
// From this line on `input_conj` and output_conj`
// are interchangeable with `input` and `output`.
auto input_conj = input.conj();
auto output_conj = output.conj();
const auto w = output_conj * grad;
const auto is_zero = input == 0;
if (!(is_zero.any().item<uint8_t>())) {
return reversed_cumsum(w, dim).div(input_conj);
}
// If we are not computing a second order gradient, we can use an
// O(n) implementation. The derivative of this implementation is _not_
// the second derivative of cumprod. As such, we fallback to a less efficient
// O(n^2) implementation when at::GradMode::is_enabled().
Tensor grad_input = at::zeros(input.sizes(), grad.options());
if (!at::GradMode::is_enabled()) {
// n.b. This could probably be implemented much faster with a kernel
// From here on we need to use some mask gymnastics to
// account for the tensorial dimensions
// We do a cumsum of the zeros along the dimension.
// For a vector is_zero = [False, True, False, True, False]
// we would have cumsum = [0, 1, 1, 2, 2]
// As such we have (in python code for simplicity)
// The mask for the range [0, z1):
// cumsum == 0
// The indices of the first zero z1 and zeros when
// there is no first zero:
// indices = (cumsum == 1).max(dim, keepdim=True).indices
// The mask for the first zero:
// zeros_like(indices).scatter_(dim, indices, 1.) & cumsum == 1
// Note that the logic_and with cumsum == 1 accounts
// for the case when there is no first zero
const auto cumsum = is_zero.cumsum(dim);
// case k < z1
// select everything before the first zero [0, z1)
auto mask = cumsum == 0;
// equiv to grad_input[mask] = deriv[grad]
grad_input.masked_scatter_(mask,
reversed_cumsum(w.masked_fill(~mask, 0.), dim).div_(input_conj).masked_select(mask));
// select everything from the first zero to the second zero [z1, z2)
mask = cumsum == 1;
// case k = z1
// We start by select the first zero [z1]
// We locate the indices of the first zero using the max function
// We then go from the indices to a mask index_fill_
// When there is no zero in the slice, max will return the index 0.
// To account for this, we need to do an intersection with mask,
// which is true in the range [z1, z2)
const auto first_zero_index = std::get<1>(mask.max(dim, /*keepdim*/ true));
const auto first_zero_mask = at::zeros_like(mask)
.scatter_(dim, first_zero_index, /*src*/ 1)
.logical_and_(mask);
// select everything between the first zero and the second zero (z1, z2)
mask &= ~first_zero_mask;
// here we compute
// dy_j / dx_z1 = sum(cumprod(input[z1+1:z2] * grad[z1+1:z2])) * prod(output[z1-1])
// relu_() necessary as gather does not support negative indices
// finally, we do grad_input[z1] = dy_j / dx_z1
grad_input.masked_scatter_(first_zero_mask,
input_conj.masked_fill(~mask, 1.).cumprod(dim)
.mul_(grad.masked_fill(cumsum != 1, 0.))
.sum(dim, /*keepdim*/true)
.mul_(at::gather(output_conj, dim, (first_zero_index - 1).relu_())
.masked_fill_(first_zero_index == 0, 1.))
.masked_select(first_zero_mask));
} else { // GradMode::enabled()
/*
If the input is nonzero, we need to calculate the dy_j / dx_k
by using the formula (2), called in the code omitted_products.
The way the code calculates it is simply by noting that
prod_{1 <= i <= j, i != k} x_i
= (prod_{1 <= i <= k} x_i) * (prod_{k + 1 <= i <= j} x_i)
the first term is calculated as prods_until_k, which since
doesn't depend in j is easy to vectorize.
The second term (indexed by j) is the cumulative product of
x_{k+1}, x_{k+2}, ..., x_n, and it's named in the code
prods_from_k_pkus_1, and it's calculated as a cumprod.
In order to vectorize this properly, we need to add to
omitted_products the dimensions where k > j, and therefore
dy_j / dx_k = 0, which is done right after the assert.
*/
auto ones_size = input.sizes().vec();
ones_size[dim] = 1;
const Tensor ones = at::ones({1}, grad.options()).expand(ones_size);
Tensor prods_from_k_plus_1;
Tensor omitted_products;
for (const auto k : c10::irange(dim_size)) {
if (k == 0) {
prods_from_k_plus_1 = at::cumprod(input_conj.slice(dim, k + 1), dim);
omitted_products = at::cat({ones, prods_from_k_plus_1}, dim);
} else if (k == dim_size - 1) {
const Tensor prods_until_k = at::prod(input_conj.slice(dim, 0, k), dim, true);
omitted_products = prods_until_k;
} else {
const Tensor prods_until_k = at::prod(input_conj.slice(dim, 0, k), dim, true);
prods_from_k_plus_1 = at::cumprod(input_conj.slice(dim, k+1), dim);
omitted_products = prods_until_k.expand_as(prods_from_k_plus_1) * prods_from_k_plus_1;
omitted_products = at::cat({prods_until_k, omitted_products}, dim);
}
// At this point omitted_products is the same size
// as input, except on the dimension dim where it's
// dim_size - k
TORCH_CHECK(omitted_products.size(dim) == dim_size - k);
grad_input.select(dim, k).copy_(
at::sum(grad.slice(dim, k) * omitted_products,dim));
}
}
return grad_input;
}
// Implement std::is_nan<IntegralType> for MSVC.
namespace {
#ifdef _MSC_VER
template<typename T>
inline typename std::enable_if<std::is_integral<T>::value, bool>::type isnan_(T x) {
return false;
}
template<typename T>
inline typename std::enable_if<!std::is_integral<T>::value, bool>::type isnan_(T x) {
return std::isnan(x);
}
#else
template<typename T>
inline bool isnan_(T x) {
return std::isnan(x);
}
#endif
}
template<typename T1, typename T2, typename Operation>
void cummax_cummin_helper(const T1* self_data, T1* values_data, T2* indices_data,
int self_dim_size, int self_stride, int values_stride, int indices_stride) {
Operation op;
T1 out = self_data[0];
int idx = 0;
for(int i = 0; i < self_dim_size; i++) {
T1 curr_elem = self_data[i*self_stride];
if(isnan_(curr_elem) || (!isnan_(out) && op(curr_elem, out))) {
out = self_data[i*self_stride];
idx = i;
}
values_data[i*values_stride] = out;
indices_data[i*indices_stride] = idx;
}
}
void cummax_helper_cpu(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool,
self.scalar_type(), "cummax_cpu",
[&] {
at::native::tensor_dim_apply3<scalar_t, int64_t>(self, values, indices, dim, cummax_cummin_helper<scalar_t, int64_t, std::greater_equal<scalar_t>>);
});
}
std::tuple<Tensor&, Tensor&> cummax_out(const Tensor& self, int64_t dim, Tensor& values, Tensor& indices) {
check_scalar_type_device_layout_equal(values, self);
check_scalar_type_device_layout_equal(indices, at::empty({0}, self.options().dtype(at::kLong)));
{
NoNamesGuard guard;
at::native::resize_output(values, self.sizes());
at::native::resize_output(indices, self.sizes());
if(self.dim() == 0) {
values.fill_(self);
indices.fill_(0);
} else if(self.numel() != 0) {
dim = maybe_wrap_dim(dim, self.dim());
at::_cummax_helper(self, values, indices, dim);
}
}
namedinference::propagate_names(values, self);
namedinference::propagate_names(indices, self);
return std::forward_as_tuple(values, indices);
}
std::tuple<Tensor, Tensor> cummax(const Tensor& self, int64_t dim) {
auto values = at::empty(self.sizes(), self.options());
auto indices = at::empty(self.sizes(), self.options().dtype(at::kLong));
at::cummax_out(values, indices, self, dim);
return std::make_tuple(values, indices);
}
void cummin_helper_cpu(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool,
self.scalar_type(), "cummin_cpu",
[&] {
at::native::tensor_dim_apply3<scalar_t, int64_t>(self, values, indices, dim, cummax_cummin_helper<scalar_t, int64_t, std::less_equal<scalar_t>>);
});
}
std::tuple<Tensor&, Tensor&> cummin_out(const Tensor& self, int64_t dim, Tensor& values, Tensor& indices) {
check_scalar_type_device_layout_equal(values, self);
check_scalar_type_device_layout_equal(indices, at::empty({0}, self.options().dtype(at::kLong)));
{
NoNamesGuard guard;
at::native::resize_output(values, self.sizes());
at::native::resize_output(indices, self.sizes());
if(self.dim() == 0) {
values.fill_(self);
indices.fill_(0);
} else if(self.numel() != 0) {
dim = maybe_wrap_dim(dim, self.dim());
at::_cummin_helper(self, values, indices, dim);
}
}
namedinference::propagate_names(values, self);
namedinference::propagate_names(indices, self);
return std::forward_as_tuple(values, indices);
}
std::tuple<Tensor, Tensor> cummin(const Tensor& self, int64_t dim) {
auto values = at::empty(self.sizes(), self.options());
auto indices = at::empty(self.sizes(), self.options().dtype(at::kLong));
at::cummin_out(values, indices, self, dim);
return std::make_tuple(values, indices);
}
Tensor cummaxmin_backward(const Tensor& grad, const Tensor& input, const Tensor& indices, int64_t dim) {
if (input.numel() == 0) {
return input;
}
auto result = at::zeros(input.sizes(), input.options());
return result.scatter_add_(dim, indices, grad);
}
static Tensor prepend_append_on_dim(const Tensor& self, const c10::optional<Tensor>& prepend, const c10::optional<Tensor>& append, int64_t dim) {
// Helper for diff that handles prepending and appending when at least one is present
TORCH_INTERNAL_ASSERT(prepend.has_value() || append.has_value(), "either prepend or append must be have value");
if (!prepend.has_value() && append.has_value()) {
return at::cat({self, append.value()}, dim);
} else if (prepend.has_value() && !append.has_value()) {
return at::cat({prepend.value(), self}, dim);
} else {
return at::cat({prepend.value(), self, append.value()}, dim);
}
}
static inline void diff_check_compatible_shape(const Tensor& self, const c10::optional<Tensor>&other, int64_t dim) {
// Helper for diff that checks whether the shape of the tensor to prepend or append
// is compatible with that of input
if (other.has_value()) {
int64_t wrapped_dim = maybe_wrap_dim(dim, self.dim(), false);
TORCH_CHECK(
other.value().dim() == self.dim(),
"diff expects prepend or append to be the same dimension as input");
for (int i = 0; i < other.value().dim(); i++) {
TORCH_CHECK(
other.value().size(i) == self.size(i) || i == wrapped_dim,
"diff expects the shape of tensor to prepend or append to match that of"
" input except along the differencing dimension;"
" input.size(", i, ") = ", self.size(i), ", but got"
" tensor.size(", i, ") = ", other.value().size(i));
}
}
}
static inline void diff_check(const Tensor& self, int64_t n, int64_t dim, const c10::optional<Tensor>&prepend, const c10::optional<Tensor>& append) {
// Helper for diff that checks whether its parameters are valid
TORCH_CHECK(
n == 1,
"diff only supports n = 1 currently. Please file an issue at"
" https://github.com/pytorch/pytorch/issues/new?assignees=&labels=&template=feature-request.md"
" if your use case requires supporting higher-order differences");
TORCH_CHECK(
self.dim() >= 1,
"diff expects input to be at least one-dimensional");
diff_check_compatible_shape(self, prepend, dim);
diff_check_compatible_shape(self, append, dim);
}
static inline Tensor diff_helper(const Tensor& self, int64_t n, int64_t dim) {
auto out_len = self.size(dim) - 1;
if (self.dtype() == at::kBool) {
return at::logical_xor(at::narrow(self, dim, 1, out_len), at::narrow(self, dim, 0, out_len));
}
return at::narrow(self, dim, 1, out_len) - at::narrow(self, dim, 0, out_len);
}
Tensor diff(const Tensor& self, int64_t n, int64_t dim, const c10::optional<Tensor>& prepend, const c10::optional<Tensor>& append) {
diff_check(self, n, dim, prepend, append);
if (!prepend.has_value() && !append.has_value()) {
return diff_helper(self, n, dim);
} else {
auto a = prepend_append_on_dim(self, prepend, append, dim);
return diff_helper(a, n, dim);
}
}
static inline Tensor& diff_out_helper(const Tensor& self, int64_t n, int64_t dim, Tensor& result) {
auto out_len = self.size(dim) - 1;
if (self.dtype() == at::kBool) {
return at::logical_xor_out(result, at::narrow(self, dim, 1, out_len), at::narrow(self, dim, 0, out_len));
}
return at::sub_out(result, at::narrow(self, dim, 1, out_len), at::narrow(self, dim, 0, out_len));
}
Tensor& diff_out(const Tensor& self, int64_t n, int64_t dim, const c10::optional<Tensor>& prepend, const c10::optional<Tensor>& append, Tensor& result) {
diff_check(self, n, dim, prepend, append);
if (!prepend.has_value() && !append.has_value()) {
return diff_out_helper(self, n, dim, result);
} else {
auto a = prepend_append_on_dim(self, prepend, append, dim);
return diff_out_helper(a, n, dim, result);
}
}
void pre_check_gradient(const Tensor& self, c10::optional<int64_t> spacing_size, c10::optional<IntArrayRef> dim, int64_t edge_order) {
// Helper for gradient function to make sure input data satisfies prerequisites
TORCH_CHECK(self.scalar_type() != ScalarType::Byte, "torch.gradient does not support uint8 input.");
if (spacing_size.has_value() && !dim.has_value()) {
TORCH_CHECK(spacing_size.value() == 1 || spacing_size.value() == self.dim(), "torch.gradient expected spacing to be unspecified, a scalar or a list of length ", self.dim(), " but got a list of length ", spacing_size.value());
}
if (spacing_size.has_value() && dim.has_value()) {
TORCH_CHECK(spacing_size.value() == dim.value().size(), "torch.gradient expected spacing to be unspecified, a scalar or it's spacing and dim arguments to have the same length, but got a spacing argument of length ", spacing_size.value(), " and a dim argument of length ", dim.value().size(), "." );
}
TORCH_CHECK(edge_order == 1 || edge_order == 2, "torch.gradient only supports edge_order=1 and edge_order=2.");
for (const auto i : c10::irange(self.dim())) {
TORCH_CHECK(self.size(i) >= edge_order + 1, "torch.gradient expected each dimension size to be at least edge_order+1");
}
if (dim.has_value()) {
// The following function get called to check whether dim argument satisfies prerequisites.
// The output of the function is not used for the computation of gradient.
dim_list_to_bitset(dim.value(), self.dim());
}
}
std::vector<Tensor> gradient_helper(const Tensor& self, TensorList coordinates, IntArrayRef dim, int64_t edge_order) {
for (const auto i : c10::irange(coordinates.size())) {
TORCH_CHECK(self.device() == coordinates[i].device(), "torch.gradient expected each tensor to be on the same device, but got devices ", self.device(), " and ", coordinates[i].device(), "!");
}
std::vector<Tensor> result;
for (const auto i : c10::irange(dim.size())) {
TORCH_CHECK( coordinates[i].dim() == 1, "torch.gradient expected each element of spacing to have one dimension, but got an element with ", coordinates[i].dim(), " dimensions!");
int64_t direction = maybe_wrap_dim(dim[i], self.dim());
Tensor prepend, append;
std::vector<int64_t> shape(self.dim(),1);
shape[ direction ] = -1;
auto ax_dx = coordinates[i].diff(1,0);
auto dx1 = at::slice(ax_dx, 0, 0, -1);
auto dx2 = at::slice(ax_dx, 0, 1);
auto a = ( -dx2 / (dx1*(dx1+dx2)) ).reshape(shape);
auto b = ( (dx2-dx1) / (dx1*dx2) ).reshape(shape);
auto c = ( dx1 / (dx2*(dx1+dx2)) ).reshape(shape);
auto center = a * at::slice(self, direction, 0, -2) + b * at::slice(self, direction , 1, -1) + c * at::slice(self, direction, 2);
if (edge_order == 1) {
prepend = (at::slice(self, direction, 1, 2 ) - at::slice(self, direction, 0, 1 )) / ax_dx[0] ;
append = (at::slice(self, direction, -1 ) - at::slice(self, direction, -2, -1 )) / ax_dx[-1] ;
} else if (edge_order == 2) {
a =-(2.0 * ax_dx[0] + ax_dx[1]) / (ax_dx[0] * (ax_dx[0] + ax_dx[1])) ;
b = ( ax_dx[0] + ax_dx[1]) / (ax_dx[0] * ax_dx[1]) ;
c = ( -ax_dx[0] ) / (ax_dx[1] * (ax_dx[0] + ax_dx[1]));
prepend = a * at::slice(self, direction, 0, 1) + b * at::slice(self, direction, 1, 2) + c * at::slice(self, direction, 2, 3);
a = ( ax_dx[-1] ) / (ax_dx[-2] * (ax_dx[-1] + ax_dx[-2]));
b =-( ax_dx[-1] + ax_dx[-2]) / (ax_dx[-1] * ax_dx[-2]);
c = (2 * ax_dx[-1] + ax_dx[-2]) / (ax_dx[-1] * (ax_dx[-1] + ax_dx[-2]));
append = a * at::slice(self, direction, -3, -2) + b * at::slice(self, direction, -2, -1) + c * at::slice(self, direction, -1);
}
result.emplace_back(prepend_append_on_dim(center, prepend, append, direction));
}
return result;
}
std::vector<Tensor> gradient_helper_float(const Tensor& self, ArrayRef<Scalar> spacing, IntArrayRef dim, int64_t edge_order) {
std::vector<Tensor> result;
for (const auto i : c10::irange(dim.size())) {
int64_t direction = maybe_wrap_dim(dim[i], self.dim());
auto ax_dx = spacing[i];
Tensor prepend, append;
auto center = (at::slice(self,direction, 2 ) - at::slice(self, direction, 0, -2 ) ) / ax_dx;
if (edge_order==1) {
prepend = (at::slice(self,direction, 1, 2) - at::slice(self, direction, 0, 1 ) ) / ax_dx;
append = (at::slice(self,direction, -1 ) - at::slice(self, direction, -2, -1) ) / ax_dx ;
} else if (edge_order==2) {
prepend = (-1.5 * at::slice(self, direction, 0, 1) + 2 * at::slice(self, direction, 1, 2) - 0.5 * at::slice(self, direction, 2, 3))/ ax_dx;
append = (0.5 * at::slice(self, direction, -3, -2) - 2 * at::slice(self, direction, -2, -1) + 1.5 * at::slice(self, direction, -1)) / ax_dx;
}
result.emplace_back(prepend_append_on_dim(center/2, prepend, append, direction));
}
return result;
}
std::vector<int64_t> gradient_dim_preprocess(const Tensor& self, c10::optional<int64_t> dim) {
// if gradient dim is provided as an integer, then we need to compute gradient only on this direction.
// Moreover, if it's not provided at all, then we are interested in gradient for all directions.
// Finally, if dim is provided as vector of ints, then it is not expected to be called by this function.
if (dim.has_value()) {
return std::vector<int64_t>{dim.value()};
}
std::vector<int64_t> axis(self.dim());
std::iota(axis.begin(), axis.end(), 0);
return axis;
}
std::vector<Tensor> gradient(const Tensor& self, TensorList coordinates, IntArrayRef dim, int64_t edge_order) {
pre_check_gradient(self,
c10::optional<int64_t>(coordinates.size()),
c10::optional<IntArrayRef>(dim),
edge_order);
return gradient_helper(self, coordinates, dim, edge_order);
}
std::vector<Tensor> gradient(const Tensor& self, TensorList coordinates, c10::optional<int64_t> dim, int64_t edge_order) {
const auto processed_dim = gradient_dim_preprocess(self, dim);
pre_check_gradient(self,
c10::optional<int64_t>(coordinates.size()),
dim.has_value() ? c10::optional<IntArrayRef>(processed_dim) : c10::nullopt,
edge_order);
return gradient_helper(self, coordinates, processed_dim, edge_order);
}
std::vector<Tensor> gradient(const Tensor& self, c10::ArrayRef<Scalar> spacing, IntArrayRef dim, int64_t edge_order) {
pre_check_gradient(self,
c10::optional<int64_t>(spacing.size()),
c10::optional<IntArrayRef>(dim),
edge_order);
return gradient_helper_float(self, spacing, dim, edge_order);
}
std::vector<Tensor> gradient(const Tensor& self, ArrayRef<Scalar> spacing, c10::optional<int64_t> dim, int64_t edge_order) {
const auto processed_dim = gradient_dim_preprocess(self, dim);
pre_check_gradient(self,
c10::optional<int64_t>(spacing.size()),
dim.has_value() ? c10::optional<IntArrayRef>(processed_dim) : c10::nullopt,
edge_order);
return gradient_helper_float(self, spacing, processed_dim, edge_order);
}
std::vector<Tensor> gradient(const Tensor& self, const Scalar& unit_size, IntArrayRef dim, int64_t edge_order) {
// When spacing is given as scalar, while dim is given as IntArrayRef, scalar value need to
// be taken as unit size at every given dimension element of - dim.
std::vector<Scalar> spacing(dim.size(), unit_size);
pre_check_gradient(self,
c10::optional<int64_t>(spacing.size()),
c10::optional<IntArrayRef>(dim),
edge_order);
return gradient_helper_float(self, spacing, dim, edge_order);
}
std::vector<Tensor> gradient(const Tensor& self, const c10::optional<Scalar>& unit_size, c10::optional<int64_t> dim, int64_t edge_order) {
const auto processed_dim = gradient_dim_preprocess(self, dim);
// When unit_size not provided, it is always assumed to be equal to 1.
// When dim has integer value it implies we are looking for gradient in the specific direction, however when
// it is not provided, it means we are interested to find gradient in all directions.
std::vector<Scalar> spacing(dim.has_value() ? 1 : self.dim(),
unit_size.has_value() ? unit_size.value() : 1.0) ;
pre_check_gradient(self,
unit_size.has_value() ? c10::optional<int64_t>(spacing.size()) : c10::nullopt,
dim.has_value() ? c10::optional<IntArrayRef>(processed_dim) : c10::nullopt,
edge_order);
return gradient_helper_float(self, spacing, processed_dim, edge_order);
}
std::vector<Tensor> gradient(const Tensor& self, IntArrayRef dim, int64_t edge_order) {
std::vector<Scalar> spacing(dim.size(), 1.0) ;
pre_check_gradient(self,
c10::optional<int64_t>(spacing.size()),
c10::optional<IntArrayRef>(dim),
edge_order);
return gradient_helper_float(self, spacing, dim, edge_order);
}
// ALL REDUCE #################################################################
inline ScalarType get_dtype_from_result(Tensor& result, optional<ScalarType> dtype) {
TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
if (dtype.has_value()) {
return dtype.value();
} else {
return result.scalar_type();
}
}
inline ScalarType get_dtype_from_self(const Tensor& self, optional<ScalarType> dtype,
bool promote_integers) {
if (dtype.has_value()) {
return dtype.value();
}
ScalarType src_type = self.scalar_type();
if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) {
return kLong;
}
return src_type;
}
Tensor& sum_out(const Tensor& self, IntArrayRef dim,
bool keepdim, optional<ScalarType> opt_dtype, Tensor& result) {
ScalarType dtype = get_dtype_from_result(result, opt_dtype);
auto iter = make_reduction("sum", result, self, dim, keepdim, dtype);
if (iter.numel() == 0) {
result.zero_();
} else {
sum_stub(iter.device_type(), iter);
}
return result;
}
Tensor sum(const Tensor &self, c10::optional<ScalarType> dtype) {
return at::native::sum(self, std::vector<int64_t>{}, false, dtype);
}
Tensor sum(const Tensor& self, IntArrayRef dim, bool keepdim, c10::optional<ScalarType> opt_dtype) {
ScalarType dtype = get_dtype_from_self(self, opt_dtype, true);
Tensor result = create_reduction_result(self, dim, keepdim, dtype);
return at::native::sum_out(self, dim, keepdim, dtype, result);
}
Tensor sum(const Tensor& self, DimnameList dim, bool keepdim, c10::optional<ScalarType> dtype) {
return at::sum(self, dimnames_to_positions(self, dim), keepdim, dtype);
}
Tensor& sum_out(const Tensor& self, DimnameList dim,
bool keepdim, optional<ScalarType> opt_dtype, Tensor& result) {
return at::sum_out(result, self, dimnames_to_positions(self, dim), keepdim, opt_dtype);
}
Tensor& nansum_out(const Tensor& self, IntArrayRef dim,
bool keepdim, optional<ScalarType> opt_dtype, Tensor& result) {
TORCH_CHECK(!c10::isComplexType(self.scalar_type()), "nansum does not support complex inputs");
// For integral types, use existing sum as
// integral types don't have `Nan`.
if (c10::isIntegralType(self.scalar_type(), true)){
return at::sum_out(result, self, dim, keepdim, opt_dtype);
}
ScalarType dtype = get_dtype_from_result(result, opt_dtype);
auto iter = make_reduction("nansum", result, self, dim, keepdim, dtype);
if (iter.numel() == 0) {
result = result.zero_();
} else {
nansum_stub(iter.device_type(), iter);
}
return result;
}
Tensor nansum(const Tensor &self, c10::optional<ScalarType> dtype) {
return at::native::nansum(self, std::vector<int64_t>{}, false, dtype);
}
Tensor nansum(const Tensor& self, IntArrayRef dim, bool keepdim, c10::optional<ScalarType> opt_dtype) {
ScalarType dtype = get_dtype_from_self(self, opt_dtype, true);
Tensor result = create_reduction_result(self, dim, keepdim, dtype);
return at::native::nansum_out(self, dim, keepdim, dtype, result);
}
static Tensor& prod_out_impl(Tensor& result, const Tensor& self, IntArrayRef dim,
bool keepdim, c10::optional<ScalarType> opt_dtype) {
ScalarType dtype = get_dtype_from_result(result, opt_dtype);
auto iter = make_reduction("prod", result, self, dim, keepdim, dtype);
if (iter.numel() == 0) {
result.fill_(1);
} else {
prod_stub(iter.device_type(), iter);
}
return result;
}
// NOTE: this could be implemented via diag and sum, but this has perf problems,
// see https://github.com/pytorch/pytorch/pull/47305,
Tensor trace_cpu(const Tensor& self) {
Tensor result;
// Returns the ScalarType of the self tensor if the tensor is non integral type
// In the case, self is an integer type tensor, at::kLong is return since promote_integers
// is set to true
ScalarType dtype = get_dtype_from_self(self, c10::nullopt, true);
result = at::empty({}, self.options().dtype(dtype));
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(self.scalar_type(), "trace", [&] {
using accscalar_t = at::acc_type<scalar_t, false>;
accscalar_t sum = 0;
const auto* t_data = self.data_ptr<scalar_t>();
int64_t t_stride_0, t_stride_1, t_diag_size;
TORCH_CHECK(self.dim() == 2, "trace: expected a matrix, but got tensor with dim ", self.dim());
t_stride_0 = self.stride(0);
t_stride_1 = self.stride(1);
t_diag_size = std::min(self.size(0), self.size(1));
for (int64_t i = 0; i < t_diag_size; i++) {
sum += t_data[i * (t_stride_0 + t_stride_1)];
}
c10::guts::if_constexpr<std::is_integral<accscalar_t>::value>(
// all integer types get promoted to kLong
[&] (auto _) { *result.data_ptr<int64_t>() = _(sum); }, // then-case, invalid for non-integral types
[&] (auto _) { *result.data_ptr<scalar_t>() = _(sum); } // else-case, invalid for integral types
);
});
return result;
}
Tensor prod(const Tensor& self, int64_t dim, bool keepdim, c10::optional<ScalarType> opt_dtype) {
ScalarType dtype = get_dtype_from_self(self, opt_dtype, true);
Tensor result = create_reduction_result(self, dim, keepdim, dtype);
native::prod_out_impl(result, self, dim, keepdim, dtype);
return result;
}
Tensor prod(const Tensor &self, c10::optional<ScalarType> opt_dtype) {
ScalarType dtype = get_dtype_from_self(self, opt_dtype, true);
Tensor result = create_reduction_result(self, {}, false, dtype);
return at::native::prod_out_impl(result, self, {}, false, dtype);
}
Tensor& prod_out(const Tensor& self, int64_t dim, bool keepdim, c10::optional<ScalarType> dtype, Tensor& result) {
return at::native::prod_out_impl(result, self, dim, keepdim, dtype);
}
Tensor prod(const Tensor& self, Dimname dim, bool keepdim, c10::optional<ScalarType> dtype) {
return at::prod(self, dimname_to_position(self, dim), keepdim, dtype);
}
Tensor& prod_out(const Tensor& self, Dimname dim,
bool keepdim, optional<ScalarType> opt_dtype, Tensor& result) {
return at::prod_out(result, self, dimname_to_position(self, dim), keepdim, opt_dtype);
}
Tensor &mean_out_cpu_gpu(const Tensor &self, IntArrayRef dim,
bool keepdim, c10::optional<ScalarType> opt_dtype, Tensor &result) {
ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type();
TORCH_CHECK(
at::isFloatingType(scalarType) || at::isComplexType(scalarType),
"Can only calculate the mean of floating types. Got ",
toString(scalarType),
" instead.");
ScalarType dtype = get_dtype_from_result(result, opt_dtype);
// TODO: the TensorIterator reduction implementation of mean
// (mean_kernel_impl()) is unvectorized and leads to very poor performance
// for production workloads. Once that's fixed, the following code can be used
// in lieu of the sum + divide implementation below.
if (self.device().is_cpu()) {
int64_t dim_prod = 1;
if (dim.size() == 0 || self.ndimension() == 0) {
dim_prod = self.numel();
} else {
for (auto d : dim) {
dim_prod *= self.size(d);
}
}
at::sum_out(result, self, dim, keepdim, dtype).div_(dim_prod);
return result;
}
auto iter = make_reduction("mean", result, self, dim, keepdim, dtype);
if (iter.numel() == 0) {
result.fill_(std::numeric_limits<double>::quiet_NaN());
} else {
mean_stub(iter.device_type(), iter);
}
return result;
}
Tensor mean_cpu_gpu(const Tensor &self, optional<ScalarType> dtype) {
return at::native::mean_cpu_gpu(self, IntArrayRef{}, false, dtype);
}
Tensor mean_cpu_gpu(const Tensor& self, IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
ScalarType dtype = get_dtype_from_self(self, opt_dtype, true);
Tensor result = create_reduction_result(self, dim, keepdim, dtype);
return at::native::mean_out_cpu_gpu(self, dim, keepdim, dtype, result);
}
Tensor mean(const Tensor& self, DimnameList dim, bool keepdim, optional<ScalarType> dtype) {
return at::mean(self, dimnames_to_positions(self, dim), keepdim, dtype);
}
Tensor& mean_out(const Tensor& self, DimnameList dim,
bool keepdim, c10::optional<ScalarType> opt_dtype, Tensor& result) {
return at::mean_out(result, self, dimnames_to_positions(self, dim), keepdim, opt_dtype);
}
static Tensor squeeze_multiple(const Tensor& self, IntArrayRef dims) {
int ndims = self.sizes().size();
auto dims_to_squeeze = at::dim_list_to_bitset(dims, ndims);
Tensor result = self;
for (int i = ndims - 1; i >= 0; --i) {
if (dims_to_squeeze[i]) {
result = result.squeeze(i);
}
}
return result;
}
static Tensor& logsumexp_out_impl(Tensor& result, const Tensor& self, IntArrayRef dims, bool keepdim) {
// can't take max of empty tensor
if (self.numel() != 0) {
auto maxes = at::amax(self, dims, true);
auto maxes_squeezed = (keepdim ? maxes : squeeze_multiple(maxes, dims));
maxes_squeezed.masked_fill_(maxes_squeezed.abs() == INFINITY, 0);
at::sum_out(result, (self - maxes).exp_(), dims, keepdim);
result.log_().add_(maxes_squeezed);
} else {
at::sum_out(result, at::exp(self), dims, keepdim);
result.log_();
}
return result;
}
Tensor& logsumexp_out(const Tensor& self, IntArrayRef dims, bool keepdim, Tensor& result) {
{
NoNamesGuard guard;
logsumexp_out_impl(result, self, dims, keepdim);
}
namedinference::propagate_names_for_reduction(result, self, dims, keepdim);
return result;
}
Tensor logsumexp(const Tensor& self, IntArrayRef dims, bool keepdim) {
Tensor result = at::empty({0}, self.options());
return at::native::logsumexp_out(self, dims, keepdim, result);
}
Tensor logsumexp(const Tensor& self, DimnameList dims, bool keepdim) {
return at::logsumexp(self, dimnames_to_positions(self, dims), keepdim);
}
Tensor& logsumexp_out(const Tensor& self, DimnameList dims, bool keepdim, Tensor& result) {
return at::logsumexp_out(result, self, dimnames_to_positions(self, dims), keepdim);
}
static Tensor& norm_out(Tensor &result, const Tensor &self, const optional<Scalar>& opt_p,
IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
auto p = opt_p.value_or(2.0).to<double>();
TORCH_CHECK(self.device().is_cpu() || self.is_cuda(),
"norm only supports CPU and CUDA device types, but got: ", self.device().type());
TORCH_CHECK(self.layout() == Layout::Strided,
"norm only supports strided layout, but got: ", self.layout());
ScalarType in_dtype = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type();
TORCH_CHECK(
at::isFloatingType(in_dtype) || at::isComplexType(in_dtype),
"Can only calculate the norm of floating point and complex dtypes. Got ",
toString(in_dtype),
" instead.");
ScalarType out_dtype = result.defined() ? result.scalar_type() : (opt_dtype.has_value() ? opt_dtype.value() : toValueType(self.scalar_type()));
// omit in_dtype in the following call, to avoid make_reduction explicitly casting input to out_dtype
auto iter = isComplexType(self.scalar_type()) ?
make_reduction("norm", result, self, dim, keepdim, in_dtype, out_dtype) :
make_reduction("norm", result, self, dim, keepdim, out_dtype);
if (iter.numel() == 0) {
result.zero_();
} else {
norm_stub(iter.device_type(), iter, p);
}
return result;
}
static inline Tensor _norm(const Tensor &self, const Scalar& p) {
if (self.is_sparse()) {
// Sparse tensors need a different implementation because their values
// are accessed with a different API than strided tensors
return at::native_norm(self, p);
} else {
TORCH_CHECK(self.device().is_cpu() || self.is_cuda(),
"norm only supports CPU AND CUDA device type, got: ", self.device().type());
TORCH_CHECK(self.layout() == Layout::Strided,
"norm only supports strided layout, got: ", self.layout());
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
"norm only supports floating-point dtypes");
ScalarType dtype = toValueType(self.scalar_type());
Tensor result = create_reduction_result(self, IntArrayRef{}, false, dtype);
return at::native::norm_out(result, self, p, IntArrayRef{}, false, c10::nullopt);
}
}
Tensor &norm_out(const Tensor& self, const optional<Scalar>& p, IntArrayRef dim, bool keepdim, ScalarType dtype, Tensor& result) {
return at::native::norm_out(result, self, p, dim, keepdim, optional<ScalarType>(dtype));
}
Tensor &norm_out(const Tensor& self, const optional<Scalar>& p, IntArrayRef dim, bool keepdim, Tensor& result) {
return at::native::norm_out(result, self, p, dim, keepdim, c10::nullopt);
}
static Tensor norm(const Tensor& self, const optional<Scalar>& p, IntArrayRef dim, bool keepdim,
optional<ScalarType> opt_dtype) {
if (self.is_sparse()) {
// Sparse tensors need a different implementation because their values
// are accessed with a different API than strided tensors
return at::native_norm(self, p, dim, keepdim, opt_dtype);
} else {
ScalarType out_dtype = value_or_else(opt_dtype, [&] {return toValueType(self.scalar_type());});
Tensor result = create_reduction_result(self, dim, keepdim, out_dtype);
return at::native::norm_out(result, self, p, dim, keepdim, opt_dtype);
}
}
Tensor norm(const Tensor& self, const optional<Scalar>& p, IntArrayRef dim, bool keepdim, ScalarType dtype) {
return at::native::norm(self, p, dim, keepdim, optional<ScalarType>(dtype));
}
Tensor norm(const Tensor& self, const optional<Scalar>& p, ScalarType dtype) {
return at::native::norm(self, p, IntArrayRef{}, false, optional<ScalarType>(dtype));
}
Tensor norm(const Tensor& self, const optional<Scalar>& p, IntArrayRef dim, bool keepdim) {
return at::native::norm(self, p, dim, keepdim, c10::nullopt);
}
// leave it so we support sparse tensors
Tensor norm(const Tensor& self, const Scalar& p) {
return at::native::_norm(self, p);
}
// Note [all, any : uint8 compatibility]:
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// For NumPy comptability, `all` and `any` return
// Tensor of dtype `bool`. However for compatibility reason,
// for `uint8`, they return Tensor of same dtype `uint8`.
// Reference: https://github.com/pytorch/pytorch/pull/47878#issuecomment-747108561
inline const Tensor & _all(const Tensor & result, TensorIterator & iter) {
if (iter.numel() == 0) {
result.fill_(1);
} else {
and_stub(iter.device_type(), iter);
}
return result;
}
inline TensorIterator get_allany_iter(
const Tensor& self,
const Tensor& result,
IntArrayRef dims,
bool keepdim) {
if (self.is_cuda()) {
// As CUDA supports dynamic type casting, we use this overload of
// `make_reduction`, which doesn't cast input to the result type i.e. kBool.,
// otherwise we use the overload below which casts the input to kBool (which is
// an extra operation).
return meta::make_reduction(self, result, dims, keepdim, self.scalar_type());
}
return meta::make_reduction_from_out_ty(
self, result, dims, keepdim, result.scalar_type());
}
Tensor all(const Tensor& self) {
Tensor result;
auto out_dtype =
meta::check_allany_and_get_output_dtype("all", self, result, false);
auto shape = meta::get_reduction_shape(self, {}, false);
result = at::empty(shape, self.options().dtype(out_dtype));
auto iter = get_allany_iter(self, result, {}, false);
return _all(result, iter);
}
TORCH_IMPL_FUNC(all_out)
(const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
dim = maybe_wrap_dim(dim, self.dim());
auto iter = get_allany_iter(self, result, dim, keepdim);
auto mut_result = const_cast<Tensor&>(result);
if (!_dimreduce_return_trivial(mut_result, self, 1, dim, keepdim)) {
_all(mut_result, iter);
}
}
inline const Tensor & _any(const Tensor & result, TensorIterator & iter) {
if (iter.numel() == 0) {
result.fill_(0);
} else {
or_stub(iter.device_type(), iter);
}
return result;
}
Tensor any(const Tensor& self) {
Tensor result;
auto out_dtype =
meta::check_allany_and_get_output_dtype("any", self, result, false);
auto shape = meta::get_reduction_shape(self, {}, false);
result = at::empty(shape, self.options().dtype(out_dtype));
auto iter = get_allany_iter(self, result, {}, false);
return _any(result, iter);
}
TORCH_IMPL_FUNC(any_out)
(const Tensor& self, int64_t dim, bool keepdim, const Tensor& result) {
dim = maybe_wrap_dim(dim, self.dim());
auto iter = get_allany_iter(self, result, dim, keepdim);
auto mut_result = const_cast<Tensor&>(result);
if (!_dimreduce_return_trivial(mut_result, self, 0, dim, keepdim)) {
_any(mut_result, iter);
}
}
Tensor &amin_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tensor& result) {
TORCH_CHECK(self.scalar_type() == result.scalar_type(), "Expected the dtype for input and out to match, but got ",
self.scalar_type(), " for input's dtype and ", result.scalar_type(), " for out's dtype.");
if (self.numel() == 0) {
zero_numel_check_dims(self, dim, "amin()");
}
auto iter = make_reduction("amin", result, self, dim, keepdim, self.scalar_type());
if (iter.numel() != 0) {
min_values_stub(iter.device_type(), iter);
}
return result;
}
Tensor amin(const Tensor& self, IntArrayRef dim, bool keepdim) {
Tensor result = at::empty({0}, self.options());
return at::amin_out(result, self, dim, keepdim);
}
Tensor &amax_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tensor& result) {
TORCH_CHECK(self.scalar_type() == result.scalar_type(), "Expected the dtype for input and out to match, but got ",
self.scalar_type(), " for input's dtype and ", result.scalar_type(), " for out's dtype.");
if (self.numel() == 0) {
zero_numel_check_dims(self, dim, "amax()");
}
auto iter = make_reduction("amax", result, self, dim, keepdim, self.scalar_type());
if (iter.numel() != 0) {
max_values_stub(iter.device_type(), iter);
}
return result;
}
Tensor amax(const Tensor& self, IntArrayRef dim, bool keepdim) {
Tensor result = at::empty({0}, self.options());
return at::amax_out(result, self, dim, keepdim);
}
TORCH_IMPL_FUNC(argmax_out)
(const Tensor& self,
c10::optional<int64_t> dim,
bool keepdim,
const Tensor& result) {
c10::MaybeOwned<Tensor> in;
DimVector dims;
int64_t _dim = 0;
if (dim.has_value()) {
_dim = maybe_wrap_dim(dim.value(), self.dim());
auto sizes = self.sizes();
if (sizes[_dim] == 1) {
result.fill_(0);
return;
}
dims = IntArrayRef(_dim);
in = c10::MaybeOwned<Tensor>::borrowed(self);
} else {
in = c10::MaybeOwned<Tensor>::owned(self.reshape({-1}));
keepdim = false;
}
auto iter =
meta::make_reduction(*in, result, dims, keepdim, self.scalar_type());
if (iter.numel() != 0) {
argmax_stub(iter.device_type(), iter);
}
}
Tensor& argmin_out(const Tensor& self, c10::optional<int64_t> dim, bool keepdim, Tensor& result) {
c10::MaybeOwned<Tensor> in;
if (dim) {
auto sizes = self.sizes();
zero_numel_check_dims(self, dim.value(), "argmin()");
auto wrap_dim = maybe_wrap_dim(dim.value(), self.dim());
if (sizes[wrap_dim] == 1) {
if (keepdim) {
result = at::zeros(sizes, self.options().dtype(at::kLong));
} else {
auto sizes_vec = sizes.vec();
sizes_vec.erase(sizes_vec.begin() + wrap_dim);
result = at::zeros(sizes_vec, self.options().dtype(at::kLong));
}
return result;
}
in = c10::MaybeOwned<Tensor>::borrowed(self);
} else {
TORCH_CHECK_INDEX(self.numel() != 0, "argmin_out(): Expected reduction dim to be specified for input.numel() == 0.");
in = c10::MaybeOwned<Tensor>::owned(self.reshape({-1}));
keepdim = false;
}
auto itr = make_reduction("argmin", result, *in, dim.value_or(0), keepdim,
self.scalar_type(), at::kLong);
if (itr.numel() != 0) {
argmin_stub(itr.device_type(), itr);
}
return result;
}
Tensor argmin(const Tensor& self, c10::optional<int64_t> dim, bool keepdims) {
Tensor result = at::empty({0}, self.options().dtype(at::kLong));
return at::native::argmin_out(self, dim, keepdims, result);
}
static double std_var_all_cpu(const Tensor& self, int64_t correction, bool take_sqrt) {
const auto dtype = self.scalar_type();
TORCH_CHECK(dtype == kDouble || dtype == kFloat,
"std_var_all: Unsupported dtype ", dtype);
auto mean = self.mean().item<double>();
auto iter = TensorIteratorConfig()
.add_input(self)
.build();
auto reduction = [&](int64_t begin, int64_t end, double thread_sum) {
AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "std_var_all_cpu", [&] {
iter.serial_for_each([&] (char** data, const int64_t* strides, int64_t size0, int64_t size1) {
const double local_mean = mean;
const int64_t inner_stride = strides[0];
const int64_t outer_stride = strides[1];
double local_sum = 0.0;
for (int64_t i = 0; i < size1; ++i) {
const char* row_ptr = data[0] + outer_stride * i;
for (int64_t j = 0; j < size0; ++j) {
const auto ptr = reinterpret_cast<const scalar_t*>(row_ptr + inner_stride * j);
auto dx = (static_cast<double>(*ptr) - local_mean);
local_sum += dx * dx;
}
}
thread_sum += local_sum;
}, {begin, end});
});
return thread_sum;
};
// ((x - mean)**2).sum()
const double sum_dx2 = at::parallel_reduce(
0, iter.numel(), at::internal::GRAIN_SIZE, 0.0, reduction, std::plus<>{});
const auto var = [&] () __ubsan_ignore_float_divide_by_zero__ {
return sum_dx2 / std::max(int64_t{0}, self.numel() - correction);
}();
const auto result = take_sqrt ? std::sqrt(var) : var;
if (dtype == kFloat) {
// Convert to infinity if out of range for a float.
// Doing it now prevents checked_convert failing later
return static_cast<float>(result);
}
return result;
}
static Tensor& std_var_out(
const char* fname, Tensor& result, const Tensor& self,
c10::optional<IntArrayRef> dim, c10::optional<int64_t> correction_opt,
bool keepdim, bool take_sqrt) {
TORCH_CHECK(self.device().is_cpu() || self.device().is_cuda(),
"std and var only supports tensors on a CPU or CUDA device, but got: ",
self.device().type());
TORCH_CHECK(self.layout() == Layout::Strided,
"std and var only supports strided layout, got: ", self.layout());
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
"std and var only support floating point and complex dtypes");
if (at::isComplexType(self.scalar_type())) {
// For complex, calculate variance of real and imaginary components
// seperately then add to get overall variance.
ScalarType dtype = c10::toValueType(get_dtype_from_result(result, {}));
Tensor real_in = at::real(self);
Tensor real_out = at::empty({0}, self.options().dtype(dtype));
std_var_out(
fname,
real_out,
real_in,
dim,
correction_opt,
keepdim,
/*take_sqrt=*/false);
Tensor imag_in = at::imag(self);
Tensor imag_out = at::empty({0}, self.options().dtype(dtype));
std_var_out(
fname,
imag_out,
imag_in,
dim,
correction_opt,
keepdim,
/*take_sqrt=*/false);
at::add_out(result, real_out, imag_out);
if (take_sqrt) {
at::sqrt_out(result, result);
}
return result;
}
// Computation for floating point
const auto correction = correction_opt.value_or(1);
ScalarType dtype = get_dtype_from_result(result, {});
auto iter = make_reduction(fname, result, self, dim, keepdim, dtype);
if (iter.numel() == 0) {
// Trivial reduction
result.fill_(std::numeric_limits<double>::quiet_NaN());
return result;
} else if (
result.numel() == 1 && iter.device_type() == kCPU &&
iter.common_dtype() != kBFloat16 && iter.common_dtype() != kHalf) {
// NOTE: CPU performance significantly regressed when attempting to port to
// ATen,
// so all-reduce has a custom implementation.
// See https://github.com/pytorch/pytorch/pull/43858.
result.fill_(std_var_all_cpu(self, correction, take_sqrt));
} else {
std_var_stub(iter.device_type(), iter, correction, take_sqrt);
}
return result;
}
static std::tuple<Tensor&, Tensor&> std_var_mean_out(
const char* fname, Tensor& result1, Tensor& result2, const Tensor& self,
c10::optional<IntArrayRef> dim, c10::optional<int64_t> correction_opt,
bool keepdim, bool take_sqrt) {
AT_ASSERT(result1.defined() && result2.defined());
TORCH_CHECK(self.device().is_cpu() || self.is_cuda(),
fname, " only supports tensors on a CPU or CUDA device, got: ",
self.device().type());
TORCH_CHECK(self.layout() == Layout::Strided,
fname, " only supports strided layout, got: ", self.layout());
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
fname, " only support floating point and complex dtypes");
TORCH_CHECK(result1.scalar_type() == c10::toValueType(result2.scalar_type()),
fname, " expected result1 to be real and match the precision of result2. Got ",
result1.scalar_type(), " and ", result2.scalar_type(), ".");
if (at::isComplexType(self.scalar_type())) {
// For complex, calculate for real and imaginary components seperately then combine as:
// variance = var_real + var_imag
// mean = mean_real + j * mean_imag
ScalarType dtype = c10::toValueType(get_dtype_from_result(result1, {}));
Tensor real_in = at::real(self);
Tensor real_out_var = at::empty({0}, self.options().dtype(dtype));
Tensor real_out_mean = at::empty({0}, self.options().dtype(dtype));
std_var_mean_out(
fname,
real_out_var,
real_out_mean,
real_in,
dim,
correction_opt,
keepdim,
/*take_sqrt=*/false);
Tensor imag_in = at::imag(self);
Tensor imag_out_var = at::empty({0}, self.options().dtype(dtype));
Tensor imag_out_mean = at::empty({0}, self.options().dtype(dtype));
std_var_mean_out(
fname,
imag_out_var,
imag_out_mean,
imag_in,
dim,
correction_opt,
keepdim,
/*take_sqrt=*/false);
at::add_out(result1, real_out_var, imag_out_var);
if (take_sqrt) {
at::sqrt_out(result1, result1);
}
at::complex_out(result2, real_out_mean, imag_out_mean);
return std::tuple<Tensor&, Tensor&>(result1, result2);
}
// Computation for floating point
const auto correction = correction_opt.value_or(1);
ScalarType dtype = get_dtype_from_result(result1, {});
auto iter =
make_reduction(fname, result1, result2, self, dim, keepdim, dtype);
if (iter.numel() == 0) {
// Trivial reduction
result1.fill_(std::numeric_limits<double>::quiet_NaN());
result2.fill_(std::numeric_limits<double>::quiet_NaN());
} else {
std_var_stub(iter.device_type(), iter, correction, take_sqrt);
}
return std::tuple<Tensor&, Tensor&>(result1, result2);
}
std::tuple<Tensor, Tensor> var_mean(
const Tensor& self, IntArrayRef dim, bool unbiased, bool keepdim) {
return at::var_mean(self, /*dim=*/c10::optional<IntArrayRef>(dim),
/*correction=*/int64_t{unbiased ? 1 : 0}, keepdim);
}
std::tuple<Tensor, Tensor> std_mean(
const Tensor& self, IntArrayRef dim, bool unbiased, bool keepdim) {
return at::std_mean(self, /*dim=*/c10::optional<IntArrayRef>(dim),
/*correction=*/int64_t{unbiased ? 1 : 0}, keepdim);
}
std::tuple<Tensor, Tensor> std_mean(const Tensor& self, bool unbiased) {
return at::std_mean(
self, /*dim=*/c10::nullopt, /*correction=*/int64_t{unbiased ? 1 : 0});
}
std::tuple<Tensor, Tensor> var_mean(const Tensor& self, bool unbiased) {
return at::var_mean(
self, /*dim=*/c10::nullopt, /*correction=*/int64_t{unbiased ? 1 : 0});
}
std::tuple<Tensor&, Tensor&> var_mean_out(
Tensor& result1, Tensor& result2, const Tensor& self, IntArrayRef dim,
int64_t correction, bool keepdim) {
return std_var_mean_out(
"var_mean", result1, result2, self, dim, correction, keepdim, false);
}
static TensorOptions options_to_value_type(TensorOptions opts) {
auto scalar_type = typeMetaToScalarType(opts.dtype());
return opts.dtype(c10::toValueType(scalar_type));
}
std::tuple<Tensor, Tensor> var_mean(
const Tensor& self, c10::optional<IntArrayRef> dim,
c10::optional<int64_t> correction, bool keepdim) {
Tensor result1 = at::empty({0}, options_to_value_type(self.options()));
Tensor result2 = at::empty({0}, self.options());
return std_var_mean_out(
"var_mean", result1, result2, self, dim, correction, keepdim, false);
}
std::tuple<Tensor, Tensor> std_mean(
const Tensor& self, c10::optional<IntArrayRef> dim,
c10::optional<int64_t> correction, bool keepdim) {
Tensor result1 = at::empty({0}, options_to_value_type(self.options()));
Tensor result2 = at::empty({0}, self.options());
return std_var_mean_out(
"std_mean", result1, result2, self, dim, correction, keepdim, true);
}
Tensor var(const Tensor& self, bool unbiased) {
return at::var(
self, /*dim=*/c10::nullopt, /*correction=*/int64_t{unbiased ? 1 : 0});
}
Tensor var(const Tensor& self, IntArrayRef dim, bool unbiased, bool keepdim) {
return at::var(self, /*dim=*/c10::optional<IntArrayRef>(dim),
/*correction=*/int64_t{unbiased ? 1 : 0}, keepdim);
}
Tensor& var_out(const Tensor& self, IntArrayRef dim, bool unbiased, bool keepdim, Tensor& result) {
return at::var_out(result, self, /*dim=*/c10::optional<IntArrayRef>(dim),
/*correction=*/int64_t{unbiased ? 1 : 0}, keepdim);
}
Tensor std(const Tensor& self, bool unbiased) {
return at::std(
self, /*dim=*/c10::nullopt, /*correction=*/int64_t{unbiased ? 1 : 0});
}
Tensor std(const Tensor& self, IntArrayRef dim, bool unbiased, bool keepdim) {
return at::std(self, /*dim=*/c10::optional<IntArrayRef>(dim),
/*correction=*/int64_t{unbiased ? 1 : 0}, keepdim);
}
Tensor& std_out(const Tensor& self, IntArrayRef dim, bool unbiased, bool keepdim, Tensor& result) {
return at::std_out(result, self, /*dim=*/c10::optional<IntArrayRef>(dim),
/*correction=*/int64_t{unbiased ? 1 : 0}, keepdim);
}
Tensor std(const Tensor& self, c10::optional<IntArrayRef> dim,
c10::optional<int64_t> correction, bool keepdim) {
Tensor result = at::empty({0}, options_to_value_type(self.options()));
return std_var_out("std", result, self, dim, correction, keepdim, true);
}
Tensor& std_out(
const Tensor& self, c10::optional<IntArrayRef> dim,
c10::optional<int64_t> correction, bool keepdim, Tensor& result) {
return std_var_out("std", result, self, dim, correction, keepdim, true);
}
Tensor& var_out(
const Tensor& self, c10::optional<IntArrayRef> dim,
c10::optional<int64_t> correction, bool keepdim, Tensor& result) {
return std_var_out("var", result, self, dim, correction, keepdim, false);
}
Tensor var(
const Tensor& self, c10::optional<IntArrayRef> dim,
c10::optional<int64_t> correction, bool keepdim) {
Tensor result = at::empty({0}, options_to_value_type(self.options()));
return std_var_out("var", result, self, dim, correction, keepdim, false);
}
Tensor std(const Tensor& self, DimnameList dim, bool unbiased, bool keepdim) {
return at::std(self, dimnames_to_positions(self, dim), unbiased, keepdim);
}
Tensor& std_out(const Tensor& self, DimnameList dim, bool unbiased, bool keepdim, Tensor& result) {
return at::std_out(result, self, dimnames_to_positions(self, dim), unbiased, keepdim);
}
Tensor var(const Tensor& self, DimnameList dim, bool unbiased, bool keepdim) {
return at::var(self, dimnames_to_positions(self, dim), unbiased, keepdim);
}
Tensor& var_out(const Tensor& self, DimnameList dim, bool unbiased, bool keepdim, Tensor& result) {
return at::var_out(
result, self, dimnames_to_positions(self, dim), unbiased, keepdim);
}
std::tuple<Tensor,Tensor> var_mean(const Tensor& self, DimnameList dim, bool unbiased, bool keepdim) {
return at::var_mean(self, dimnames_to_positions(self, dim), unbiased, keepdim);
}
std::tuple<Tensor,Tensor> std_mean(const Tensor& self, DimnameList dim, bool unbiased, bool keepdim) {
return at::std_mean(self, dimnames_to_positions(self, dim), unbiased, keepdim);
}
Tensor std(const Tensor& self, DimnameList dim, c10::optional<int64_t> correction, bool keepdim) {
return at::std(self, dimnames_to_positions(self, dim), correction, keepdim);
}
Tensor& std_out(const Tensor& self, DimnameList dim, c10::optional<int64_t> correction,
bool keepdim, Tensor& result) {
return at::std_out(result, self, dimnames_to_positions(self, dim), correction, keepdim);
}
Tensor var(const Tensor& self, DimnameList dim, c10::optional<int64_t> correction, bool keepdim) {
return at::var(self, dimnames_to_positions(self, dim), correction, keepdim);
}
Tensor& var_out(const Tensor& self, DimnameList dim, c10::optional<int64_t> correction,
bool keepdim, Tensor& result) {
return at::var_out(
result, self, dimnames_to_positions(self, dim), correction, keepdim);
}
std::tuple<Tensor,Tensor> var_mean(const Tensor& self, DimnameList dim,
c10::optional<int64_t> correction, bool keepdim) {
return at::var_mean(self, dimnames_to_positions(self, dim), correction, keepdim);
}
std::tuple<Tensor,Tensor> std_mean(const Tensor& self, DimnameList dim,
c10::optional<int64_t> correction, bool keepdim) {
return at::std_mean(self, dimnames_to_positions(self, dim), correction, keepdim);
}
Tensor& norm_out(const Tensor& self, const optional<Scalar>& p, DimnameList dim, bool keepdim, ScalarType dtype, Tensor& result) {
return at::norm_out(result, self, p, dimnames_to_positions(self, dim), keepdim, dtype);
}
Tensor& norm_out(const Tensor& self, const optional<Scalar>& p, DimnameList dim, bool keepdim, Tensor& result) {
return at::norm_out(result, self, p, dimnames_to_positions(self, dim), keepdim);
}
Tensor norm(const Tensor& self, const optional<Scalar>& p, DimnameList dim, bool keepdim, ScalarType dtype) {
return at::norm(self, p, dimnames_to_positions(self, dim), keepdim, dtype);
}
Tensor norm(const Tensor& self, const optional<Scalar>& p, DimnameList dim, bool keepdim) {
return at::norm(self, p, dimnames_to_positions(self, dim), keepdim);
}
Tensor any(const Tensor& self, Dimname dim, bool keepdim) {
reportNYIDimnameOverload("any");
}
Tensor& any_out(const Tensor &self, Dimname dim, bool keepdim, Tensor& result) {
reportNYIDimnameOverload("any");
}
Tensor all(const Tensor& self, Dimname dim, bool keepdim) {
reportNYIDimnameOverload("all");
}
Tensor& all_out(const Tensor &self, Dimname dim, bool keepdim, Tensor& result) {
reportNYIDimnameOverload("all");
}
Tensor logcumsumexp(const Tensor& self, Dimname dim) {
return at::logcumsumexp(self, dimname_to_position(self, dim));
}
Tensor& logcumsumexp_out(const Tensor& self, Dimname dim, Tensor& result) {
return at::logcumsumexp_out(result, self, dimname_to_position(self, dim));
}
Tensor cumsum(const Tensor& self, Dimname dim, c10::optional<ScalarType> dtype) {
return at::cumsum(self, dimname_to_position(self, dim), dtype);
}
Tensor& cumsum_(Tensor& self, Dimname dim, c10::optional<ScalarType> dtype) {
return native::cumsum_(self, dimname_to_position(self, dim), dtype);
}
Tensor& cumsum_out(const Tensor& self, Dimname dim, c10::optional<ScalarType> dtype, Tensor& result) {
return at::cumsum_out(result, self, dimname_to_position(self, dim), dtype);
}
Tensor cumprod(const Tensor& self, Dimname dim, c10::optional<ScalarType> dtype) {
return at::cumprod(self, dimname_to_position(self, dim), dtype);
}
Tensor& cumprod_(Tensor& self, Dimname dim, c10::optional<ScalarType> dtype) {
return native::cumprod_(self, dimname_to_position(self, dim), dtype);
}
Tensor& cumprod_out(const Tensor& self, Dimname dim, c10::optional<ScalarType> dtype, Tensor& result) {
return at::cumprod_out(result, self, dimname_to_position(self, dim), dtype);
}
std::tuple<Tensor, Tensor> cummax(const Tensor& self, Dimname dim) {
return at::cummax(self, dimname_to_position(self, dim));
}
std::tuple<Tensor&, Tensor&> cummax_out(const Tensor& self, Dimname dim, Tensor& values, Tensor& indices) {
return at::cummax_out(values, indices, self, dimname_to_position(self, dim));
}
std::tuple<Tensor, Tensor> cummin(const Tensor& self, Dimname dim) {
return at::cummin(self, dimname_to_position(self, dim));
}
std::tuple<Tensor&, Tensor&> cummin_out(const Tensor& self, Dimname dim, Tensor& values, Tensor& indices) {
return at::cummin_out(values, indices, self, dimname_to_position(self, dim));
}
Tensor dist(const Tensor &self, const Tensor& other, const Scalar& p){
return at::norm(self - other, p);
}
bool cpu_equal(const Tensor& self, const Tensor& other) {
if (!at::namedinference::are_names_equal(
self.unsafeGetTensorImpl(), other.unsafeGetTensorImpl())) {
return false;
}
at::NoNamesGuard guard;
TORCH_CHECK(self.device() == other.device(), "Cannot compare two tensors on "
"different devices. Got: ", self.device(), " and ", other.device());
TORCH_CHECK(self.dtype() == other.dtype(),
"Expected object of scalar type ", self.dtype(), " but got scalar type ",
other.dtype(), " for argument 'other'");
if (!self.is_same_size(other)) {
return false;
}
std::atomic<bool> result{true};
auto iter = TensorIteratorConfig()
.add_input(self)
.add_input(other)
.allow_cpu_scalars(true)
.promote_inputs_to_common_dtype(true)
.build();
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "equal_cpu", [&] {
iter.for_each([&](char** data, const int64_t *strides, int64_t dim_size) {
if (!result) {
return;
}
char* self_data = data[0];
char* other_data = data[1];
for (int64_t i = 0; i < dim_size; ++i) {
if (*((scalar_t*)self_data) != *((scalar_t*)other_data)) {
result = false;
return;
}
self_data += strides[0];
other_data += strides[1];
}
});
});
return result.load();
}
// max(dim), min(dim), topk(dim), mode(dim), are examples of reduction
// functions that select values. value_selecting_reduction_backward is the
// backward function for those operators; it propagates the grad to the
// specific value locations referred to at `indices`.
Tensor value_selecting_reduction_backward(const Tensor& grad, int64_t dim, const Tensor& indices, IntArrayRef sizes, bool keepdim) {
if (!keepdim && sizes.size() > 0) {
auto grad_ = grad.unsqueeze(dim);
auto indices_ = indices.unsqueeze(dim);
return at::zeros(sizes, grad_.options()).scatter_(dim, indices_, grad_);
}
return at::zeros(sizes, grad.options()).scatter_(dim, indices, grad);
}
}} // namespace at::native