Use cascade summation in nll_loss on CPU (#55841)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/55657
This also avoids summing `total_weight_val` when weights aren't supplied. Avoiding accumulated error completely.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55841
Reviewed By: jbschlosser
Differential Revision: D27751492
Pulled By: ngimel
fbshipit-source-id: 2c2dc48f31c25dfa9db48693e3f765b179771a3c
diff --git a/aten/src/ATen/native/LossNLL.cpp b/aten/src/ATen/native/LossNLL.cpp
index 378d364..76371b2 100644
--- a/aten/src/ATen/native/LossNLL.cpp
+++ b/aten/src/ATen/native/LossNLL.cpp
@@ -3,6 +3,7 @@
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/TensorUtils.h>
+#include <ATen/native/cpu/utils.h>
namespace at {
namespace native {
@@ -23,6 +24,7 @@
return source.defined() ? source.data_ptr<scalar_t>() : nullptr;
}
+
template <typename scalar_t>
static void nll_loss_out_frame(
Tensor& output,
@@ -82,43 +84,67 @@
const scalar_t* input_data = input_contiguous.data_ptr<scalar_t>();
const int64_t* target_data = target_contiguous.data_ptr<int64_t>();
- scalar_t output_val = 0;
- scalar_t total_weight_val = 0;
+ const int64_t ndim = input.dim();
+ TORCH_CHECK(ndim <= 2);
+ const int64_t batch_size = ndim == 1 ? 1 : input.size(0);
+ TORCH_CHECK(target.size(0) == batch_size);
- if (input.dim() == 1) {
- const auto cur_target = target_data[0];
- if (cur_target != ignore_index) {
- TORCH_CHECK_INDEX(
- cur_target >= 0 && cur_target < n_classes,
- "Target ",
- cur_target,
- " is out of bounds.");
- total_weight_val =
- weight_data ? weight_data[cur_target] : static_cast<scalar_t>(1);
- output_val = -input_data[cur_target] * total_weight_val;
+ constexpr int64_t cascade_sum_num_levels = 8;
+ const int64_t level_power =
+ std::max(int64_t(4), utils::CeilLog2(batch_size) / cascade_sum_num_levels);
+ const int64_t level_step = (1 << level_power);
+ const int64_t level_mask = level_step - 1;
+
+ int64_t num_ignored = 0;
+
+ scalar_t weight_partial_sums[cascade_sum_num_levels] = {0};
+ scalar_t loss_partial_sums[cascade_sum_num_levels] = {0};
+ for (int64_t b = 0; b < batch_size; b++) {
+ const int64_t cur_target = target_data[b];
+ if (cur_target == ignore_index) {
+ ++num_ignored;
+ continue;
}
- } else if (input.dim() == 2) {
- const auto batch_size = input.size(0);
- TORCH_CHECK(target.size(0) == batch_size);
- const auto n_target = input.size(1);
- for (int64_t i = 0; i < batch_size; i++) {
- const auto cur_target = target_data[i];
- if (cur_target != ignore_index) {
- TORCH_CHECK_INDEX(
- cur_target >= 0 && cur_target < n_classes,
- "Target ",
- cur_target,
- " is out of bounds.");
+ TORCH_CHECK_INDEX(
+ cur_target >= 0 && cur_target < n_classes,
+ "Target ",
+ cur_target,
+ " is out of bounds.");
- scalar_t cur_weight =
- weight_data ? weight_data[cur_target] : static_cast<scalar_t>(1);
- total_weight_val += cur_weight;
- output_val -= input_data[i * n_target + cur_target] * cur_weight;
+ const auto data = input_data[b * n_classes + cur_target];
+ if (weight_data) {
+ const scalar_t weight_val = weight_data[cur_target];
+ loss_partial_sums[0] -= data * weight_val;
+ weight_partial_sums[0] += weight_val;
+ } else {
+ loss_partial_sums[0] -= data;
+ }
+
+ for (int64_t j = 0; j + 1 < cascade_sum_num_levels; ++j) {
+ const auto mask = (level_mask << (j * level_power));
+ if (C10_LIKELY((b & mask) != 0)) {
+ break;
}
+
+ weight_partial_sums[j + 1] += weight_partial_sums[j];
+ loss_partial_sums[j + 1] += loss_partial_sums[j];
+
+ weight_partial_sums[j] = 0;
+ loss_partial_sums[j] = 0;
}
}
+ const scalar_t total_weight_val = !weight_data ?
+ static_cast<scalar_t>(batch_size - num_ignored) :
+ std::accumulate(std::begin(weight_partial_sums),
+ std::end(weight_partial_sums),
+ scalar_t{0});
+
+ scalar_t output_val = std::accumulate(std::begin(loss_partial_sums),
+ std::end(loss_partial_sums),
+ scalar_t{0});
+
if (reduction == Reduction::Mean &&
(total_weight_val != 0 || input.numel() == 0)) {
// allow NaN result for total_weight_val == 0 case, see #15870
diff --git a/aten/src/ATen/native/LossNLL2d.cpp b/aten/src/ATen/native/LossNLL2d.cpp
index 6f98702..3587d02 100644
--- a/aten/src/ATen/native/LossNLL2d.cpp
+++ b/aten/src/ATen/native/LossNLL2d.cpp
@@ -3,6 +3,7 @@
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/TensorUtils.h>
+#include <ATen/native/cpu/utils.h>
namespace at {
namespace native {
@@ -78,6 +79,7 @@
target.sizes());
}
+
template <typename scalar_t>
static void nll_loss2d_forward_out_frame(
Tensor& output,
@@ -147,14 +149,22 @@
const int64_t batch_size = input.size(0);
const int64_t map_size = input.size(2) * input.size(3);
const int64_t sample_size = map_size * n_classes;
+ const int64_t numiter = batch_size * map_size;
- scalar_t total_weight_val = 0;
- scalar_t output_val = 0;
+ constexpr int64_t cascade_sum_num_levels = 8;
+ scalar_t weight_partial_sums[cascade_sum_num_levels] = {0};
+ scalar_t loss_partial_sums[cascade_sum_num_levels] = {0};
+ const int64_t level_power =
+ std::max(int64_t(4), utils::CeilLog2(numiter) / cascade_sum_num_levels);
+ const int64_t level_step = (1 << level_power);
+ const int64_t level_mask = level_step - 1;
+
+ int64_t num_ignored = 0;
for (int64_t b = 0; b < batch_size; b++) {
for (int64_t elem = 0; elem < map_size; elem++) {
const int64_t cur_target = target_data[b * map_size + elem];
-
if (cur_target == ignore_index) {
+ ++num_ignored;
continue;
}
@@ -164,14 +174,42 @@
cur_target,
" is out of bounds.");
- const scalar_t weight_val =
- weight_data ? weight_data[cur_target] : static_cast<scalar_t>(1);
- total_weight_val += weight_val;
- output_val -= input_data[b * sample_size + cur_target * map_size + elem] *
- weight_val;
+ const auto data = input_data[b * sample_size + cur_target * map_size + elem];
+ if (weight_data) {
+ const scalar_t weight_val = weight_data[cur_target];
+ loss_partial_sums[0] -= data * weight_val;
+ weight_partial_sums[0] += weight_val;
+ } else {
+ loss_partial_sums[0] -= data;
+ }
+
+ const int64_t linear_idx = b * map_size + elem;
+ for (int64_t j = 0; j + 1 < cascade_sum_num_levels; ++j) {
+ const auto mask = (level_mask << (j * level_power));
+ if (C10_LIKELY((linear_idx & mask) != 0)) {
+ break;
+ }
+
+ weight_partial_sums[j + 1] += weight_partial_sums[j];
+ loss_partial_sums[j + 1] += loss_partial_sums[j];
+
+ weight_partial_sums[j] = 0;
+ loss_partial_sums[j] = 0;
+ }
}
}
+
+ const scalar_t total_weight_val = !weight_data ?
+ static_cast<scalar_t>(numiter - num_ignored) :
+ std::accumulate(std::begin(weight_partial_sums),
+ std::end(weight_partial_sums),
+ scalar_t{0});
+
+ scalar_t output_val = std::accumulate(std::begin(loss_partial_sums),
+ std::end(loss_partial_sums),
+ scalar_t{0});
+
if (reduction == Reduction::Mean &&
(total_weight_val != 0 || input.numel() == 0)) {
// allow NaN result for total_weight_val == 0 case, see #15870
diff --git a/aten/src/ATen/native/cpu/utils.h b/aten/src/ATen/native/cpu/utils.h
index 32d1de5..5cdc3eb 100644
--- a/aten/src/ATen/native/cpu/utils.h
+++ b/aten/src/ATen/native/cpu/utils.h
@@ -1,5 +1,8 @@
#pragma once
+#include <ATen/cpu/vec256/vec256.h>
+#include <c10/util/llvmMathExtras.h>
+
namespace at { namespace native { namespace {
template <typename T>
@@ -27,4 +30,21 @@
return false;
}
-}}} // namespace at::native::<anonymous>
+} // namespace
+
+namespace utils {
+
+template <typename T>
+T CeilLog2(const T& x) {
+ if (x <= 2) {
+ return 1;
+ }
+ // Last set bit is floor(log2(x)), floor + 1 is ceil
+ // except when x is an exact powers of 2, so subtract 1 first
+ return static_cast<T>(llvm::findLastSet(static_cast<uint64_t>(x) - 1)) + 1;
+}
+
+} // namespace utils
+
+} // namespace native
+} // namespace at// namespace at::native::<anonymous>
diff --git a/test/test_nn.py b/test/test_nn.py
index 8414c48..d93c5cf 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -10132,6 +10132,17 @@
# TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
self.assertEqualIgnoreType(input.grad, inputf.grad, atol=1e-1, rtol=0)
+ def test_cross_entropy_loss_precision(self):
+ # Regression test for #55657
+ loss_cpu = nn.CrossEntropyLoss().cpu()
+ inputf = torch.randn(128, 2, 768, 768, device="cpu", dtype=torch.float)
+ inputd = inputf.double()
+ target = torch.randint(2, (128, 768, 768), dtype=torch.long)
+
+ outf = loss_cpu(inputf, target)
+ outd = loss_cpu(inputd, target)
+ self.assertEqual(outf, outd, exact_dtype=False)
+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_convert_sync_batchnorm(self):
module = torch.nn.Sequential(