blob: 16817a271959478dd0e51bdf3683818e8567a0d7 [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <cinttypes>
#include <cstdint>
#include <cstring>
#include <tuple>
#include <executorch/kernels/portable/cpu/util/advanced_index_util.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
namespace torch {
namespace executor {
namespace native {
using Tensor = exec_aten::Tensor;
using TensorOptList = exec_aten::ArrayRef<exec_aten::optional<Tensor>>;
Tensor& index_Tensor_out(
RuntimeContext& ctx,
const Tensor& in,
TensorOptList indices,
Tensor& out) {
(void)ctx;
ET_KERNEL_CHECK(
ctx, check_index_args(in, indices, out), InvalidArgument, out);
ScalarType in_type = in.scalar_type();
size_t block_count = count_index_blocks(indices);
// If indices list is empty or all indices are null, just copy the input to
// output and return early.
if (block_count == 0) {
ET_KERNEL_CHECK(
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
ET_SWITCH_REAL_TYPES_AND(
Bool, in_type, ctx, "index.Tensor_out", CTYPE, [&]() {
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
memcpy(out_data, in_data, in.nbytes());
});
return out;
}
// The output shape depends on whether all the non-null indices are adjacent
// or not.
bool adjacent = (block_count == 1);
Tensor::SizesType expected_size[kTensorDimensionLimit];
size_t expected_ndim = 0;
ET_KERNEL_CHECK(
ctx,
get_index_out_target_size(
in, indices, adjacent, expected_size, &expected_ndim),
InvalidArgument,
out);
ET_KERNEL_CHECK(
ctx,
resize_tensor(out, {expected_size, expected_ndim}) == Error::Ok,
InvalidArgument,
out);
if (out.numel() == 0) {
return out;
}
int32_t dim_map[kTensorDimensionLimit];
int32_t ix_map[kTensorDimensionLimit];
size_t start = 0;
size_t xdim = 0;
if (adjacent) {
start = get_num_leading_null_indices(indices);
}
xdim = get_indices_broadcast_ndim(indices);
compute_dim_map(in, indices, dim_map, block_count == 1);
compute_index_map(in, indices, ix_map);
ET_SWITCH_REAL_TYPES_AND(
Bool, in_type, ctx, "index.Tensor_out", CTYPE, [&]() {
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
for (auto out_ix = 0; out_ix < out.numel(); out_ix++) {
size_t in_ix = 0;
bool success = true;
std::tie(in_ix, success) =
get_in_ix(in, indices, out, out_ix, start, xdim, dim_map, ix_map);
ET_KERNEL_CHECK(ctx, success, InvalidArgument, );
out_data[out_ix] = in_data[in_ix];
}
});
return out;
}
} // namespace native
} // namespace executor
} // namespace torch