Use some `if constexpr` in the code (#90483)
As PyTorch is C++17 project now. Replace `c10::guts::if_constexpr` with `if constexpr`
Deliberately delaying changes in headers until at least one nightly
cycle is complete.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90483
Approved by: https://github.com/kit1980, https://github.com/Skylion007
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index a10f6c7..549611a 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -1219,6 +1219,18 @@
return at::native::nansum_out(self, dim, keepdim, dtype, result);
}
+namespace {
+template<typename scalar_t, typename accscalar_t = at::acc_type<scalar_t, false>>
+void inline set_result(Tensor& result, accscalar_t sum)
+{
+ if constexpr (std::is_integral_v<accscalar_t>) {
+ // all integer types get promoted to kLong
+ *result.data_ptr<int64_t>() = sum;
+ } else {
+ *result.data_ptr<scalar_t>() = sum;
+ }
+}
+}
// NOTE: this could be implemented via diag and sum, but this has perf problems,
// see https://github.com/pytorch/pytorch/pull/47305,
Tensor trace_cpu(const Tensor& self) {
@@ -1244,12 +1256,8 @@
for (const auto i : c10::irange(t_diag_size)) {
sum += t_data[i * (t_stride_0 + t_stride_1)];
}
+ set_result<scalar_t>(result, sum);
- c10::guts::if_constexpr<std::is_integral<accscalar_t>::value>(
- // all integer types get promoted to kLong
- [&] (auto _) { *result.data_ptr<int64_t>() = _(sum); }, // then-case, invalid for non-integral types
- [&] (auto _) { *result.data_ptr<scalar_t>() = _(sum); } // else-case, invalid for integral types
- );
});
return result;
diff --git a/aten/src/ATen/native/cpu/SumKernel.cpp b/aten/src/ATen/native/cpu/SumKernel.cpp
index d4a6eb3..89fdca1 100644
--- a/aten/src/ATen/native/cpu/SumKernel.cpp
+++ b/aten/src/ATen/native/cpu/SumKernel.cpp
@@ -552,18 +552,16 @@
char* ptrs[3] = { data[0], data[0], data[1] };
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t inner_strides[3] = { strides[0], strides[0], strides[1] };
- c10::guts::if_constexpr<ignore_nan>(
- [&](auto) {
+ if constexpr (ignore_nan) {
basic_loop(ptrs, inner_strides, 0, size0, [](scalar_t a, scalar_t b) {
auto a_notnan = at::_isnan(a) ? scalar_t(0) : a;
auto b_notnan = at::_isnan(b) ? scalar_t(0) : b;
return a_notnan + b_notnan;
});
- },
- [&](auto) {
+ } else {
basic_loop(ptrs, inner_strides, 0, size0,
[](scalar_t a, scalar_t b) { return a + b; });
- });
+ }
});
return;
}