Add op: masked_select
Differential Revision: D65497030
Pull Request resolved: https://github.com/pytorch/executorch/pull/6670
diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml
index db59af1..f2f18f5 100644
--- a/kernels/aten/functions.yaml
+++ b/kernels/aten/functions.yaml
@@ -243,6 +243,8 @@
- op: masked_scatter.out
+- op: masked_select.out
+
- op: max_pool2d_with_indices.out
- op: max.dim_max
diff --git a/kernels/portable/cpu/op_masked_select.cpp b/kernels/portable/cpu/op_masked_select.cpp
new file mode 100644
index 0000000..b176000
--- /dev/null
+++ b/kernels/portable/cpu/op_masked_select.cpp
@@ -0,0 +1,148 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
+#include <executorch/runtime/kernel/kernel_includes.h>
+
+namespace torch {
+namespace executor {
+namespace native {
+
+Tensor& masked_select_out(
+ KernelRuntimeContext& ctx,
+ const Tensor& in,
+ const Tensor& mask,
+ Tensor& out) {
+ ScalarType in_type = in.scalar_type();
+
+ ET_KERNEL_CHECK(
+ ctx,
+ executorch::runtime::tensor_is_realhbbf16_type(in),
+ InvalidArgument,
+ out);
+
+ ET_KERNEL_CHECK(
+ ctx, mask.scalar_type() == ScalarType::Bool, InvalidArgument, out);
+ ET_KERNEL_CHECK(ctx, out.scalar_type() == in_type, InvalidArgument, out);
+
+ ET_KERNEL_CHECK(
+ ctx, tensors_have_same_dim_order(in, mask, out), InvalidArgument, out);
+
+ ET_KERNEL_CHECK(
+ ctx, tensors_are_broadcastable_between(in, mask), InvalidArgument, out);
+
+ // If input or mask is empty, the output should be empty
+ if (in.numel() == 0 || mask.numel() == 0) {
+ ET_KERNEL_CHECK(
+ ctx, resize_tensor(out, {0}) == Error::Ok, InvalidArgument, out);
+ return out;
+ }
+
+ // Compute the shape resulting from broadcasting the mask against the input
+ size_t broadcast_ndim = 0;
+ Tensor::SizesType broadcast_sizes[kTensorDimensionLimit];
+ Error err = get_broadcast_target_size(
+ in, mask, broadcast_sizes, kTensorDimensionLimit, &broadcast_ndim);
+ if (err != Error::Ok) {
+ ET_KERNEL_CHECK_MSG(
+ ctx, false, InvalidArgument, out, "Failed to broadcast input and mask");
+ }
+ size_t broadcast_numel = 1;
+ for (size_t i = 0; i < broadcast_ndim; i++) {
+ broadcast_numel *= broadcast_sizes[i];
+ }
+
+ // Compute the number of out elements
+ size_t mask_true_count = 0;
+ const bool* const mask_data = mask.const_data_ptr<bool>();
+ for (size_t i = 0; i < mask.numel(); ++i) {
+ if (mask_data[i]) {
+ mask_true_count++;
+ }
+ }
+ Tensor::SizesType out_numel =
+ mask_true_count * (broadcast_numel / mask.numel());
+
+ // Resize the out tensor
+ ET_KERNEL_CHECK(
+ ctx, resize_tensor(out, {out_numel}) == Error::Ok, InvalidArgument, out);
+
+ const char* const in_data =
+ reinterpret_cast<const char*>(in.const_data_ptr());
+ char* const out_data = reinterpret_cast<char*>(out.mutable_data_ptr());
+ const auto elem_size = in.element_size();
+
+ // Figure out if `in` is broadcasted
+ bool in_is_broadcasted = false;
+ if (in.dim() != broadcast_ndim) {
+ in_is_broadcasted = true;
+ } else {
+ for (size_t i = 0; i < in.dim(); ++i) {
+ if (in.size(i) != broadcast_sizes[i]) {
+ in_is_broadcasted = true;
+ }
+ }
+ }
+
+ // Figure out if `mask` is broadcasted
+ bool mask_is_broadcasted = false;
+ if (mask.dim() != broadcast_ndim) {
+ mask_is_broadcasted = true;
+ } else {
+ for (size_t i = 0; i < mask.dim(); ++i) {
+ if (mask.size(i) != broadcast_sizes[i]) {
+ mask_is_broadcasted = true;
+ }
+ }
+ }
+
+ // Figure out if either `in` or `mask` is broadcasted
+ bool any_is_broadcasted = (in_is_broadcasted || mask_is_broadcasted);
+
+ size_t out_ix = 0;
+ for (size_t i = 0; i < broadcast_numel; ++i) {
+ size_t in_linear_index = i;
+ size_t mask_linear_index = i;
+
+ // If either `in` or `mask` is broadcasted, we need to compute the indexes
+ // in the broadcasted space.
+ if (any_is_broadcasted) {
+ size_t broadcast_indexes[kTensorDimensionLimit];
+ delinearize_index(
+ i,
+ {broadcast_sizes, broadcast_ndim},
+ broadcast_indexes,
+ kTensorDimensionLimit);
+
+ if (in_is_broadcasted) {
+ in_linear_index =
+ linearize_access_indexes(broadcast_indexes, broadcast_ndim, in);
+ }
+ if (mask_is_broadcasted) {
+ mask_linear_index =
+ linearize_access_indexes(broadcast_indexes, broadcast_ndim, mask);
+ }
+ }
+
+ // If the mask is true, copy the value from `in` to `out` and increment the
+ // `out_ix`
+ if (mask_data[mask_linear_index]) {
+ memcpy(
+ out_data + out_ix * elem_size,
+ in_data + in_linear_index * elem_size,
+ elem_size);
+ out_ix++;
+ }
+ }
+
+ return out;
+}
+
+} // namespace native
+} // namespace executor
+} // namespace torch
diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml
index d1eb8b8..a5d60eb 100644
--- a/kernels/portable/functions.yaml
+++ b/kernels/portable/functions.yaml
@@ -547,6 +547,11 @@
- arg_meta: null
kernel_name: torch::executor::masked_scatter_out
+- op: masked_select.out
+ kernels:
+ - arg_meta: null
+ kernel_name: torch::executor::masked_select_out
+
- op: max.dim_max
kernels:
- arg_meta: null
diff --git a/kernels/test/op_masked_select_test.cpp b/kernels/test/op_masked_select_test.cpp
new file mode 100644
index 0000000..2a7791e
--- /dev/null
+++ b/kernels/test/op_masked_select_test.cpp
@@ -0,0 +1,115 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
+#include <executorch/kernels/test/TestUtil.h>
+#include <executorch/kernels/test/supported_features.h>
+#include <executorch/runtime/core/exec_aten/exec_aten.h>
+#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
+#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
+
+#include <gtest/gtest.h>
+
+using namespace ::testing;
+using exec_aten::ScalarType;
+using exec_aten::Tensor;
+using torch::executor::testing::SupportedFeatures;
+using torch::executor::testing::TensorFactory;
+
+class OpMaskedSelectOutTest : public OperatorTest {
+ protected:
+ Tensor&
+ op_masked_select_out(const Tensor& in, const Tensor& mask, Tensor& out) {
+ return torch::executor::aten::masked_select_outf(context_, in, mask, out);
+ }
+};
+
+TEST_F(OpMaskedSelectOutTest, SmokeTest) {
+ TensorFactory<ScalarType::Int> tf;
+ TensorFactory<ScalarType::Bool> tfBool;
+
+ Tensor in = tf.make({2, 3}, {1, 2, 3, 4, 5, 6});
+ Tensor mask = tfBool.make({2, 3}, {true, false, false, true, false, true});
+ Tensor out = tf.zeros({3});
+
+ op_masked_select_out(in, mask, out);
+ EXPECT_TENSOR_EQ(out, tf.make({3}, {1, 4, 6}));
+}
+
+TEST_F(OpMaskedSelectOutTest, BroadcastInput) {
+ TensorFactory<ScalarType::Int> tf;
+ TensorFactory<ScalarType::Bool> tfBool;
+
+ Tensor in = tf.make({3}, {1, 2, 3});
+ Tensor mask = tfBool.make({2, 3}, {true, false, false, true, false, true});
+ Tensor out = tf.zeros({3});
+
+ op_masked_select_out(in, mask, out);
+ EXPECT_TENSOR_EQ(out, tf.make({3}, {1, 1, 3}));
+}
+
+TEST_F(OpMaskedSelectOutTest, BroadcastMask) {
+ TensorFactory<ScalarType::Int> tf;
+ TensorFactory<ScalarType::Bool> tfBool;
+
+ Tensor in = tf.make({2, 3}, {1, 2, 3, 4, 5, 6});
+ Tensor mask = tfBool.make({3}, {false, true, false});
+
+ Tensor out = tf.zeros({2});
+
+ op_masked_select_out(in, mask, out);
+ EXPECT_TENSOR_EQ(out, tf.make({2}, {2, 5}));
+}
+
+TEST_F(OpMaskedSelectOutTest, BroadcastInputAndMask) {
+ TensorFactory<ScalarType::Int> tf;
+ TensorFactory<ScalarType::Bool> tfBool;
+
+ Tensor in = tf.ones({2, 3, 4, 1});
+ Tensor mask = tfBool.ones({2, 1, 1, 5});
+ Tensor out = tf.zeros({120});
+
+ op_masked_select_out(in, mask, out);
+ EXPECT_TENSOR_EQ(out, tf.ones({120}));
+}
+
+TEST_F(OpMaskedSelectOutTest, EmptyInput) {
+ TensorFactory<ScalarType::Int> tf;
+ TensorFactory<ScalarType::Bool> tfBool;
+
+ Tensor in = tf.make({2, 0}, {});
+ Tensor mask = tfBool.make({2, 1}, {true, true});
+ Tensor out = tf.zeros({0});
+
+ op_masked_select_out(in, mask, out);
+ EXPECT_TENSOR_EQ(out, tf.make({0}, {}));
+}
+
+TEST_F(OpMaskedSelectOutTest, EmptyMask) {
+ TensorFactory<ScalarType::Int> tf;
+ TensorFactory<ScalarType::Bool> tfBool;
+
+ Tensor in = tf.make({2, 1}, {100, 200});
+ Tensor mask = tfBool.make({2, 0}, {});
+ Tensor out = tf.zeros({0});
+
+ op_masked_select_out(in, mask, out);
+ EXPECT_TENSOR_EQ(out, tf.make({0}, {}));
+}
+
+TEST_F(OpMaskedSelectOutTest, EmptyInputAndMask) {
+ TensorFactory<ScalarType::Int> tf;
+ TensorFactory<ScalarType::Bool> tfBool;
+
+ Tensor in = tf.make({2, 0}, {});
+ Tensor mask = tfBool.make({0}, {});
+ Tensor out = tf.zeros({0});
+
+ op_masked_select_out(in, mask, out);
+ EXPECT_TENSOR_EQ(out, tf.make({0}, {}));
+}
diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl
index 91b3ba8..ce15a57 100644
--- a/kernels/test/targets.bzl
+++ b/kernels/test/targets.bzl
@@ -255,6 +255,7 @@
_common_op_test("op_lt_test", ["aten", "portable"])
_common_op_test("op_masked_fill_test", ["aten", "portable"])
_common_op_test("op_masked_scatter_test", ["aten", "portable"])
+ _common_op_test("op_masked_select_test", ["aten", "portable"])
_common_op_test("op_max_test", ["aten", "portable"])
_common_op_test("op_max_pool2d_with_indices_test", ["aten", "portable"])
_common_op_test("op_maximum_test", ["aten", "portable"])
diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl
index f63932d..ab8fc63 100644
--- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl
+++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl
@@ -790,6 +790,12 @@
],
),
op_target(
+ name = "op_masked_select",
+ deps = [
+ "//executorch/kernels/portable/cpu/util:broadcast_util",
+ ],
+ ),
+ op_target(
name = "op_max",
deps = [
"//executorch/kernels/portable/cpu/util:reduce_util",