blob: b3d3592fe7ba86f3d93d71127746b5b5865c748c [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/kernels/prim_ops/et_view.h>
#include <cstring>
#include <executorch/runtime/core/array_ref.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
#include <executorch/runtime/platform/assert.h>
using exec_aten::SizesType;
using exec_aten::Tensor;
using torch::executor::Error;
using torch::executor::resize_tensor;
namespace torch {
namespace executor {
namespace function {
constexpr size_t kTensorDimensionLimit = 16;
namespace {
bool get_view_target_size(
const exec_aten::Tensor self,
exec_aten::ArrayRef<int64_t> size,
int64_t dim,
exec_aten::SizesType* out_size) {
ET_LOG_AND_RETURN_IF_FALSE(size.size() == dim);
int minus1_dim = -1;
int n_zero = 0;
int64_t numel_without_minus_1 = 1;
for (int i = 0; i < dim; i++) {
if (size[i] == -1) {
ET_LOG_MSG_AND_RETURN_IF_FALSE(
minus1_dim == -1, "At most one view dim can be -1.");
minus1_dim = i;
} else {
// The size[i] must be non-negative now, but we check size[i] >= -1
// in case code is reordered in the future.
ET_LOG_MSG_AND_RETURN_IF_FALSE(
size[i] >= -1, "Negative sizes are not allowed.");
numel_without_minus_1 *= size[i];
out_size[i] = static_cast<exec_aten::SizesType>(size[i]);
if (size[i] == 0) {
n_zero++;
}
}
}
if (minus1_dim >= 0) {
ET_LOG_MSG_AND_RETURN_IF_FALSE(
n_zero == 0, "Cannot infer dimension size if there is a zero dim.");
out_size[minus1_dim] = self.numel() / numel_without_minus_1;
}
return true;
}
} // namespace
void et_view(RuntimeContext& context, EValue** stack) {
(void)context;
auto self = (*stack[0]).toTensor();
auto size = (*stack[1]).toIntList();
auto out = (*stack[2]).toTensor();
ET_CHECK(tensors_have_same_dtype(self, out));
// Compute output size
SizesType expected_output_size[kTensorDimensionLimit];
ET_CHECK(get_view_target_size(self, size, out.dim(), expected_output_size));
// Resize for dynamic shape
ET_CHECK_MSG(
resize_tensor(
out, {expected_output_size, static_cast<size_t>(out.dim())}) ==
Error::Ok,
"Failed to resize output tensor.");
// Do some checks
ET_CHECK(self.numel() == out.numel());
// Update data ptr
ET_CHECK_MSG(
internal::set_tensor_data(
out,
/*buffer=*/self.mutable_data_ptr(),
/*buffer_size=*/out.nbytes()) == Error::Ok,
"Failed to set data_ptr for out to self.");
}
} // namespace function
} // namespace executor
} // namespace torch