blob: f799a968cf0eb5b67d814e39aab072368e53054b [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/extension/runner_util/inputs.h>
#include <algorithm>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/executor/method.h>
#include <executorch/runtime/executor/method_meta.h>
#include <executorch/runtime/platform/log.h>
namespace torch {
namespace executor {
namespace util {
namespace internal {
namespace {
/**
* Sets all elements of a tensor to 1.
*/
Error fill_ones(torch::executor::Tensor tensor) {
#define FILL_CASE(T, n) \
case (torch::executor::ScalarType::n): \
std::fill( \
tensor.mutable_data_ptr<T>(), \
tensor.mutable_data_ptr<T>() + tensor.numel(), \
1); \
break;
switch (tensor.scalar_type()) {
ET_FORALL_REAL_TYPES_AND(Bool, FILL_CASE)
default:
ET_LOG(Error, "Unsupported scalar type %d", (int)tensor.scalar_type());
return Error::InvalidArgument;
}
#undef FILL_CASE
return Error::Ok;
}
} // namespace
Error fill_and_set_input(
Method& method,
TensorInfo& tensor_meta,
size_t input_index,
void* data_ptr) {
TensorImpl impl = TensorImpl(
tensor_meta.scalar_type(),
/*dim=*/tensor_meta.sizes().size(),
// These const pointers will not be modified because we never resize this
// short-lived TensorImpl. It only exists so that set_input() can verify
// that the shape is correct; the Method manages its own sizes and
// dim_order arrays for the input.
const_cast<TensorImpl::SizesType*>(tensor_meta.sizes().data()),
data_ptr,
const_cast<TensorImpl::DimOrderType*>(tensor_meta.dim_order().data()));
Tensor t(&impl);
ET_CHECK_OK_OR_RETURN_ERROR(fill_ones(t));
return method.set_input(t, input_index);
}
} // namespace internal
} // namespace util
} // namespace executor
} // namespace torch