Migrate nonzero from TH to ATen (CPU) (#58811)

Summary:
Closes gh-24745

The existing PR (gh-50655) has been stalled because `TensorIterator` doesn't guarantee iteration order in the same way that `TH_TENSOR_APPLY` does. For contiguous test cases this isn't an issue; but it breaks down for example with channels last format. I resolve this by adding a new `TensorIteratorConfig` parameter, `enforce_linear_iteration`, which disables dimension reordering. I've also added a test case for non-contiguous tensors to verify this works.

This PR also significantly improves performance by adding multithreading support to the algorithm.  As part of this, I wrote a custom `count_nonzero` that gives per-thread counts which is necessary to write the outputs in the right location.

|    Shape   |  Before | After (1 thread) | After (8 threads) |
|:----------:|--------:|-----------------:|------------------:|
| 256,128,32 | 2610 us |          2220 us |            496 us |
| 128,128,32 | 1250 us |           976 us |            175 us |
|  64,128,32 |  581 us |           486 us |             88 us |
|  32,128,32 |  292 us |           245 us |             80 us |
|  16,128,32 |  147 us |           120 us |             71 us |
|  8,128,32  |   75 us |            61 us |             61 us |
|  4,128,32  |   39 us |            32 us |             32 us |
|  2,128,32  |   20 us |            17 us |             17 us |
|  1,128,32  |   11 us |             9 us |              9 us |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/58811

Reviewed By: anjali411

Differential Revision: D28700259

Pulled By: ngimel

fbshipit-source-id: 9b279ca7c36d8e348b7e5e4be0dd159e05aee159
diff --git a/BUILD.bazel b/BUILD.bazel
index 00408a7..a71406e 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -329,9 +329,7 @@
         "aten/src/TH/THLapack.cpp",
         "aten/src/TH/THStorageFunctions.cpp",
         "aten/src/TH/THTensor.cpp",
-        "aten/src/TH/THTensorEvenMoreMath.cpp",
         "aten/src/TH/THTensorLapack.cpp",
-        "aten/src/TH/THTensorMath.cpp",
         "aten/src/TH/THTensorMoreMath.cpp",
     ],
 )
diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.cpp b/aten/src/ATen/LegacyTHFunctionsCPU.cpp
index f2ac0f3..8eeadd5 100644
--- a/aten/src/ATen/LegacyTHFunctionsCPU.cpp
+++ b/aten/src/ATen/LegacyTHFunctionsCPU.cpp
@@ -35,159 +35,6 @@
   }
 }
 
-Tensor & _th_nonzero_out(const Tensor & self, Tensor & result) {
-    // DeviceGuard omitted
-    auto dispatch_scalar_type = infer_scalar_type(self);
-
-    switch (dispatch_scalar_type) {
-        case ScalarType::Bool: {
-            auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
-            THBoolTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::Byte: {
-            auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
-            THByteTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::Char: {
-            auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
-            THCharTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::Double: {
-            auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
-            THDoubleTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::Float: {
-            auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
-            THFloatTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::Int: {
-            auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
-            THIntTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::Long: {
-            auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
-            THLongTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::Short: {
-            auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
-            THShortTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::Half: {
-            auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
-            THHalfTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::BFloat16: {
-            auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
-            THBFloat16Tensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::ComplexDouble: {
-            auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
-            THComplexDoubleTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::ComplexFloat: {
-            auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
-            THComplexFloatTensor_nonzero(result_, self_);
-            break;
-        }
-        default:
-            AT_ERROR("_th_nonzero_out not supported on CPUType for ", dispatch_scalar_type);
-    }
-    return result;
-}
-Tensor _th_nonzero(const Tensor & self) {
-    // DeviceGuard omitted
-    auto dispatch_scalar_type = infer_scalar_type(self);
-    auto result_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CPU, scalarTypeToTypeMeta(ScalarType::Long)).release();
-    auto result = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(result_));
-    switch (dispatch_scalar_type) {
-        case ScalarType::Bool: {
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
-            THBoolTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::Byte: {
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
-            THByteTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::Char: {
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
-            THCharTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::Double: {
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
-            THDoubleTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::Float: {
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
-            THFloatTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::Int: {
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
-            THIntTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::Long: {
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
-            THLongTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::Short: {
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
-            THShortTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::Half: {
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
-            THHalfTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::BFloat16: {
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
-            THBFloat16Tensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::ComplexDouble: {
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
-            THComplexDoubleTensor_nonzero(result_, self_);
-            break;
-        }
-        case ScalarType::ComplexFloat: {
-            auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
-            THComplexFloatTensor_nonzero(result_, self_);
-            break;
-        }
-        default:
-            AT_ERROR("_th_nonzero not supported on CPUType for ", dispatch_scalar_type);
-    }
-    return result;
-}
 Scalar _th_std_var(const Tensor& self, int64_t correction, bool take_sqrt) {
     // DeviceGuard omitted
     auto dispatch_scalar_type = infer_scalar_type(self);
diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.h b/aten/src/ATen/LegacyTHFunctionsCPU.h
index 0898234..0d273ea 100644
--- a/aten/src/ATen/LegacyTHFunctionsCPU.h
+++ b/aten/src/ATen/LegacyTHFunctionsCPU.h
@@ -20,8 +20,6 @@
 
 Tensor & _th_masked_scatter_(Tensor & self, const Tensor & mask, const Tensor & source);
 Tensor & _th_masked_scatter_bool_(Tensor & self, const Tensor & mask, const Tensor & source);
-Tensor& _th_nonzero_out(const Tensor& self, Tensor& result);
-Tensor _th_nonzero(const Tensor & self);
 Scalar _th_std_var(const Tensor& self, int64_t correction, bool take_sqrt);
 Tensor & _th_renorm_out(const Tensor & self, const Scalar& p, int64_t dim, const Scalar& maxnorm, Tensor & result);
 Tensor _th_renorm(const Tensor & self, const Scalar& p, int64_t dim, const Scalar& maxnorm);
diff --git a/aten/src/ATen/ParallelOpenMP.h b/aten/src/ATen/ParallelOpenMP.h
index 2b37f82..d79c5ec 100644
--- a/aten/src/ATen/ParallelOpenMP.h
+++ b/aten/src/ATen/ParallelOpenMP.h
@@ -17,16 +17,22 @@
     const int64_t end,
     const int64_t grain_size,
     const F& f) {
-  TORCH_CHECK(grain_size >= 0);
-  at::internal::lazy_init_num_threads();
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grain_size >= 0);
   if (begin >= end) {
     return;
   }
-  if (end - begin == 1) {
+
+#ifdef _OPENMP
+  at::internal::lazy_init_num_threads();
+  const auto numiter = end - begin;
+  const bool use_parallel = (
+    numiter > grain_size && numiter > 1 &&
+    omp_get_max_threads() > 1 && !omp_in_parallel());
+  if (!use_parallel) {
     f(begin, end);
     return;
   }
-#ifdef _OPENMP
+
   std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
   std::exception_ptr eptr;
   // Work around memory leak when using 1 thread in nested "omp parallel"
@@ -34,7 +40,7 @@
   // returns false when omp_get_max_threads() == 1 inside nested "omp parallel"
   // See issue gh-32284
 
-#pragma omp parallel if (omp_get_max_threads() > 1 && !omp_in_parallel() && ((end - begin) > grain_size))
+#pragma omp parallel
   {
     // choose number of tasks based on grain size and number of threads
     // can't use num_threads clause due to bugs in GOMP's thread pool (See #32008)
@@ -76,7 +82,8 @@
   at::internal::lazy_init_num_threads();
   if (begin >= end) {
     return ident;
-  } else if (in_parallel_region() || get_num_threads() == 1) {
+  } else if ((end - begin) <= grain_size || in_parallel_region() ||
+             get_num_threads() == 1) {
     return f(begin, end, ident);
   } else {
     const int64_t num_results = divup((end - begin), grain_size);
@@ -84,7 +91,7 @@
     scalar_t* results_data = results.data();
     std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
     std::exception_ptr eptr;
-#pragma omp parallel for if ((end - begin) >= grain_size)
+#pragma omp parallel for
     for (int64_t id = 0; id < num_results; id++) {
       int64_t i = begin + id * grain_size;
       try {
diff --git a/aten/src/ATen/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp
index 58a2523..52b0ef4 100644
--- a/aten/src/ATen/TensorIterator.cpp
+++ b/aten/src/ATen/TensorIterator.cpp
@@ -129,6 +129,12 @@
   // initialize perm with n-1, n-2, ..., 1, 0
   std::iota(perm_.rbegin(), perm_.rend(), 0);
 
+  // Reordering dimensions changes iteraton order
+  if (enforce_linear_iteration_) {
+    permute_dimensions(perm_);
+    return;
+  }
+
   // returns 1 if the dim0 should come after dim1, -1 if dim0 should come
   // before dim1, and 0 if the comparison is ambiguous.
   auto should_swap = [&](size_t dim0, size_t dim1) {
@@ -1213,6 +1219,20 @@
     return FastSetupType::NONE;
   }
 
+  // For linear iteration, only contiguous tensors can be coalesced
+  // Fast setup of any other format requires changing iteration order
+  if (enforce_linear_iteration_) {
+    for (const auto& op : operands_) {
+      if (op.tensor->defined() && !op.will_resize) {
+        auto is_contiguous = op.tensor->is_contiguous(at::MemoryFormat::Contiguous);
+        if (!is_contiguous) {
+          return FastSetupType::NONE;
+        }
+      }
+    }
+    return FastSetupType::CONTIGUOUS;
+  }
+
   bool is_contiguous = true;
   bool is_channels_last = true;
   bool is_non_overlapping_and_dense = true;
@@ -1265,6 +1285,7 @@
 void TensorIteratorBase::build(TensorIteratorConfig& config) {
   // populate some persistent configuration fields
   is_reduction_ = config.is_reduction_;
+  enforce_linear_iteration_ = config.enforce_linear_iteration_;
 
   // fill in operands_ based on configuration
   populate_operands(config);
diff --git a/aten/src/ATen/TensorIterator.h b/aten/src/ATen/TensorIterator.h
index 8f16404..431bfea 100644
--- a/aten/src/ATen/TensorIterator.h
+++ b/aten/src/ATen/TensorIterator.h
@@ -426,6 +426,10 @@
   /// been called?  This is SOLELY used to check validity of perm_.
   bool has_coalesced_dimensions_ = false;
 
+  /// Whether iteration must be fixed. This disables dimension permuting and also
+  /// changes how for_each divides work among threads.
+  bool enforce_linear_iteration_ = false;
+
   /// The index offsets into the original tensors for each dimension.
   /// This is only non-zero when you narrow() a TensorIterator (e.g.,
   /// when you make sub-TensorIterators).
@@ -583,6 +587,17 @@
     return *this;
   }
 
+  // Sets the enforce_linear_iteration_ flag, which is false by default.
+  // If true, iteration goes in the same order as a C-contiguous tensor
+  // is layed out in memory. i.e. last dimension iterates fastest.
+  //
+  // This iteration order can be less efficient and may even prevent vectorization.
+  // So only use if the correctness of your kernel depends on it.
+  TensorIteratorConfig& enforce_linear_iteration(const bool _enforce_linear_iteration = true) {
+    enforce_linear_iteration_ = _enforce_linear_iteration;
+    return *this;
+  }
+
   // Sets the promote_inputs_to_common_dtype_ flag, which is false by default
   // If true, the iterator's "common dtype" is always computed (see the
   //   [Common Dtype Computation] note) and, on the CPU, temporary copies of
@@ -664,6 +679,7 @@
   bool check_all_same_dtype_ = true;
   bool check_all_same_device_ = true;
   bool enforce_safe_casting_to_output_ = false;
+  bool enforce_linear_iteration_ = false;
   bool promote_inputs_to_common_dtype_ = false;
   bool promote_integer_inputs_to_float_ = false;
   bool cast_common_dtype_to_outputs_ = false;
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index 33046fc..446ab2a 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -1746,19 +1746,6 @@
   return at::norm(self - other, p);
 }
 
-Tensor count_nonzero(const Tensor& self, IntArrayRef dims){
-  auto mask = (self != 0);
-  return mask.sum(dims);
-}
-
-Tensor count_nonzero(const Tensor& self, c10::optional<int64_t> dim){
-  if (dim){
-    auto wrap_dim = maybe_wrap_dim(dim.value(), self.dim());
-    return at::count_nonzero(self, IntArrayRef{wrap_dim});
-  }
-  return at::count_nonzero(self, IntArrayRef{});
-}
-
 bool cpu_equal(const Tensor& self, const Tensor& other) {
   if (!at::namedinference::are_names_equal(
         self.unsafeGetTensorImpl(), other.unsafeGetTensorImpl())) {
diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
index e4cf53f..371c5e3 100644
--- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp
+++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
@@ -62,6 +62,7 @@
 #include <ATen/Parallel.h>
 
 #include <c10/util/irange.h>
+#include <c10/util/Unroll.h>
 
 #include <algorithm>
 #include <functional>
@@ -1369,6 +1370,201 @@
     return at::_sparse_coo_tensor_unsafe(sparse_ind, grad.reshape(-1), self.sizes());
 }
 
+template <typename scalar_t>
+int64_t count_nonzero_impl(TensorIteratorBase& iter, Range range) {
+  int64_t num_nonzero = 0;
+
+  auto loop = [&](char** data, const int64_t* strides, int64_t n) {
+    constexpr int ilp_factor = 4;
+    const char* ptr = data[0];
+    const auto stride = strides[0];
+    int64_t nonzero[ilp_factor] = {0};
+
+    int64_t i = 0;
+    for (; i + (ilp_factor - 1) < n; i += ilp_factor) {
+      c10::ForcedUnroll<ilp_factor>{}([&](int k) {
+        const auto& val = *reinterpret_cast<const scalar_t*>(ptr + k * stride);
+        if (val != scalar_t(0)) {
+          ++nonzero[k];
+        }
+      });
+      ptr += ilp_factor * stride;
+    }
+    for (; i < n; ++i) {
+      const auto& val = *reinterpret_cast<const scalar_t*>(ptr);
+      if (val != scalar_t(0)) {
+        ++nonzero[0];
+      }
+      ptr += stride;
+    }
+    for (int64_t k = 1; k < ilp_factor; ++k) {
+      nonzero[0] += nonzero[k];
+    }
+    num_nonzero += nonzero[0];
+  };
+  iter.serial_for_each(loop, range);
+
+  return num_nonzero;
+}
+
+Tensor count_nonzero_cuda(const Tensor& self, IntArrayRef dims){
+  return (self != 0).sum(dims);
+}
+
+Tensor count_nonzero_cpu(const Tensor& self, IntArrayRef dims){
+  if (dims.size() > 0) {
+    return (self != 0).sum(dims);
+  }
+
+  // Optimized all-reduce
+  auto iter = TensorIteratorConfig()
+      .add_input(self)
+      .build();
+
+  const auto num_threads = at::get_num_threads();
+  DimVector thread_count_nonzero(num_threads);
+
+  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
+      kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_count_cpu", [&] {
+    at::parallel_for(0, iter.numel(), internal::GRAIN_SIZE, [&] (int64_t begin, int64_t end) {
+      const auto tid = at::get_thread_num();
+      thread_count_nonzero[tid] = count_nonzero_impl<scalar_t>(iter, {begin, end});
+    });
+  });
+
+  for (int64_t i = 1; i < num_threads; ++i) {
+    thread_count_nonzero[0] += thread_count_nonzero[i];
+  }
+  auto out = at::empty({}, self.options().dtype(kLong));
+  *out.data_ptr<int64_t>() = thread_count_nonzero[0];
+  return out;
+}
+
+
+Tensor count_nonzero(const Tensor& self, c10::optional<int64_t> dim) {
+  if (dim) {
+    return at::count_nonzero(self, IntArrayRef{*dim});
+  }
+  return at::count_nonzero(self, IntArrayRef{});
+}
+
+
+Tensor& nonzero_out_cpu(const Tensor& self, Tensor& result) {
+  TORCH_CHECK(result.scalar_type() == kLong,
+              "nonzero: Expected out tensor to have scalar type Long "
+              "but got scalar type", result.scalar_type());
+  at::assert_no_internal_overlap(result);
+  at::assert_no_overlap(result, self);
+
+  auto iter = TensorIteratorConfig()
+    .add_input(self)
+    .enforce_linear_iteration()
+    .build();
+
+  const auto numel = iter.numel();
+  const auto num_threads = at::get_num_threads();
+  DimVector thread_begin(num_threads, -1);
+  DimVector thread_count_nonzero(num_threads + 1);
+
+  // Pass 1: Count nonzero element per-thread
+  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
+      kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_count_cpu", [&] {
+    at::parallel_for(0, numel, internal::GRAIN_SIZE, [&] (int64_t begin, int64_t end) {
+      const auto tid = at::get_thread_num();
+      thread_begin[tid] = begin;
+      thread_count_nonzero[tid + 1] = count_nonzero_impl<scalar_t>(iter, {begin, end});
+    });
+  });
+
+  // Convert thread-local counts to cumulative sum
+  for (size_t i = 1; i < thread_count_nonzero.size(); ++i) {
+    thread_count_nonzero[i] += thread_count_nonzero[i - 1];
+  }
+
+  const auto self_sizes = self.sizes();
+  const auto total_nonzero = thread_count_nonzero.back();
+  const int64_t ndim = self_sizes.size();
+  if (resize_output(result, {total_nonzero, ndim})) {
+    // Default to fortran-contiguous output (see gh-46224)
+    result.as_strided_({total_nonzero, ndim}, {1, total_nonzero});
+  }
+
+  if (result.numel() == 0) {
+    return result;
+  }
+
+  // Pass 2: Write indexes
+  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
+      kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_cpu", [&] {
+    at::parallel_for(0, numel, internal::GRAIN_SIZE, [&] (int64_t begin, int64_t end) {
+      auto tid = at::get_thread_num();
+      // Work needs to be distributed the same on both passes
+      TORCH_INTERNAL_ASSERT_DEBUG_ONLY(begin == thread_begin[tid]);
+
+      // +1 faster than additional condition check inside loop
+      c10::SmallVector<int64_t, 33> sizes(ndim + 1, -1);
+      std::copy(self_sizes.begin(), self_sizes.end(), sizes.begin() + 1);
+      c10::SmallVector<int64_t, 33> current_idx(ndim + 1);
+      if (begin > 0) {
+        auto idx = begin;
+        for (int64_t k = ndim; idx > 0 && k > 0; --k) {
+          current_idx[k] = idx % sizes[k];
+          idx /= sizes[k];
+        }
+      }
+
+      auto out_accessor = result.accessor<int64_t, 2>();
+      auto out_ptr = out_accessor[thread_count_nonzero[tid]].data();
+
+      auto loop = [&](char** data, const int64_t* strides, int64_t n1, int64_t n2) {
+        // Copy into local variables to improve compiler alias analysis
+        int64_t* C10_RESTRICT local_idx = current_idx.data() + 1;
+        const int64_t* C10_RESTRICT local_sizes = sizes.data() + 1;
+        const auto in_stride = strides[0];
+        const auto out_stride1 = out_accessor.stride(1);
+        const auto out_stride0 = out_accessor.stride(0) - ndim * out_stride1;
+        const auto ndim = out_accessor.size(1);
+        int64_t* out = out_ptr;
+
+        for (int64_t i = 0; i < n2; ++i) {
+          const char* ptr = data[0] + i * strides[1];
+          for (int64_t j = 0; j < n1; ++j) {
+            const auto& val = *reinterpret_cast<const scalar_t*>(ptr);
+            // If nonzero, write index
+            if (val != scalar_t(0)) {
+              for (int64_t k = 0; k < ndim; ++k) {
+                *out = local_idx[k];
+                out += out_stride1;
+              }
+              out += out_stride0;
+            }
+            ptr += in_stride;
+
+            // Advance current index
+            int64_t k = ndim - 1;
+            ++local_idx[k];
+            while (C10_UNLIKELY(local_idx[k] == local_sizes[k])) {
+              local_idx[k] = 0;
+              --k;
+              ++local_idx[k];
+            }
+          }
+        }
+        out_ptr = out;
+      };
+      iter.serial_for_each(loop, {begin, end});
+      TORCH_INTERNAL_ASSERT(out_ptr == out_accessor[thread_count_nonzero[tid + 1]].data());
+    });
+  });
+  return result;
+}
+
+Tensor nonzero_cpu(const Tensor& self) {
+  auto result = at::empty({0}, self.options().dtype(kLong));
+  nonzero_out_cpu(self, result);
+  return result;
+}
+
 std::vector<Tensor> nonzero_numpy(const Tensor& self) {
   // special case scalar for compatibility with numpy:
   //
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 3579fe9..5cf405b 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1244,7 +1244,8 @@
 - func: count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor
   variants: function, method
   dispatch:
-    CPU, CUDA: count_nonzero
+    CPU: count_nonzero_cpu
+    CUDA: count_nonzero_cuda
 
 - func: count_nonzero(Tensor self, int? dim=None) -> Tensor
   variants: function, method
@@ -6228,13 +6229,13 @@
 
 - func: nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
   dispatch:
-    CPU: legacy::cpu::_th_nonzero_out
+    CPU: nonzero_out_cpu
     CUDA: nonzero_out_cuda
 
 - func: nonzero(Tensor self) -> Tensor
   variants: method, function
   dispatch:
-    CPU: legacy::cpu::_th_nonzero
+    CPU: nonzero_cpu
     CUDA: nonzero_cuda
 
 - func: nonzero_numpy(Tensor self) -> Tensor[]
diff --git a/aten/src/TH/CMakeLists.txt b/aten/src/TH/CMakeLists.txt
index 483d360..4db5d83 100644
--- a/aten/src/TH/CMakeLists.txt
+++ b/aten/src/TH/CMakeLists.txt
@@ -9,9 +9,7 @@
   ${CMAKE_CURRENT_SOURCE_DIR}/THAllocator.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/THStorageFunctions.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/THTensor.cpp
-  ${CMAKE_CURRENT_SOURCE_DIR}/THTensorMath.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/THTensorMoreMath.cpp
-  ${CMAKE_CURRENT_SOURCE_DIR}/THTensorEvenMoreMath.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/THTensorLapack.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/THBlas.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/THLapack.cpp
@@ -91,7 +89,6 @@
   generic/THTensor.hpp
   generic/THTensorLapack.cpp
   generic/THTensorLapack.h
-  generic/THTensorMath.cpp
   generic/THTensorMath.h
   generic/THVector.h
   # See Note [TH abstraction violation]
diff --git a/aten/src/TH/THTensorEvenMoreMath.cpp b/aten/src/TH/THTensorEvenMoreMath.cpp
deleted file mode 100644
index 02662b5..0000000
--- a/aten/src/TH/THTensorEvenMoreMath.cpp
+++ /dev/null
@@ -1,23 +0,0 @@
-#include <TH/THTensor.hpp>
-#include <TH/THVector.h>
-#include <TH/THBlas.h>
-#include <TH/THTensorDimApply.h>
-
-// NOLINTNEXTLINE(bugprone-suspicious-include)
-#include <TH/generic/THTensorEvenMoreMath.cpp>
-#include <TH/THGenerateAllTypes.h>
-
-// NOLINTNEXTLINE(bugprone-suspicious-include)
-#include <TH/generic/THTensorEvenMoreMath.cpp>
-#include <TH/THGenerateBoolType.h>
-
-// NOLINTNEXTLINE(bugprone-suspicious-include)
-#include <TH/generic/THTensorEvenMoreMath.cpp>
-#include <TH/THGenerateHalfType.h>
-
-// NOLINTNEXTLINE(bugprone-suspicious-include)
-#include <TH/generic/THTensorEvenMoreMath.cpp>
-#include <TH/THGenerateBFloat16Type.h>
-
-#include <TH/generic/THTensorEvenMoreMath.cpp>
-#include <TH/THGenerateComplexTypes.h>
diff --git a/aten/src/TH/THTensorMath.cpp b/aten/src/TH/THTensorMath.cpp
deleted file mode 100644
index 1a81bd3..0000000
--- a/aten/src/TH/THTensorMath.cpp
+++ /dev/null
@@ -1,16 +0,0 @@
-#include <TH/THTensor.hpp>
-#include <TH/THVector.h>
-#include <TH/THBlas.h>
-#include <TH/THTensorDimApply.h>
-
-// NOLINTNEXTLINE(bugprone-suspicious-include)
-#include <TH/generic/THTensorMath.cpp>
-#include <TH/THGenerateAllTypes.h>
-
-// NOLINTNEXTLINE(bugprone-suspicious-include)
-#include <TH/generic/THTensorMath.cpp>
-#include <TH/THGenerateBFloat16Type.h>
-
-// NOLINTNEXTLINE(bugprone-suspicious-include)
-#include <TH/generic/THTensorMath.cpp>
-#include <TH/THGenerateBoolType.h>
diff --git a/aten/src/TH/generic/THTensorEvenMoreMath.cpp b/aten/src/TH/generic/THTensorEvenMoreMath.cpp
deleted file mode 100644
index 35f0e9f..0000000
--- a/aten/src/TH/generic/THTensorEvenMoreMath.cpp
+++ /dev/null
@@ -1,76 +0,0 @@
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "TH/generic/THTensorEvenMoreMath.cpp"
-#else
-
-#include <TH/generic/THTensorApply.hpp>
-#include <ATen/NamedTensorUtils.h>
-#include <ATen/WrapDimUtils.h>
-#include <ATen/MemoryOverlap.h>
-
-// Finds non-zero elements of a tensor and returns their subscripts
-void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
-{
-  ptrdiff_t numel = 0;
-  int64_t *subscript_data;
-  int64_t i = 0;
-#ifdef TH_REAL_IS_HALF
-#define IS_NONZERO(val) (c10::Half(0)!=val)
-#elif defined(TH_REAL_IS_BFLOAT16)
-#define IS_NONZERO(val) (c10::BFloat16(0)!=val)
-#else
-#define IS_NONZERO(val) ((val)!=scalar_t(0))
-#endif
-
-  /* First Pass to determine size of subscripts */
-  TH_TENSOR_APPLY(scalar_t, tensor,
-                  if IS_NONZERO(*tensor_data) {
-                    ++numel;
-                  });
-#ifdef DEBUG
-  THAssert(numel <= LONG_MAX);
-#endif
-  THLongTensor_resize2d(subscript, numel, tensor->dim());
-  if (numel <= 0) {
-    return;
-  }
-  int64_t dimensions = tensor->dim();
-  // +1 faster than additional condition check inside loop
-  int64_t *sizes = new int64_t[dimensions+1];
-  int64_t *idx = new int64_t[dimensions+1];
-  int64_t *ii;
-  int64_t *ss;
-  std::fill(idx, idx+dimensions+1, 0);
-  for (i = 0; i < dimensions; ++i) {
-    sizes[dimensions - i - 1] = THTensor_(size)(tensor, i); // reverse order important
-  }
-  sizes[dimensions] = 0;
-  /* Second pass populates subscripts */
-  subscript_data = THLongTensor_data(subscript);
-  auto subscript_strides = THTensor_stridesLegacyNoScalars(subscript);
-  subscript_strides[0] -= subscript_strides[1] * tensor->dim();
-  TH_TENSOR_APPLY(scalar_t, tensor,
-                  if IS_NONZERO(*tensor_data) {
-                    ii = idx + dimensions;
-                    for (int64_t dim = dimensions - 1; dim >= 0; dim--) {
-                      --ii;
-                      *subscript_data = *ii;
-                      subscript_data += subscript_strides[1];
-                    }
-                    subscript_data += subscript_strides[0];
-                  }
-                  ii = idx;
-                  ss = sizes;
-                  ++(*ii);
-                  while (*ii == *ss) {
-                    *ii = 0;
-                    ++ii;
-                    ++ss;
-                    ++(*ii);
-                  }
-                );
-  delete [] sizes;
-  delete [] idx;
-
-#undef IS_NONZERO
-}
-#endif /* TH_GENERIC_FILE */
diff --git a/aten/src/TH/generic/THTensorMath.cpp b/aten/src/TH/generic/THTensorMath.cpp
deleted file mode 100644
index eb3b593..0000000
--- a/aten/src/TH/generic/THTensorMath.cpp
+++ /dev/null
@@ -1,25 +0,0 @@
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "TH/generic/THTensorMath.cpp"
-#else
-
-#include <TH/generic/THTensorApply.hpp>
-#include <ATen/NamedTensorUtils.h>
-
-// HEY YOU!
-//
-// Looking for a function which used to be in THTensorMath.cpp, but
-// can't find it anymore?  Check THTensorMoreMath.cpp and
-// THTensorEvenMoreMath.cpp.  These source files have been split up
-// because they were getting too big (a whopping 4669 lines at time
-// of writing) and causing MSVC to run out of memory.  Did you come
-// here because you saw:
-//
-//    fatal error C1002: compiler is out of heap space in pass 2
-//
-// Try splitting up the file some more.
-//
-// At some point, we should reorganize these files in a way that makes
-// sense (rather than just having cut the file down the middle, which is
-// what I did when I split these up originally).
-
-#endif /* TH_GENERIC_FILE */
diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h
index 50c34da..b0b294c 100644
--- a/aten/src/TH/generic/THTensorMath.h
+++ b/aten/src/TH/generic/THTensorMath.h
@@ -4,7 +4,6 @@
 
 #include <ATen/core/Generator.h>
 
-TH_API void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor);
 TH_API int THTensor_(equal)(THTensor *ta, THTensor *tb);
 
 #if !defined(TH_REAL_IS_HALF)
diff --git a/c10/util/Unroll.h b/c10/util/Unroll.h
new file mode 100644
index 0000000..f74f593
--- /dev/null
+++ b/c10/util/Unroll.h
@@ -0,0 +1,29 @@
+#pragma once
+#include <c10/macros/Macros.h>
+
+// Utility to guaruntee complete unrolling of a loop where the bounds are known
+// at compile time. Various pragmas achieve similar effects, but are not as
+// portable across compilers.
+
+// Example: c10::ForcedUnroll<4>{}(f); is equivalent to f(0); f(1); f(2); f(3);
+
+namespace c10 {
+
+template <int n>
+struct ForcedUnroll {
+  template <typename Func>
+  C10_ALWAYS_INLINE void operator()(const Func& f) const {
+    ForcedUnroll<n - 1>{}(f);
+    f(n - 1);
+  }
+};
+
+template <>
+struct ForcedUnroll<1> {
+  template <typename Func>
+  C10_ALWAYS_INLINE void operator()(const Func& f) const {
+    f(0);
+  }
+};
+
+} // namespace c10
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 5b9b873..7066f14 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -1212,9 +1212,13 @@
         # Test code from issue https://github.com/pytorch/pytorch/issues/45113
         batch_size, input_size, hidden_size = 5, 3, 7
 
-        # Create coalesced sparse tensor as in the issue
+        # Create coalesced sparse tensor with non-contiguous indices
         weight = torch.randn(hidden_size, input_size, dtype=dtype, device=device).to_sparse()
         self.assertTrue(weight.is_coalesced())
+        non_contig_indices = weight.indices().transpose(-1, -2).contiguous().transpose(-1, -2)
+        weight = torch.sparse_coo_tensor(
+            indices=non_contig_indices, values=weight.values(), size=weight.shape)
+        weight._coalesced_(True)
         self.assertFalse(weight._indices().is_contiguous())
         # Create un/coalesced sparse tensor
         bias = torch.randn((hidden_size, 1), dtype=dtype, device=device).to_sparse()
diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py
index 29f62c2..929ba2a 100644
--- a/test/test_unary_ufuncs.py
+++ b/test/test_unary_ufuncs.py
@@ -1438,6 +1438,34 @@
         self.assertEqual(1, len(z))
         self.assertEqual(torch.empty(0, dtype=torch.long), z[0])
 
+    @dtypes(*torch.testing.get_all_dtypes())
+    def test_nonzero_noncontiguous(self, device, dtype):
+        x = make_tensor((10, 10, 10), dtype=dtype, device=device,
+                        low=1, noncontiguous=False)
+        mask = make_tensor((10, 10, 10), dtype=torch.bool, device=device)
+        x[mask] = 0
+
+        def permute_storage(tensor, dims):
+            dest_dims = tuple(range(len(dims)))
+            return tensor.permute(dims).contiguous().movedim(dims, dest_dims)
+
+        # Assume contiguous case is correct
+        expect = x.nonzero()
+
+        # Dense, permuted
+        self.assertEqual(permute_storage(x, [0, 2, 1]).nonzero(), expect)
+        self.assertEqual(permute_storage(x, [2, 1, 0]).nonzero(), expect)
+
+        # Non-dense
+        nondense = torch.empty((40, 10, 20), dtype=dtype, device=device)[::4, :, ::2]
+        nondense[:] = x
+        self.assertEqual(nondense.nonzero(), expect)
+
+        # Non-dense, permuted
+        nondense = nondense.permute([0, 2, 1])
+        nondense[:] = x
+        self.assertEqual(nondense.nonzero(), expect)
+
     # TODO: rationalize with exp OpInfo
     @dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False) +
               torch.testing.get_all_complex_dtypes()))
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index 5168d21..6d6b094 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -1031,9 +1031,7 @@
     "aten/src/TH/THLapack.cpp",
     "aten/src/TH/THStorageFunctions.cpp",
     "aten/src/TH/THTensor.cpp",
-    "aten/src/TH/THTensorEvenMoreMath.cpp",
     "aten/src/TH/THTensorLapack.cpp",
-    "aten/src/TH/THTensorMath.cpp",
     "aten/src/TH/THTensorMoreMath.cpp",
     "aten/src/ATen/native/utils/Factory.cpp",
     "aten/src/ATen/native/xnnpack/Activation.cpp",