Use cascade summation in nll_loss on CPU (#55841)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/55657

This also avoids summing `total_weight_val` when weights aren't supplied. Avoiding accumulated error completely.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/55841

Reviewed By: jbschlosser

Differential Revision: D27751492

Pulled By: ngimel

fbshipit-source-id: 2c2dc48f31c25dfa9db48693e3f765b179771a3c
diff --git a/aten/src/ATen/native/LossNLL.cpp b/aten/src/ATen/native/LossNLL.cpp
index 378d364..76371b2 100644
--- a/aten/src/ATen/native/LossNLL.cpp
+++ b/aten/src/ATen/native/LossNLL.cpp
@@ -3,6 +3,7 @@
 #include <ATen/Dispatch.h>
 #include <ATen/Parallel.h>
 #include <ATen/TensorUtils.h>
+#include <ATen/native/cpu/utils.h>
 
 namespace at {
 namespace native {
@@ -23,6 +24,7 @@
   return source.defined() ? source.data_ptr<scalar_t>() : nullptr;
 }
 
+
 template <typename scalar_t>
 static void nll_loss_out_frame(
     Tensor& output,
@@ -82,43 +84,67 @@
   const scalar_t* input_data = input_contiguous.data_ptr<scalar_t>();
   const int64_t* target_data = target_contiguous.data_ptr<int64_t>();
 
-  scalar_t output_val = 0;
-  scalar_t total_weight_val = 0;
+  const int64_t ndim = input.dim();
+  TORCH_CHECK(ndim <= 2);
+  const int64_t batch_size = ndim == 1 ? 1 : input.size(0);
+  TORCH_CHECK(target.size(0) == batch_size);
 
-  if (input.dim() == 1) {
-    const auto cur_target = target_data[0];
-    if (cur_target != ignore_index) {
-      TORCH_CHECK_INDEX(
-          cur_target >= 0 && cur_target < n_classes,
-          "Target ",
-          cur_target,
-          " is out of bounds.");
-      total_weight_val =
-          weight_data ? weight_data[cur_target] : static_cast<scalar_t>(1);
-      output_val = -input_data[cur_target] * total_weight_val;
+  constexpr int64_t cascade_sum_num_levels = 8;
+  const int64_t level_power =
+      std::max(int64_t(4), utils::CeilLog2(batch_size) / cascade_sum_num_levels);
+  const int64_t level_step = (1 << level_power);
+  const int64_t level_mask = level_step - 1;
+
+  int64_t num_ignored = 0;
+
+  scalar_t weight_partial_sums[cascade_sum_num_levels] = {0};
+  scalar_t loss_partial_sums[cascade_sum_num_levels] = {0};
+  for (int64_t b = 0; b < batch_size; b++) {
+    const int64_t cur_target = target_data[b];
+    if (cur_target == ignore_index) {
+      ++num_ignored;
+      continue;
     }
-  } else if (input.dim() == 2) {
-    const auto batch_size = input.size(0);
-    TORCH_CHECK(target.size(0) == batch_size);
-    const auto n_target = input.size(1);
 
-    for (int64_t i = 0; i < batch_size; i++) {
-      const auto cur_target = target_data[i];
-      if (cur_target != ignore_index) {
-        TORCH_CHECK_INDEX(
-            cur_target >= 0 && cur_target < n_classes,
-            "Target ",
-            cur_target,
-            " is out of bounds.");
+    TORCH_CHECK_INDEX(
+        cur_target >= 0 && cur_target < n_classes,
+        "Target ",
+        cur_target,
+        " is out of bounds.");
 
-        scalar_t cur_weight =
-            weight_data ? weight_data[cur_target] : static_cast<scalar_t>(1);
-        total_weight_val += cur_weight;
-        output_val -= input_data[i * n_target + cur_target] * cur_weight;
+    const auto data = input_data[b * n_classes + cur_target];
+    if (weight_data) {
+      const scalar_t weight_val = weight_data[cur_target];
+      loss_partial_sums[0] -= data * weight_val;
+      weight_partial_sums[0] += weight_val;
+    } else {
+      loss_partial_sums[0] -= data;
+    }
+
+    for (int64_t j = 0; j + 1 < cascade_sum_num_levels; ++j) {
+      const auto mask = (level_mask << (j * level_power));
+      if (C10_LIKELY((b & mask) != 0)) {
+        break;
       }
+
+      weight_partial_sums[j + 1] += weight_partial_sums[j];
+      loss_partial_sums[j + 1] += loss_partial_sums[j];
+
+      weight_partial_sums[j] = 0;
+      loss_partial_sums[j] = 0;
     }
   }
 
+  const scalar_t total_weight_val = !weight_data ?
+    static_cast<scalar_t>(batch_size - num_ignored) :
+    std::accumulate(std::begin(weight_partial_sums),
+                    std::end(weight_partial_sums),
+                    scalar_t{0});
+
+  scalar_t output_val = std::accumulate(std::begin(loss_partial_sums),
+                                        std::end(loss_partial_sums),
+                                        scalar_t{0});
+
   if (reduction == Reduction::Mean &&
       (total_weight_val != 0 || input.numel() == 0)) {
     // allow NaN result for total_weight_val == 0 case, see #15870
diff --git a/aten/src/ATen/native/LossNLL2d.cpp b/aten/src/ATen/native/LossNLL2d.cpp
index 6f98702..3587d02 100644
--- a/aten/src/ATen/native/LossNLL2d.cpp
+++ b/aten/src/ATen/native/LossNLL2d.cpp
@@ -3,6 +3,7 @@
 #include <ATen/Dispatch.h>
 #include <ATen/Parallel.h>
 #include <ATen/TensorUtils.h>
+#include <ATen/native/cpu/utils.h>
 
 namespace at {
 namespace native {
@@ -78,6 +79,7 @@
       target.sizes());
 }
 
+
 template <typename scalar_t>
 static void nll_loss2d_forward_out_frame(
     Tensor& output,
@@ -147,14 +149,22 @@
   const int64_t batch_size = input.size(0);
   const int64_t map_size = input.size(2) * input.size(3);
   const int64_t sample_size = map_size * n_classes;
+  const int64_t numiter = batch_size * map_size;
 
-  scalar_t total_weight_val = 0;
-  scalar_t output_val = 0;
+  constexpr int64_t cascade_sum_num_levels = 8;
+  scalar_t weight_partial_sums[cascade_sum_num_levels] = {0};
+  scalar_t loss_partial_sums[cascade_sum_num_levels] = {0};
+  const int64_t level_power =
+      std::max(int64_t(4), utils::CeilLog2(numiter) / cascade_sum_num_levels);
+  const int64_t level_step = (1 << level_power);
+  const int64_t level_mask = level_step - 1;
+
+  int64_t num_ignored = 0;
   for (int64_t b = 0; b < batch_size; b++) {
     for (int64_t elem = 0; elem < map_size; elem++) {
       const int64_t cur_target = target_data[b * map_size + elem];
-
       if (cur_target == ignore_index) {
+        ++num_ignored;
         continue;
       }
 
@@ -164,14 +174,42 @@
           cur_target,
           " is out of bounds.");
 
-      const scalar_t weight_val =
-          weight_data ? weight_data[cur_target] : static_cast<scalar_t>(1);
-      total_weight_val += weight_val;
-      output_val -= input_data[b * sample_size + cur_target * map_size + elem] *
-          weight_val;
+      const auto data = input_data[b * sample_size + cur_target * map_size + elem];
+      if (weight_data) {
+        const scalar_t weight_val = weight_data[cur_target];
+        loss_partial_sums[0] -= data * weight_val;
+        weight_partial_sums[0] += weight_val;
+      } else {
+        loss_partial_sums[0] -= data;
+      }
+
+      const int64_t linear_idx = b * map_size + elem;
+      for (int64_t j = 0; j + 1 < cascade_sum_num_levels; ++j) {
+        const auto mask = (level_mask << (j * level_power));
+        if (C10_LIKELY((linear_idx & mask) != 0)) {
+          break;
+        }
+
+        weight_partial_sums[j + 1] += weight_partial_sums[j];
+        loss_partial_sums[j + 1] += loss_partial_sums[j];
+
+        weight_partial_sums[j] = 0;
+        loss_partial_sums[j] = 0;
+      }
     }
   }
 
+
+  const scalar_t total_weight_val = !weight_data ?
+    static_cast<scalar_t>(numiter - num_ignored) :
+    std::accumulate(std::begin(weight_partial_sums),
+                    std::end(weight_partial_sums),
+                    scalar_t{0});
+
+  scalar_t output_val = std::accumulate(std::begin(loss_partial_sums),
+                                        std::end(loss_partial_sums),
+                                        scalar_t{0});
+
   if (reduction == Reduction::Mean &&
       (total_weight_val != 0 || input.numel() == 0)) {
     // allow NaN result for total_weight_val == 0 case, see #15870
diff --git a/aten/src/ATen/native/cpu/utils.h b/aten/src/ATen/native/cpu/utils.h
index 32d1de5..5cdc3eb 100644
--- a/aten/src/ATen/native/cpu/utils.h
+++ b/aten/src/ATen/native/cpu/utils.h
@@ -1,5 +1,8 @@
 #pragma once
 
+#include <ATen/cpu/vec256/vec256.h>
+#include <c10/util/llvmMathExtras.h>
+
 namespace at { namespace native { namespace {
 
 template <typename T>
@@ -27,4 +30,21 @@
   return false;
 }
 
-}}}  // namespace at::native::<anonymous>
+} // namespace
+
+namespace utils {
+
+template <typename T>
+T CeilLog2(const T& x) {
+  if (x <= 2) {
+    return 1;
+  }
+  // Last set bit is floor(log2(x)), floor + 1 is ceil
+  // except when x is an exact powers of 2, so subtract 1 first
+  return static_cast<T>(llvm::findLastSet(static_cast<uint64_t>(x) - 1)) + 1;
+}
+
+} // namespace utils
+
+} // namespace native
+} // namespace at// namespace at::native::<anonymous>
diff --git a/test/test_nn.py b/test/test_nn.py
index 8414c48..d93c5cf 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -10132,6 +10132,17 @@
         # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
         self.assertEqualIgnoreType(input.grad, inputf.grad, atol=1e-1, rtol=0)
 
+    def test_cross_entropy_loss_precision(self):
+        # Regression test for #55657
+        loss_cpu = nn.CrossEntropyLoss().cpu()
+        inputf = torch.randn(128, 2, 768, 768, device="cpu", dtype=torch.float)
+        inputd = inputf.double()
+        target = torch.randint(2, (128, 768, 768), dtype=torch.long)
+
+        outf = loss_cpu(inputf, target)
+        outd = loss_cpu(inputd, target)
+        self.assertEqual(outf, outd, exact_dtype=False)
+
     @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
     def test_convert_sync_batchnorm(self):
         module = torch.nn.Sequential(