blob: 04fc88d1d147e52d73307aac8603f2d9b519e974 [file] [log] [blame]
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/native/ReduceOps.h>
#include <ATen/native/ReduceAllOps.h>
#include <ATen/native/ReduceOpsUtils.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/TensorIterator.h>
#include <ATen/OpMathType.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/native/cpu/zmath.h>
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
#include <c10/util/irange.h>
namespace at::native {
namespace {
using namespace vec;
template <typename scalar_t, typename func_t, typename vec_func_t>
inline void reduce_all_impl_vec(
Tensor& output,
const Tensor& input,
const scalar_t ident_v,
func_t op,
vec_func_t vop) {
using Vec = Vectorized<opmath_type<scalar_t>>;
const int64_t input_numel = input.numel();
auto input_data = input.const_data_ptr<scalar_t>();
// NOTE: parallel_reduce not support bool type
scalar_t result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
[&](int64_t start, int64_t end, const scalar_t /*ident*/) -> scalar_t {
scalar_t partial_out = vec::reduce_all<scalar_t>(
[=](Vec x, Vec y) { return vop(x, y); },
input_data + start,
end - start);
return partial_out;
}, op);
output.fill_(result);
}
// For operation not support in avx/avx2
template <typename scalar_t, typename func_t>
inline void reduce_all_impl(
Tensor& output,
const Tensor& input,
const scalar_t ident_v,
func_t op) {
const int64_t input_numel = input.numel();
auto input_data = input.const_data_ptr<scalar_t>();
scalar_t result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
[&](int64_t start, int64_t end, const scalar_t ident) -> scalar_t {
scalar_t partial_out = ident;
for (const auto i : c10::irange(start, end)) {
partial_out = op(partial_out, input_data[i]);
}
return partial_out;
}, op);
output.fill_(result);
}
static void min_all_kernel_impl(Tensor& result, const Tensor& input) {
if (input.scalar_type() == ScalarType::Bool) {
TensorIterator iter = TensorIteratorConfig()
.add_input(input)
.build();
bool result_data = true;
cpu_serial_kernel(iter, [&](const bool a) -> void {
result_data = result_data && a;
});
result.fill_(result_data);
} else if(input.scalar_type() == ScalarType::Long) {
// for int64_t, vectorized implementation have performance issue,
// just use scalar path
reduce_all_impl<int64_t>(result, input, upper_bound<int64_t>(),
[=](int64_t a, int64_t b) -> int64_t { return min_impl(a, b); });
} else {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "min_all", [&] {
using Vec = Vectorized<opmath_type<scalar_t>>;
reduce_all_impl_vec<scalar_t>(result, input, upper_bound<scalar_t>(),
[=] (scalar_t a , scalar_t b) -> scalar_t { return min_impl(a, b); },
[=](Vec a, Vec b) -> Vec { return minimum(a, b); });
});
}
}
static void max_all_kernel_impl(Tensor& result, const Tensor& input) {
if (input.scalar_type() == ScalarType::Bool) {
TensorIterator iter = TensorIteratorConfig()
.add_input(input)
.build();
bool result_data = false;
cpu_serial_kernel(iter, [&](const bool a) -> void {
result_data = result_data || a;
});
result.fill_(result_data);
} else if (input.scalar_type() == ScalarType::Long) {
// for int64_t, vectorized implementation have performance issue,
// just use scalar path
reduce_all_impl<int64_t>(result, input, lower_bound<int64_t>(),
[=](int64_t a, int64_t b) -> int64_t { return max_impl(a, b); });
} else {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "max_all", [&] {
using Vec = Vectorized<opmath_type<scalar_t>>;
reduce_all_impl_vec<scalar_t>(result, input, lower_bound<scalar_t>(),
[=] (scalar_t a , scalar_t b) -> scalar_t { return max_impl(a, b); },
[=](Vec a, Vec b) -> Vec { return maximum(a, b); });
});
}
}
// For operation not support in avx/avx2
template <typename scalar_t, typename func_t1, typename func_t2>
inline void reduce_all_impl_two_outputs(
Tensor& output1,
Tensor& output2,
const Tensor& input,
const std::pair<scalar_t, scalar_t>& ident_v,
func_t1 reduce_chunk_func,
func_t2 reduce_acc_func) {
using scalar_t_pair = std::pair<scalar_t, scalar_t>;
const int64_t input_numel = input.numel();
auto input_data = input.const_data_ptr<scalar_t>();
scalar_t_pair result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
[&](int64_t start, int64_t end, const scalar_t_pair& ident) -> scalar_t_pair {
scalar_t_pair partial_out(ident);
for (const auto i : c10::irange(start, end)) {
partial_out = reduce_chunk_func(partial_out, input_data[i]);
}
return partial_out;
},
reduce_acc_func
);
output1.fill_(result.first);
output2.fill_(result.second);
}
template <typename scalar_t, typename func_t, typename vec_func_t1, typename vec_func_t2>
inline void reduce_all_impl_vec_two_outputs(
Tensor& output1,
Tensor& output2,
const Tensor& input,
const std::pair<scalar_t, scalar_t>& ident_v,
func_t reduce_acc_func,
vec_func_t1 reduce_chunk_func1,
vec_func_t2 reduce_chunk_func2) {
using Vec = Vectorized<opmath_type<scalar_t>>;
using scalar_t_pair = std::pair<scalar_t, scalar_t>;
const int64_t input_numel = input.numel();
auto input_data = input.const_data_ptr<scalar_t>();
// NOTE: parallel_reduce not support bool type
std::pair<scalar_t, scalar_t> result = at::parallel_reduce(0, input_numel, internal::GRAIN_SIZE, ident_v,
[&](int64_t start, int64_t end, const scalar_t_pair& /* ident */) -> scalar_t_pair {
scalar_t_pair partial_out = vec::reduce2_all<scalar_t>(
[=](Vec x, Vec y) { return reduce_chunk_func1(x, y); },
[=](Vec x, Vec y) { return reduce_chunk_func2(x, y); },
input_data + start,
end - start);
return partial_out;
},
reduce_acc_func
);
output1.fill_(result.first);
output2.fill_(result.second);
}
static void aminmax_allreduce_kernel(
const Tensor& input,
Tensor& min_result,
Tensor& max_result) {
if (input.scalar_type() == ScalarType::Bool) {
TensorIterator iter = TensorIteratorConfig()
.add_input(input)
.build();
bool min_result_data = true;
bool max_result_data = false;
cpu_serial_kernel(iter, [&](const bool a) -> void {
min_result_data = min_result_data && a;
max_result_data = max_result_data || a;
});
min_result.fill_(min_result_data);
max_result.fill_(max_result_data);
} else if (input.scalar_type() == ScalarType::Long) {
// for int64_t, vectorized implementation have performance issue,
// just use scalar path
using int64_t_pair = std::pair<int64_t, int64_t>;
reduce_all_impl_two_outputs<int64_t>(min_result, max_result, input,
int64_t_pair(upper_bound<int64_t>(), lower_bound<int64_t>()),
// reduce over chunk
[=](int64_t_pair a, int64_t b) -> int64_t_pair {
return int64_t_pair(min_impl(a.first, b), max_impl(a.second, b));
},
// combine two inputs
[=](int64_t_pair a, int64_t_pair b) -> int64_t_pair {
return int64_t_pair(min_impl(a.first, b.first), max_impl(a.second, b.second));
}
);
} else {
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "aminmax_cpu", [&] {
using Vec = Vectorized<opmath_type<scalar_t>>;
using scalar_t_pair = std::pair<scalar_t, scalar_t>;
reduce_all_impl_vec_two_outputs<scalar_t>(
min_result,
max_result,
input,
scalar_t_pair(upper_bound<scalar_t>(), lower_bound<scalar_t>()),
[=] (scalar_t_pair a , scalar_t_pair b) -> scalar_t_pair {
return scalar_t_pair(
min_impl(a.first, b.first), max_impl(a.second, b.second));
},
[=](Vec a, Vec b) -> Vec { return minimum(a, b); },
[=](Vec a, Vec b) -> Vec { return maximum(a, b); }
);
});
}
}
} // namespace
REGISTER_DISPATCH(min_all_stub, &min_all_kernel_impl);
REGISTER_DISPATCH(max_all_stub, &max_all_kernel_impl);
REGISTER_DISPATCH(aminmax_allreduce_stub, &aminmax_allreduce_kernel);
} // namespace at::native