blob: c88585f88a5a6b812c522d280bd27003c4b101b3 [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/assert.h>
namespace torch {
namespace executor {
namespace native {
namespace {
bool check_flip_args(const Tensor& in, IntArrayRef dims, const Tensor& out) {
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
return check_dim_list_is_valid(in, dims);
}
size_t unflip_flat_ix(size_t ix, const Tensor& in, ArrayRef<bool> flip_dim) {
size_t ix_coord[kTensorDimensionLimit];
indexToCoordinate(in, ix, ix_coord);
size_t unflip_coord[kTensorDimensionLimit];
for (size_t d = 0; d < in.dim(); d++) {
if (flip_dim[d]) {
unflip_coord[d] = in.size(d) - ix_coord[d] - 1;
} else {
unflip_coord[d] = ix_coord[d];
}
}
return coordinateToIndex(in, unflip_coord);
}
} // namespace
Tensor&
flip_out(RuntimeContext& ctx, const Tensor& in, IntArrayRef dims, Tensor& out) {
(void)ctx;
ET_KERNEL_CHECK(
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
ET_KERNEL_CHECK(ctx, check_flip_args(in, dims, out), InvalidArgument, out);
bool flip_dim_data[kTensorDimensionLimit];
for (size_t i = 0; i < in.dim(); i++) {
flip_dim_data[i] = false;
}
for (size_t i = 0; i < dims.size(); i++) {
const auto d = dims[i] < 0 ? dims[i] + nonzero_dim(in) : dims[i];
flip_dim_data[d] = true;
}
size_t flip_dim_length = static_cast<size_t>(in.dim()); // NOLINT
ArrayRef<bool> flip_dim(flip_dim_data, flip_dim_length);
constexpr auto name = "flip.out";
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] {
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
for (size_t ix = 0; ix < out.numel(); ++ix) {
out_data[ix] = in_data[unflip_flat_ix(ix, in, flip_dim)];
}
});
return out;
}
} // namespace native
} // namespace executor
} // namespace torch