Migrate nonzero from TH to ATen (CPU) (#58811)
Summary:
Closes gh-24745
The existing PR (gh-50655) has been stalled because `TensorIterator` doesn't guarantee iteration order in the same way that `TH_TENSOR_APPLY` does. For contiguous test cases this isn't an issue; but it breaks down for example with channels last format. I resolve this by adding a new `TensorIteratorConfig` parameter, `enforce_linear_iteration`, which disables dimension reordering. I've also added a test case for non-contiguous tensors to verify this works.
This PR also significantly improves performance by adding multithreading support to the algorithm. As part of this, I wrote a custom `count_nonzero` that gives per-thread counts which is necessary to write the outputs in the right location.
| Shape | Before | After (1 thread) | After (8 threads) |
|:----------:|--------:|-----------------:|------------------:|
| 256,128,32 | 2610 us | 2220 us | 496 us |
| 128,128,32 | 1250 us | 976 us | 175 us |
| 64,128,32 | 581 us | 486 us | 88 us |
| 32,128,32 | 292 us | 245 us | 80 us |
| 16,128,32 | 147 us | 120 us | 71 us |
| 8,128,32 | 75 us | 61 us | 61 us |
| 4,128,32 | 39 us | 32 us | 32 us |
| 2,128,32 | 20 us | 17 us | 17 us |
| 1,128,32 | 11 us | 9 us | 9 us |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58811
Reviewed By: anjali411
Differential Revision: D28700259
Pulled By: ngimel
fbshipit-source-id: 9b279ca7c36d8e348b7e5e4be0dd159e05aee159
diff --git a/BUILD.bazel b/BUILD.bazel
index 00408a7..a71406e 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -329,9 +329,7 @@
"aten/src/TH/THLapack.cpp",
"aten/src/TH/THStorageFunctions.cpp",
"aten/src/TH/THTensor.cpp",
- "aten/src/TH/THTensorEvenMoreMath.cpp",
"aten/src/TH/THTensorLapack.cpp",
- "aten/src/TH/THTensorMath.cpp",
"aten/src/TH/THTensorMoreMath.cpp",
],
)
diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.cpp b/aten/src/ATen/LegacyTHFunctionsCPU.cpp
index f2ac0f3..8eeadd5 100644
--- a/aten/src/ATen/LegacyTHFunctionsCPU.cpp
+++ b/aten/src/ATen/LegacyTHFunctionsCPU.cpp
@@ -35,159 +35,6 @@
}
}
-Tensor & _th_nonzero_out(const Tensor & self, Tensor & result) {
- // DeviceGuard omitted
- auto dispatch_scalar_type = infer_scalar_type(self);
-
- switch (dispatch_scalar_type) {
- case ScalarType::Bool: {
- auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
- THBoolTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::Byte: {
- auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
- THByteTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::Char: {
- auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
- THCharTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::Double: {
- auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
- THDoubleTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::Float: {
- auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
- THFloatTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::Int: {
- auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
- THIntTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::Long: {
- auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
- THLongTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::Short: {
- auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
- THShortTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::Half: {
- auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
- THHalfTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::BFloat16: {
- auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
- THBFloat16Tensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::ComplexDouble: {
- auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
- THComplexDoubleTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::ComplexFloat: {
- auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_nonzero_out", false, DeviceType::CPU, ScalarType::Long);
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero_out", false, DeviceType::CPU, dispatch_scalar_type);
- THComplexFloatTensor_nonzero(result_, self_);
- break;
- }
- default:
- AT_ERROR("_th_nonzero_out not supported on CPUType for ", dispatch_scalar_type);
- }
- return result;
-}
-Tensor _th_nonzero(const Tensor & self) {
- // DeviceGuard omitted
- auto dispatch_scalar_type = infer_scalar_type(self);
- auto result_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CPU, scalarTypeToTypeMeta(ScalarType::Long)).release();
- auto result = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(result_));
- switch (dispatch_scalar_type) {
- case ScalarType::Bool: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
- THBoolTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::Byte: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
- THByteTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::Char: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
- THCharTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::Double: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
- THDoubleTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::Float: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
- THFloatTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::Int: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
- THIntTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::Long: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
- THLongTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::Short: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
- THShortTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::Half: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
- THHalfTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::BFloat16: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
- THBFloat16Tensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::ComplexDouble: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
- THComplexDoubleTensor_nonzero(result_, self_);
- break;
- }
- case ScalarType::ComplexFloat: {
- auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_nonzero", false, DeviceType::CPU, dispatch_scalar_type);
- THComplexFloatTensor_nonzero(result_, self_);
- break;
- }
- default:
- AT_ERROR("_th_nonzero not supported on CPUType for ", dispatch_scalar_type);
- }
- return result;
-}
Scalar _th_std_var(const Tensor& self, int64_t correction, bool take_sqrt) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);
diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.h b/aten/src/ATen/LegacyTHFunctionsCPU.h
index 0898234..0d273ea 100644
--- a/aten/src/ATen/LegacyTHFunctionsCPU.h
+++ b/aten/src/ATen/LegacyTHFunctionsCPU.h
@@ -20,8 +20,6 @@
Tensor & _th_masked_scatter_(Tensor & self, const Tensor & mask, const Tensor & source);
Tensor & _th_masked_scatter_bool_(Tensor & self, const Tensor & mask, const Tensor & source);
-Tensor& _th_nonzero_out(const Tensor& self, Tensor& result);
-Tensor _th_nonzero(const Tensor & self);
Scalar _th_std_var(const Tensor& self, int64_t correction, bool take_sqrt);
Tensor & _th_renorm_out(const Tensor & self, const Scalar& p, int64_t dim, const Scalar& maxnorm, Tensor & result);
Tensor _th_renorm(const Tensor & self, const Scalar& p, int64_t dim, const Scalar& maxnorm);
diff --git a/aten/src/ATen/ParallelOpenMP.h b/aten/src/ATen/ParallelOpenMP.h
index 2b37f82..d79c5ec 100644
--- a/aten/src/ATen/ParallelOpenMP.h
+++ b/aten/src/ATen/ParallelOpenMP.h
@@ -17,16 +17,22 @@
const int64_t end,
const int64_t grain_size,
const F& f) {
- TORCH_CHECK(grain_size >= 0);
- at::internal::lazy_init_num_threads();
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grain_size >= 0);
if (begin >= end) {
return;
}
- if (end - begin == 1) {
+
+#ifdef _OPENMP
+ at::internal::lazy_init_num_threads();
+ const auto numiter = end - begin;
+ const bool use_parallel = (
+ numiter > grain_size && numiter > 1 &&
+ omp_get_max_threads() > 1 && !omp_in_parallel());
+ if (!use_parallel) {
f(begin, end);
return;
}
-#ifdef _OPENMP
+
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
std::exception_ptr eptr;
// Work around memory leak when using 1 thread in nested "omp parallel"
@@ -34,7 +40,7 @@
// returns false when omp_get_max_threads() == 1 inside nested "omp parallel"
// See issue gh-32284
-#pragma omp parallel if (omp_get_max_threads() > 1 && !omp_in_parallel() && ((end - begin) > grain_size))
+#pragma omp parallel
{
// choose number of tasks based on grain size and number of threads
// can't use num_threads clause due to bugs in GOMP's thread pool (See #32008)
@@ -76,7 +82,8 @@
at::internal::lazy_init_num_threads();
if (begin >= end) {
return ident;
- } else if (in_parallel_region() || get_num_threads() == 1) {
+ } else if ((end - begin) <= grain_size || in_parallel_region() ||
+ get_num_threads() == 1) {
return f(begin, end, ident);
} else {
const int64_t num_results = divup((end - begin), grain_size);
@@ -84,7 +91,7 @@
scalar_t* results_data = results.data();
std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
std::exception_ptr eptr;
-#pragma omp parallel for if ((end - begin) >= grain_size)
+#pragma omp parallel for
for (int64_t id = 0; id < num_results; id++) {
int64_t i = begin + id * grain_size;
try {
diff --git a/aten/src/ATen/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp
index 58a2523..52b0ef4 100644
--- a/aten/src/ATen/TensorIterator.cpp
+++ b/aten/src/ATen/TensorIterator.cpp
@@ -129,6 +129,12 @@
// initialize perm with n-1, n-2, ..., 1, 0
std::iota(perm_.rbegin(), perm_.rend(), 0);
+ // Reordering dimensions changes iteraton order
+ if (enforce_linear_iteration_) {
+ permute_dimensions(perm_);
+ return;
+ }
+
// returns 1 if the dim0 should come after dim1, -1 if dim0 should come
// before dim1, and 0 if the comparison is ambiguous.
auto should_swap = [&](size_t dim0, size_t dim1) {
@@ -1213,6 +1219,20 @@
return FastSetupType::NONE;
}
+ // For linear iteration, only contiguous tensors can be coalesced
+ // Fast setup of any other format requires changing iteration order
+ if (enforce_linear_iteration_) {
+ for (const auto& op : operands_) {
+ if (op.tensor->defined() && !op.will_resize) {
+ auto is_contiguous = op.tensor->is_contiguous(at::MemoryFormat::Contiguous);
+ if (!is_contiguous) {
+ return FastSetupType::NONE;
+ }
+ }
+ }
+ return FastSetupType::CONTIGUOUS;
+ }
+
bool is_contiguous = true;
bool is_channels_last = true;
bool is_non_overlapping_and_dense = true;
@@ -1265,6 +1285,7 @@
void TensorIteratorBase::build(TensorIteratorConfig& config) {
// populate some persistent configuration fields
is_reduction_ = config.is_reduction_;
+ enforce_linear_iteration_ = config.enforce_linear_iteration_;
// fill in operands_ based on configuration
populate_operands(config);
diff --git a/aten/src/ATen/TensorIterator.h b/aten/src/ATen/TensorIterator.h
index 8f16404..431bfea 100644
--- a/aten/src/ATen/TensorIterator.h
+++ b/aten/src/ATen/TensorIterator.h
@@ -426,6 +426,10 @@
/// been called? This is SOLELY used to check validity of perm_.
bool has_coalesced_dimensions_ = false;
+ /// Whether iteration must be fixed. This disables dimension permuting and also
+ /// changes how for_each divides work among threads.
+ bool enforce_linear_iteration_ = false;
+
/// The index offsets into the original tensors for each dimension.
/// This is only non-zero when you narrow() a TensorIterator (e.g.,
/// when you make sub-TensorIterators).
@@ -583,6 +587,17 @@
return *this;
}
+ // Sets the enforce_linear_iteration_ flag, which is false by default.
+ // If true, iteration goes in the same order as a C-contiguous tensor
+ // is layed out in memory. i.e. last dimension iterates fastest.
+ //
+ // This iteration order can be less efficient and may even prevent vectorization.
+ // So only use if the correctness of your kernel depends on it.
+ TensorIteratorConfig& enforce_linear_iteration(const bool _enforce_linear_iteration = true) {
+ enforce_linear_iteration_ = _enforce_linear_iteration;
+ return *this;
+ }
+
// Sets the promote_inputs_to_common_dtype_ flag, which is false by default
// If true, the iterator's "common dtype" is always computed (see the
// [Common Dtype Computation] note) and, on the CPU, temporary copies of
@@ -664,6 +679,7 @@
bool check_all_same_dtype_ = true;
bool check_all_same_device_ = true;
bool enforce_safe_casting_to_output_ = false;
+ bool enforce_linear_iteration_ = false;
bool promote_inputs_to_common_dtype_ = false;
bool promote_integer_inputs_to_float_ = false;
bool cast_common_dtype_to_outputs_ = false;
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index 33046fc..446ab2a 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -1746,19 +1746,6 @@
return at::norm(self - other, p);
}
-Tensor count_nonzero(const Tensor& self, IntArrayRef dims){
- auto mask = (self != 0);
- return mask.sum(dims);
-}
-
-Tensor count_nonzero(const Tensor& self, c10::optional<int64_t> dim){
- if (dim){
- auto wrap_dim = maybe_wrap_dim(dim.value(), self.dim());
- return at::count_nonzero(self, IntArrayRef{wrap_dim});
- }
- return at::count_nonzero(self, IntArrayRef{});
-}
-
bool cpu_equal(const Tensor& self, const Tensor& other) {
if (!at::namedinference::are_names_equal(
self.unsafeGetTensorImpl(), other.unsafeGetTensorImpl())) {
diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
index e4cf53f..371c5e3 100644
--- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp
+++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
@@ -62,6 +62,7 @@
#include <ATen/Parallel.h>
#include <c10/util/irange.h>
+#include <c10/util/Unroll.h>
#include <algorithm>
#include <functional>
@@ -1369,6 +1370,201 @@
return at::_sparse_coo_tensor_unsafe(sparse_ind, grad.reshape(-1), self.sizes());
}
+template <typename scalar_t>
+int64_t count_nonzero_impl(TensorIteratorBase& iter, Range range) {
+ int64_t num_nonzero = 0;
+
+ auto loop = [&](char** data, const int64_t* strides, int64_t n) {
+ constexpr int ilp_factor = 4;
+ const char* ptr = data[0];
+ const auto stride = strides[0];
+ int64_t nonzero[ilp_factor] = {0};
+
+ int64_t i = 0;
+ for (; i + (ilp_factor - 1) < n; i += ilp_factor) {
+ c10::ForcedUnroll<ilp_factor>{}([&](int k) {
+ const auto& val = *reinterpret_cast<const scalar_t*>(ptr + k * stride);
+ if (val != scalar_t(0)) {
+ ++nonzero[k];
+ }
+ });
+ ptr += ilp_factor * stride;
+ }
+ for (; i < n; ++i) {
+ const auto& val = *reinterpret_cast<const scalar_t*>(ptr);
+ if (val != scalar_t(0)) {
+ ++nonzero[0];
+ }
+ ptr += stride;
+ }
+ for (int64_t k = 1; k < ilp_factor; ++k) {
+ nonzero[0] += nonzero[k];
+ }
+ num_nonzero += nonzero[0];
+ };
+ iter.serial_for_each(loop, range);
+
+ return num_nonzero;
+}
+
+Tensor count_nonzero_cuda(const Tensor& self, IntArrayRef dims){
+ return (self != 0).sum(dims);
+}
+
+Tensor count_nonzero_cpu(const Tensor& self, IntArrayRef dims){
+ if (dims.size() > 0) {
+ return (self != 0).sum(dims);
+ }
+
+ // Optimized all-reduce
+ auto iter = TensorIteratorConfig()
+ .add_input(self)
+ .build();
+
+ const auto num_threads = at::get_num_threads();
+ DimVector thread_count_nonzero(num_threads);
+
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
+ kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_count_cpu", [&] {
+ at::parallel_for(0, iter.numel(), internal::GRAIN_SIZE, [&] (int64_t begin, int64_t end) {
+ const auto tid = at::get_thread_num();
+ thread_count_nonzero[tid] = count_nonzero_impl<scalar_t>(iter, {begin, end});
+ });
+ });
+
+ for (int64_t i = 1; i < num_threads; ++i) {
+ thread_count_nonzero[0] += thread_count_nonzero[i];
+ }
+ auto out = at::empty({}, self.options().dtype(kLong));
+ *out.data_ptr<int64_t>() = thread_count_nonzero[0];
+ return out;
+}
+
+
+Tensor count_nonzero(const Tensor& self, c10::optional<int64_t> dim) {
+ if (dim) {
+ return at::count_nonzero(self, IntArrayRef{*dim});
+ }
+ return at::count_nonzero(self, IntArrayRef{});
+}
+
+
+Tensor& nonzero_out_cpu(const Tensor& self, Tensor& result) {
+ TORCH_CHECK(result.scalar_type() == kLong,
+ "nonzero: Expected out tensor to have scalar type Long "
+ "but got scalar type", result.scalar_type());
+ at::assert_no_internal_overlap(result);
+ at::assert_no_overlap(result, self);
+
+ auto iter = TensorIteratorConfig()
+ .add_input(self)
+ .enforce_linear_iteration()
+ .build();
+
+ const auto numel = iter.numel();
+ const auto num_threads = at::get_num_threads();
+ DimVector thread_begin(num_threads, -1);
+ DimVector thread_count_nonzero(num_threads + 1);
+
+ // Pass 1: Count nonzero element per-thread
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
+ kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_count_cpu", [&] {
+ at::parallel_for(0, numel, internal::GRAIN_SIZE, [&] (int64_t begin, int64_t end) {
+ const auto tid = at::get_thread_num();
+ thread_begin[tid] = begin;
+ thread_count_nonzero[tid + 1] = count_nonzero_impl<scalar_t>(iter, {begin, end});
+ });
+ });
+
+ // Convert thread-local counts to cumulative sum
+ for (size_t i = 1; i < thread_count_nonzero.size(); ++i) {
+ thread_count_nonzero[i] += thread_count_nonzero[i - 1];
+ }
+
+ const auto self_sizes = self.sizes();
+ const auto total_nonzero = thread_count_nonzero.back();
+ const int64_t ndim = self_sizes.size();
+ if (resize_output(result, {total_nonzero, ndim})) {
+ // Default to fortran-contiguous output (see gh-46224)
+ result.as_strided_({total_nonzero, ndim}, {1, total_nonzero});
+ }
+
+ if (result.numel() == 0) {
+ return result;
+ }
+
+ // Pass 2: Write indexes
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
+ kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_cpu", [&] {
+ at::parallel_for(0, numel, internal::GRAIN_SIZE, [&] (int64_t begin, int64_t end) {
+ auto tid = at::get_thread_num();
+ // Work needs to be distributed the same on both passes
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(begin == thread_begin[tid]);
+
+ // +1 faster than additional condition check inside loop
+ c10::SmallVector<int64_t, 33> sizes(ndim + 1, -1);
+ std::copy(self_sizes.begin(), self_sizes.end(), sizes.begin() + 1);
+ c10::SmallVector<int64_t, 33> current_idx(ndim + 1);
+ if (begin > 0) {
+ auto idx = begin;
+ for (int64_t k = ndim; idx > 0 && k > 0; --k) {
+ current_idx[k] = idx % sizes[k];
+ idx /= sizes[k];
+ }
+ }
+
+ auto out_accessor = result.accessor<int64_t, 2>();
+ auto out_ptr = out_accessor[thread_count_nonzero[tid]].data();
+
+ auto loop = [&](char** data, const int64_t* strides, int64_t n1, int64_t n2) {
+ // Copy into local variables to improve compiler alias analysis
+ int64_t* C10_RESTRICT local_idx = current_idx.data() + 1;
+ const int64_t* C10_RESTRICT local_sizes = sizes.data() + 1;
+ const auto in_stride = strides[0];
+ const auto out_stride1 = out_accessor.stride(1);
+ const auto out_stride0 = out_accessor.stride(0) - ndim * out_stride1;
+ const auto ndim = out_accessor.size(1);
+ int64_t* out = out_ptr;
+
+ for (int64_t i = 0; i < n2; ++i) {
+ const char* ptr = data[0] + i * strides[1];
+ for (int64_t j = 0; j < n1; ++j) {
+ const auto& val = *reinterpret_cast<const scalar_t*>(ptr);
+ // If nonzero, write index
+ if (val != scalar_t(0)) {
+ for (int64_t k = 0; k < ndim; ++k) {
+ *out = local_idx[k];
+ out += out_stride1;
+ }
+ out += out_stride0;
+ }
+ ptr += in_stride;
+
+ // Advance current index
+ int64_t k = ndim - 1;
+ ++local_idx[k];
+ while (C10_UNLIKELY(local_idx[k] == local_sizes[k])) {
+ local_idx[k] = 0;
+ --k;
+ ++local_idx[k];
+ }
+ }
+ }
+ out_ptr = out;
+ };
+ iter.serial_for_each(loop, {begin, end});
+ TORCH_INTERNAL_ASSERT(out_ptr == out_accessor[thread_count_nonzero[tid + 1]].data());
+ });
+ });
+ return result;
+}
+
+Tensor nonzero_cpu(const Tensor& self) {
+ auto result = at::empty({0}, self.options().dtype(kLong));
+ nonzero_out_cpu(self, result);
+ return result;
+}
+
std::vector<Tensor> nonzero_numpy(const Tensor& self) {
// special case scalar for compatibility with numpy:
//
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 3579fe9..5cf405b 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1244,7 +1244,8 @@
- func: count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor
variants: function, method
dispatch:
- CPU, CUDA: count_nonzero
+ CPU: count_nonzero_cpu
+ CUDA: count_nonzero_cuda
- func: count_nonzero(Tensor self, int? dim=None) -> Tensor
variants: function, method
@@ -6228,13 +6229,13 @@
- func: nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
- CPU: legacy::cpu::_th_nonzero_out
+ CPU: nonzero_out_cpu
CUDA: nonzero_out_cuda
- func: nonzero(Tensor self) -> Tensor
variants: method, function
dispatch:
- CPU: legacy::cpu::_th_nonzero
+ CPU: nonzero_cpu
CUDA: nonzero_cuda
- func: nonzero_numpy(Tensor self) -> Tensor[]
diff --git a/aten/src/TH/CMakeLists.txt b/aten/src/TH/CMakeLists.txt
index 483d360..4db5d83 100644
--- a/aten/src/TH/CMakeLists.txt
+++ b/aten/src/TH/CMakeLists.txt
@@ -9,9 +9,7 @@
${CMAKE_CURRENT_SOURCE_DIR}/THAllocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/THStorageFunctions.cpp
${CMAKE_CURRENT_SOURCE_DIR}/THTensor.cpp
- ${CMAKE_CURRENT_SOURCE_DIR}/THTensorMath.cpp
${CMAKE_CURRENT_SOURCE_DIR}/THTensorMoreMath.cpp
- ${CMAKE_CURRENT_SOURCE_DIR}/THTensorEvenMoreMath.cpp
${CMAKE_CURRENT_SOURCE_DIR}/THTensorLapack.cpp
${CMAKE_CURRENT_SOURCE_DIR}/THBlas.cpp
${CMAKE_CURRENT_SOURCE_DIR}/THLapack.cpp
@@ -91,7 +89,6 @@
generic/THTensor.hpp
generic/THTensorLapack.cpp
generic/THTensorLapack.h
- generic/THTensorMath.cpp
generic/THTensorMath.h
generic/THVector.h
# See Note [TH abstraction violation]
diff --git a/aten/src/TH/THTensorEvenMoreMath.cpp b/aten/src/TH/THTensorEvenMoreMath.cpp
deleted file mode 100644
index 02662b5..0000000
--- a/aten/src/TH/THTensorEvenMoreMath.cpp
+++ /dev/null
@@ -1,23 +0,0 @@
-#include <TH/THTensor.hpp>
-#include <TH/THVector.h>
-#include <TH/THBlas.h>
-#include <TH/THTensorDimApply.h>
-
-// NOLINTNEXTLINE(bugprone-suspicious-include)
-#include <TH/generic/THTensorEvenMoreMath.cpp>
-#include <TH/THGenerateAllTypes.h>
-
-// NOLINTNEXTLINE(bugprone-suspicious-include)
-#include <TH/generic/THTensorEvenMoreMath.cpp>
-#include <TH/THGenerateBoolType.h>
-
-// NOLINTNEXTLINE(bugprone-suspicious-include)
-#include <TH/generic/THTensorEvenMoreMath.cpp>
-#include <TH/THGenerateHalfType.h>
-
-// NOLINTNEXTLINE(bugprone-suspicious-include)
-#include <TH/generic/THTensorEvenMoreMath.cpp>
-#include <TH/THGenerateBFloat16Type.h>
-
-#include <TH/generic/THTensorEvenMoreMath.cpp>
-#include <TH/THGenerateComplexTypes.h>
diff --git a/aten/src/TH/THTensorMath.cpp b/aten/src/TH/THTensorMath.cpp
deleted file mode 100644
index 1a81bd3..0000000
--- a/aten/src/TH/THTensorMath.cpp
+++ /dev/null
@@ -1,16 +0,0 @@
-#include <TH/THTensor.hpp>
-#include <TH/THVector.h>
-#include <TH/THBlas.h>
-#include <TH/THTensorDimApply.h>
-
-// NOLINTNEXTLINE(bugprone-suspicious-include)
-#include <TH/generic/THTensorMath.cpp>
-#include <TH/THGenerateAllTypes.h>
-
-// NOLINTNEXTLINE(bugprone-suspicious-include)
-#include <TH/generic/THTensorMath.cpp>
-#include <TH/THGenerateBFloat16Type.h>
-
-// NOLINTNEXTLINE(bugprone-suspicious-include)
-#include <TH/generic/THTensorMath.cpp>
-#include <TH/THGenerateBoolType.h>
diff --git a/aten/src/TH/generic/THTensorEvenMoreMath.cpp b/aten/src/TH/generic/THTensorEvenMoreMath.cpp
deleted file mode 100644
index 35f0e9f..0000000
--- a/aten/src/TH/generic/THTensorEvenMoreMath.cpp
+++ /dev/null
@@ -1,76 +0,0 @@
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "TH/generic/THTensorEvenMoreMath.cpp"
-#else
-
-#include <TH/generic/THTensorApply.hpp>
-#include <ATen/NamedTensorUtils.h>
-#include <ATen/WrapDimUtils.h>
-#include <ATen/MemoryOverlap.h>
-
-// Finds non-zero elements of a tensor and returns their subscripts
-void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
-{
- ptrdiff_t numel = 0;
- int64_t *subscript_data;
- int64_t i = 0;
-#ifdef TH_REAL_IS_HALF
-#define IS_NONZERO(val) (c10::Half(0)!=val)
-#elif defined(TH_REAL_IS_BFLOAT16)
-#define IS_NONZERO(val) (c10::BFloat16(0)!=val)
-#else
-#define IS_NONZERO(val) ((val)!=scalar_t(0))
-#endif
-
- /* First Pass to determine size of subscripts */
- TH_TENSOR_APPLY(scalar_t, tensor,
- if IS_NONZERO(*tensor_data) {
- ++numel;
- });
-#ifdef DEBUG
- THAssert(numel <= LONG_MAX);
-#endif
- THLongTensor_resize2d(subscript, numel, tensor->dim());
- if (numel <= 0) {
- return;
- }
- int64_t dimensions = tensor->dim();
- // +1 faster than additional condition check inside loop
- int64_t *sizes = new int64_t[dimensions+1];
- int64_t *idx = new int64_t[dimensions+1];
- int64_t *ii;
- int64_t *ss;
- std::fill(idx, idx+dimensions+1, 0);
- for (i = 0; i < dimensions; ++i) {
- sizes[dimensions - i - 1] = THTensor_(size)(tensor, i); // reverse order important
- }
- sizes[dimensions] = 0;
- /* Second pass populates subscripts */
- subscript_data = THLongTensor_data(subscript);
- auto subscript_strides = THTensor_stridesLegacyNoScalars(subscript);
- subscript_strides[0] -= subscript_strides[1] * tensor->dim();
- TH_TENSOR_APPLY(scalar_t, tensor,
- if IS_NONZERO(*tensor_data) {
- ii = idx + dimensions;
- for (int64_t dim = dimensions - 1; dim >= 0; dim--) {
- --ii;
- *subscript_data = *ii;
- subscript_data += subscript_strides[1];
- }
- subscript_data += subscript_strides[0];
- }
- ii = idx;
- ss = sizes;
- ++(*ii);
- while (*ii == *ss) {
- *ii = 0;
- ++ii;
- ++ss;
- ++(*ii);
- }
- );
- delete [] sizes;
- delete [] idx;
-
-#undef IS_NONZERO
-}
-#endif /* TH_GENERIC_FILE */
diff --git a/aten/src/TH/generic/THTensorMath.cpp b/aten/src/TH/generic/THTensorMath.cpp
deleted file mode 100644
index eb3b593..0000000
--- a/aten/src/TH/generic/THTensorMath.cpp
+++ /dev/null
@@ -1,25 +0,0 @@
-#ifndef TH_GENERIC_FILE
-#define TH_GENERIC_FILE "TH/generic/THTensorMath.cpp"
-#else
-
-#include <TH/generic/THTensorApply.hpp>
-#include <ATen/NamedTensorUtils.h>
-
-// HEY YOU!
-//
-// Looking for a function which used to be in THTensorMath.cpp, but
-// can't find it anymore? Check THTensorMoreMath.cpp and
-// THTensorEvenMoreMath.cpp. These source files have been split up
-// because they were getting too big (a whopping 4669 lines at time
-// of writing) and causing MSVC to run out of memory. Did you come
-// here because you saw:
-//
-// fatal error C1002: compiler is out of heap space in pass 2
-//
-// Try splitting up the file some more.
-//
-// At some point, we should reorganize these files in a way that makes
-// sense (rather than just having cut the file down the middle, which is
-// what I did when I split these up originally).
-
-#endif /* TH_GENERIC_FILE */
diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h
index 50c34da..b0b294c 100644
--- a/aten/src/TH/generic/THTensorMath.h
+++ b/aten/src/TH/generic/THTensorMath.h
@@ -4,7 +4,6 @@
#include <ATen/core/Generator.h>
-TH_API void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor);
TH_API int THTensor_(equal)(THTensor *ta, THTensor *tb);
#if !defined(TH_REAL_IS_HALF)
diff --git a/c10/util/Unroll.h b/c10/util/Unroll.h
new file mode 100644
index 0000000..f74f593
--- /dev/null
+++ b/c10/util/Unroll.h
@@ -0,0 +1,29 @@
+#pragma once
+#include <c10/macros/Macros.h>
+
+// Utility to guaruntee complete unrolling of a loop where the bounds are known
+// at compile time. Various pragmas achieve similar effects, but are not as
+// portable across compilers.
+
+// Example: c10::ForcedUnroll<4>{}(f); is equivalent to f(0); f(1); f(2); f(3);
+
+namespace c10 {
+
+template <int n>
+struct ForcedUnroll {
+ template <typename Func>
+ C10_ALWAYS_INLINE void operator()(const Func& f) const {
+ ForcedUnroll<n - 1>{}(f);
+ f(n - 1);
+ }
+};
+
+template <>
+struct ForcedUnroll<1> {
+ template <typename Func>
+ C10_ALWAYS_INLINE void operator()(const Func& f) const {
+ f(0);
+ }
+};
+
+} // namespace c10
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 5b9b873..7066f14 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -1212,9 +1212,13 @@
# Test code from issue https://github.com/pytorch/pytorch/issues/45113
batch_size, input_size, hidden_size = 5, 3, 7
- # Create coalesced sparse tensor as in the issue
+ # Create coalesced sparse tensor with non-contiguous indices
weight = torch.randn(hidden_size, input_size, dtype=dtype, device=device).to_sparse()
self.assertTrue(weight.is_coalesced())
+ non_contig_indices = weight.indices().transpose(-1, -2).contiguous().transpose(-1, -2)
+ weight = torch.sparse_coo_tensor(
+ indices=non_contig_indices, values=weight.values(), size=weight.shape)
+ weight._coalesced_(True)
self.assertFalse(weight._indices().is_contiguous())
# Create un/coalesced sparse tensor
bias = torch.randn((hidden_size, 1), dtype=dtype, device=device).to_sparse()
diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py
index 29f62c2..929ba2a 100644
--- a/test/test_unary_ufuncs.py
+++ b/test/test_unary_ufuncs.py
@@ -1438,6 +1438,34 @@
self.assertEqual(1, len(z))
self.assertEqual(torch.empty(0, dtype=torch.long), z[0])
+ @dtypes(*torch.testing.get_all_dtypes())
+ def test_nonzero_noncontiguous(self, device, dtype):
+ x = make_tensor((10, 10, 10), dtype=dtype, device=device,
+ low=1, noncontiguous=False)
+ mask = make_tensor((10, 10, 10), dtype=torch.bool, device=device)
+ x[mask] = 0
+
+ def permute_storage(tensor, dims):
+ dest_dims = tuple(range(len(dims)))
+ return tensor.permute(dims).contiguous().movedim(dims, dest_dims)
+
+ # Assume contiguous case is correct
+ expect = x.nonzero()
+
+ # Dense, permuted
+ self.assertEqual(permute_storage(x, [0, 2, 1]).nonzero(), expect)
+ self.assertEqual(permute_storage(x, [2, 1, 0]).nonzero(), expect)
+
+ # Non-dense
+ nondense = torch.empty((40, 10, 20), dtype=dtype, device=device)[::4, :, ::2]
+ nondense[:] = x
+ self.assertEqual(nondense.nonzero(), expect)
+
+ # Non-dense, permuted
+ nondense = nondense.permute([0, 2, 1])
+ nondense[:] = x
+ self.assertEqual(nondense.nonzero(), expect)
+
# TODO: rationalize with exp OpInfo
@dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False) +
torch.testing.get_all_complex_dtypes()))
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index 5168d21..6d6b094 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -1031,9 +1031,7 @@
"aten/src/TH/THLapack.cpp",
"aten/src/TH/THStorageFunctions.cpp",
"aten/src/TH/THTensor.cpp",
- "aten/src/TH/THTensorEvenMoreMath.cpp",
"aten/src/TH/THTensorLapack.cpp",
- "aten/src/TH/THTensorMath.cpp",
"aten/src/TH/THTensorMoreMath.cpp",
"aten/src/ATen/native/utils/Factory.cpp",
"aten/src/ATen/native/xnnpack/Activation.cpp",