blob: e6600850cc8d19c1c9a7698ce0143f9639bd5885 [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/runtime/kernel/kernel_includes.h>
#include <cstddef>
namespace torch {
namespace executor {
namespace native {
namespace {
bool check_roll_args(
const Tensor& in,
IntArrayRef shifts,
IntArrayRef dims,
const Tensor& out) {
for (const auto& d : dims) {
if (in.dim() == 0) {
ET_LOG_AND_RETURN_IF_FALSE(d == 0 || d == -1);
} else {
ET_LOG_AND_RETURN_IF_FALSE(dim_is_valid(d, in.dim()));
}
}
ET_LOG_AND_RETURN_IF_FALSE(!shifts.empty());
ET_LOG_AND_RETURN_IF_FALSE(shifts.size() == dims.size());
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
return true;
}
size_t unshift_flat_ix(size_t ix, const Tensor& in, IntArrayRef dim_shifts) {
size_t ix_coord[kTensorDimensionLimit];
indexToCoordinate(in, ix, ix_coord);
size_t shifted_coord[kTensorDimensionLimit];
for (size_t d = 0; d < in.dim(); d++) {
shifted_coord[d] = (ix_coord[d] - dim_shifts[d]) % in.size(d);
}
return coordinateToIndex(in, shifted_coord);
}
} // namespace
Tensor& roll_out(
RuntimeContext& ctx,
const Tensor& in,
IntArrayRef shifts,
IntArrayRef dims,
Tensor& out) {
(void)ctx;
ET_KERNEL_CHECK(
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
ET_KERNEL_CHECK(
ctx, check_roll_args(in, shifts, dims, out), InvalidArgument, out);
constexpr auto name = "roll.out";
int64_t dim_shift_array[kTensorDimensionLimit];
for (size_t i = 0; i < in.dim(); i++) {
dim_shift_array[i] = 0;
}
for (size_t i = 0; i < dims.size(); i++) {
const auto d = dims[i] < 0 ? dims[i] + in.dim() : dims[i];
dim_shift_array[d] += shifts[i];
}
size_t dim_shift_array_length = static_cast<size_t>(in.dim()); // NOLINT
IntArrayRef dim_shifts(dim_shift_array, dim_shift_array_length);
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] {
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
for (size_t ix = 0; ix < out.numel(); ++ix) {
out_data[ix] = in_data[unshift_flat_ix(ix, in, dim_shifts)];
}
});
return out;
}
} // namespace native
} // namespace executor
} // namespace torch