blob: 8abf6f97222bf67a601cd4013a7ad662d7e1209b [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <cstring>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
namespace torch {
namespace executor {
namespace native {
using Tensor = exec_aten::Tensor;
// copy.out(const Tensor& in, const Tensor& src, bool non_blocking, Tensor(a!)
// out) -> Tensor(a!), see caffe2/aten/src/ATen/native/Copy.cpp
// TODO: We actually shouldn't see this op with the proper functionalization,
// and this op needs to be deleted
Tensor& copy_out(
RuntimeContext& ctx,
const Tensor& in,
const Tensor& src,
bool non_blocking,
Tensor& out) {
(void)ctx;
// Right now we only support blocking data transfer
ET_KERNEL_CHECK(ctx, non_blocking == false, InvalidArgument, out);
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
ET_KERNEL_CHECK(
ctx, tensor_is_broadcastable_to(src, in), InvalidArgument, out);
ET_KERNEL_CHECK(
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
ScalarType in_type = in.scalar_type();
ScalarType src_type = src.scalar_type();
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "copy.out", CTYPE, [&]() {
ET_SWITCH_REAL_TYPES_AND(Bool, src_type, ctx, "copy.out", CTYPE_SRC, [&]() {
apply_binary_elementwise_fn<CTYPE, CTYPE_SRC, CTYPE>(
[](const CTYPE val_in, const CTYPE_SRC val_src) {
return convert<CTYPE, CTYPE_SRC>(val_src);
},
in,
src,
out);
});
});
return out;
}
Tensor&
copy_(RuntimeContext& ctx, Tensor& in, const Tensor& src, bool non_blocking) {
(void)ctx;
// Right now we only support blocking data transfer
ET_KERNEL_CHECK(ctx, non_blocking == false, InvalidArgument, in);
ET_KERNEL_CHECK(
ctx, tensor_is_broadcastable_to(src, in), InvalidArgument, in);
ScalarType in_type = in.scalar_type();
ScalarType src_type = src.scalar_type();
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "copy_", CTYPE, [&]() {
ET_SWITCH_REAL_TYPES_AND(Bool, src_type, ctx, "copy_", CTYPE_SRC, [&]() {
apply_binary_elementwise_fn<CTYPE, CTYPE_SRC, CTYPE>(
[](const CTYPE val_in, const CTYPE_SRC val_src) {
return convert<CTYPE, CTYPE_SRC>(val_src);
},
in,
src,
in);
});
});
return in;
}
} // namespace native
} // namespace executor
} // namespace torch