blob: 6cf7fd81da6ad318f8b1f7e57d77c884722bc153 [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <cmath>
#include <cstring>
#include <executorch/kernels/portable/cpu/util/index_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/assert.h>
namespace torch {
namespace executor {
namespace native {
using Tensor = exec_aten::Tensor;
using ScalarType = exec_aten::ScalarType;
using SizesType = exec_aten::SizesType;
namespace {
void increment_index(size_t* index, const ArrayRef<SizesType> sizes) {
for (ssize_t i = sizes.size() - 1; i >= 0; --i) {
index[i]++;
if (index[i] == sizes[i]) {
index[i] = 0;
} else {
return;
}
}
}
/**
* Two pass algorithm where we first count the number of non zeros, then resize
* out to the appropriate size, and then loop again and properly write into out
*/
template <typename CTYPE>
void nonzero(RuntimeContext& ctx, const Tensor& input, Tensor& output) {
const CTYPE* in_data = input.const_data_ptr<CTYPE>();
size_t lim = input.numel();
int32_t num_nonzero = 0;
// Count number of non zeros
for (size_t i = 0; i < lim; ++i) {
if (in_data[i] != 0) {
num_nonzero++;
}
}
// resize out
SizesType out_shape[2] = {
static_cast<SizesType>(num_nonzero), static_cast<SizesType>(input.dim())};
ET_KERNEL_CHECK(
ctx,
resize_tensor(output, ArrayRef<exec_aten::SizesType>(out_shape, 2)) ==
Error::Ok,
InvalidArgument, );
size_t index[kTensorDimensionLimit];
memset(index, 0, sizeof(index));
int64_t* out_data = output.mutable_data_ptr<int64_t>();
size_t out_idx = 0;
// Loop again and this time write the proper indices into out
for (size_t i = 0; i < lim; i++) {
if (in_data[i] != 0) {
for (size_t j = 0; j < input.dim(); j++) {
out_data[out_idx++] = index[j];
}
}
increment_index(index, input.sizes());
}
}
} // namespace
/**
* Determines the non zero indices of input.
* Out is a 2-D tensor where every row is a non zero index of the input.
*/
Tensor& nonzero_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
(void)ctx;
ET_KERNEL_CHECK(ctx, check_nonzero_args(in, out), InvalidArgument, out);
ET_SWITCH_REAL_TYPES_AND(
Bool, in.scalar_type(), ctx, "nonzero.out", CTYPE, [&] {
nonzero<CTYPE>(ctx, in, out);
});
return out;
}
} // namespace native
} // namespace executor
} // namespace torch