blob: 644ebc9842032e1fe0b11d4d2a1c5fcbb010fe11 [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <cstring>
#include <executorch/kernels/portable/cpu/util/repeat_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/assert.h>
namespace torch {
namespace executor {
namespace native {
namespace {
bool calculate_output_size(
const exec_aten::ArrayRef<exec_aten::SizesType>& self_sizes,
const exec_aten::ArrayRef<int64_t>& repeats,
Tensor::SizesType* out_sizes_ptr) {
ET_LOG_AND_RETURN_IF_FALSE(repeats.size() < kTensorDimensionLimit);
ET_LOG_MSG_AND_RETURN_IF_FALSE(
repeats.size() >= self_sizes.size(),
"Repeats vector size is %zu must be >= self_sizes %zu.",
repeats.size(),
self_sizes.size());
int32_t i = 0;
for (; i < (repeats.size() - self_sizes.size()); ++i) {
out_sizes_ptr[i] = static_cast<exec_aten::SizesType>(repeats[i]);
}
int32_t j = 0;
for (; i < repeats.size(); ++i) {
out_sizes_ptr[i] =
static_cast<exec_aten::SizesType>(repeats[i]) * self_sizes[j];
j++;
}
return true;
}
} // namespace
using Tensor = exec_aten::Tensor;
// repeat.out(Tensor self, int[] repeats, *, Tensor(a!) out) -> Tensor(a!)
Tensor& repeat_out(
RuntimeContext& ctx,
const Tensor& self,
exec_aten::ArrayRef<int64_t> repeats,
Tensor& out) {
(void)ctx;
Tensor::SizesType expected_output_size[kTensorDimensionLimit];
ET_KERNEL_CHECK(
ctx,
calculate_output_size(self.sizes(), repeats, expected_output_size),
InvalidArgument,
out);
// Resize for dynamic shape
ET_KERNEL_CHECK_MSG(
ctx,
resize_tensor(out, {expected_output_size, repeats.size()}) == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");
ET_KERNEL_CHECK(
ctx,
repeat_tensor(self, repeats, out) == Error::Ok,
InvalidArgument,
out);
return out;
}
} // namespace native
} // namespace executor
} // namespace torch