blob: b1423636969221fc1bc80255ec11b2547a0c82e4 [file] [log] [blame] [edit]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <ATen/cpu/vec/vec.h>
namespace executorch::vec {
// This function implements broadcasting binary operation on two tensors
// where lhs tensor is treated to be of shape [outer_size, broadcast_size, inner_size]
// and rhs tensor is treated to be of shape [outer_size, 1, inner_size]
// And this 1st dimension is considered broadcasting dimension
// This formula can map broadcasting on any dim=broadcast_dim
// for any two N dimensional tensors, where 0 < braodcast_dim < N-1
template <typename scalar_t, typename Op>
inline void broadcasting_map_3d_and_unsqueezed_3d(
const Op& vec_fun,
scalar_t* output_data,
const scalar_t* lhs,
const scalar_t* rhs,
int64_t outer_size,
int64_t broadcast_size,
int64_t inner_size) {
using Vec = at::vec::Vectorized<scalar_t>;
int64_t outer_stride_lhs = inner_size * broadcast_size;
int64_t outer_stride_rhs = inner_size;
int64_t broadcast_stride_lhs = inner_size;
for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs;
scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs;
const scalar_t* rhs_outer = rhs + outer_idx * outer_stride_rhs;
for (int64_t broadcast_idx = 0; broadcast_idx < broadcast_size; ++broadcast_idx) {
const scalar_t* lhs_outer_2 = lhs_outer + broadcast_idx * broadcast_stride_lhs;
scalar_t* output_data_row_2 = output_data_row + broadcast_idx * broadcast_stride_lhs;
int64_t inner_idx = 0;
for (; inner_idx < inner_size - (inner_size % Vec::size()); inner_idx += Vec::size()) {
Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx);
Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx);
Vec output_vec = vec_fun(data_vec, data_vec2);
output_vec.store(output_data_row_2 + inner_idx);
}
if (inner_size - inner_idx > 0) {
Vec data_vec = Vec::loadu(lhs_outer_2 + inner_idx, inner_size - inner_idx);
Vec data_vec2 = Vec::loadu(rhs_outer + inner_idx, inner_size - inner_idx);
Vec output_vec = vec_fun(data_vec, data_vec2);
output_vec.store(output_data_row_2 + inner_idx, inner_size - inner_idx);
}
}
}
}
template <typename scalar_t, typename Op>
inline void broadcasting_map_2d_by_1d(
const Op& vec_fun,
scalar_t* output_data,
const scalar_t* input_data,
const scalar_t* input_data2,
int64_t size,
int64_t size2) {
broadcasting_map_3d_and_unsqueezed_3d(vec_fun, output_data, input_data, input_data2, 1, size, size2);
}
/*
Following function is used to implement broadcasting binary operation on two tensors
where lhs tensor is treated to be of shape [outer_size, broadcast_size] and
rhs tensor is treated to be of shape [outer_size, 1]
Any two N dimensional tensors can be mapped to this formula
when lhs size = [lhs0, lhs1, ..., lhsN-1] and rhs size = [rhs0, rhs1, ..., 1]
by viewing the two tensors as
lhs size = [lsh0 * lsh1 * ... * lshN-2, lhsN-1]
rhs size = [rsh0 * rsh1 * ... * rshN-2, 1]
*/
template <typename scalar_t, typename Op>
inline void broadcasting_map_broadcast_last_dim(
const Op& vec_fun,
scalar_t* output_data,
const scalar_t* lhs,
const scalar_t* rhs,
int64_t outer_size,
int64_t broadcast_size) {
using Vec = at::vec::Vectorized<scalar_t>;
int64_t outer_stride_lhs = broadcast_size;
for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
const scalar_t* lhs_outer = lhs + outer_idx * outer_stride_lhs;
scalar_t* output_data_row = output_data + outer_idx * outer_stride_lhs;
int64_t inner_idx = 0;
Vec data_vec2 = Vec(rhs[outer_idx]);
for (; inner_idx < broadcast_size - (broadcast_size % Vec::size()); inner_idx += Vec::size()) {
Vec data_vec = Vec::loadu(lhs_outer + inner_idx);
Vec output_vec = vec_fun(data_vec, data_vec2);
output_vec.store(output_data_row + inner_idx);
}
if (broadcast_size - inner_idx > 0) {
Vec data_vec = Vec::loadu(lhs_outer + inner_idx, broadcast_size - inner_idx);
Vec output_vec = vec_fun(data_vec, data_vec2);
output_vec.store(output_data_row + inner_idx, broadcast_size - inner_idx);
}
}
}
} // namespace executorch::vec