Add op: masked_select

Differential Revision: D65497030

Pull Request resolved: https://github.com/pytorch/executorch/pull/6670
diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml
index db59af1..f2f18f5 100644
--- a/kernels/aten/functions.yaml
+++ b/kernels/aten/functions.yaml
@@ -243,6 +243,8 @@
 
 - op: masked_scatter.out
 
+- op: masked_select.out
+
 - op: max_pool2d_with_indices.out
 
 - op: max.dim_max
diff --git a/kernels/portable/cpu/op_masked_select.cpp b/kernels/portable/cpu/op_masked_select.cpp
new file mode 100644
index 0000000..b176000
--- /dev/null
+++ b/kernels/portable/cpu/op_masked_select.cpp
@@ -0,0 +1,148 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
+#include <executorch/runtime/kernel/kernel_includes.h>
+
+namespace torch {
+namespace executor {
+namespace native {
+
+Tensor& masked_select_out(
+    KernelRuntimeContext& ctx,
+    const Tensor& in,
+    const Tensor& mask,
+    Tensor& out) {
+  ScalarType in_type = in.scalar_type();
+
+  ET_KERNEL_CHECK(
+      ctx,
+      executorch::runtime::tensor_is_realhbbf16_type(in),
+      InvalidArgument,
+      out);
+
+  ET_KERNEL_CHECK(
+      ctx, mask.scalar_type() == ScalarType::Bool, InvalidArgument, out);
+  ET_KERNEL_CHECK(ctx, out.scalar_type() == in_type, InvalidArgument, out);
+
+  ET_KERNEL_CHECK(
+      ctx, tensors_have_same_dim_order(in, mask, out), InvalidArgument, out);
+
+  ET_KERNEL_CHECK(
+      ctx, tensors_are_broadcastable_between(in, mask), InvalidArgument, out);
+
+  // If input or mask is empty, the output should be empty
+  if (in.numel() == 0 || mask.numel() == 0) {
+    ET_KERNEL_CHECK(
+        ctx, resize_tensor(out, {0}) == Error::Ok, InvalidArgument, out);
+    return out;
+  }
+
+  // Compute the shape resulting from broadcasting the mask against the input
+  size_t broadcast_ndim = 0;
+  Tensor::SizesType broadcast_sizes[kTensorDimensionLimit];
+  Error err = get_broadcast_target_size(
+      in, mask, broadcast_sizes, kTensorDimensionLimit, &broadcast_ndim);
+  if (err != Error::Ok) {
+    ET_KERNEL_CHECK_MSG(
+        ctx, false, InvalidArgument, out, "Failed to broadcast input and mask");
+  }
+  size_t broadcast_numel = 1;
+  for (size_t i = 0; i < broadcast_ndim; i++) {
+    broadcast_numel *= broadcast_sizes[i];
+  }
+
+  // Compute the number of out elements
+  size_t mask_true_count = 0;
+  const bool* const mask_data = mask.const_data_ptr<bool>();
+  for (size_t i = 0; i < mask.numel(); ++i) {
+    if (mask_data[i]) {
+      mask_true_count++;
+    }
+  }
+  Tensor::SizesType out_numel =
+      mask_true_count * (broadcast_numel / mask.numel());
+
+  // Resize the out tensor
+  ET_KERNEL_CHECK(
+      ctx, resize_tensor(out, {out_numel}) == Error::Ok, InvalidArgument, out);
+
+  const char* const in_data =
+      reinterpret_cast<const char*>(in.const_data_ptr());
+  char* const out_data = reinterpret_cast<char*>(out.mutable_data_ptr());
+  const auto elem_size = in.element_size();
+
+  // Figure out if `in` is broadcasted
+  bool in_is_broadcasted = false;
+  if (in.dim() != broadcast_ndim) {
+    in_is_broadcasted = true;
+  } else {
+    for (size_t i = 0; i < in.dim(); ++i) {
+      if (in.size(i) != broadcast_sizes[i]) {
+        in_is_broadcasted = true;
+      }
+    }
+  }
+
+  // Figure out if `mask` is broadcasted
+  bool mask_is_broadcasted = false;
+  if (mask.dim() != broadcast_ndim) {
+    mask_is_broadcasted = true;
+  } else {
+    for (size_t i = 0; i < mask.dim(); ++i) {
+      if (mask.size(i) != broadcast_sizes[i]) {
+        mask_is_broadcasted = true;
+      }
+    }
+  }
+
+  // Figure out if either `in` or `mask` is broadcasted
+  bool any_is_broadcasted = (in_is_broadcasted || mask_is_broadcasted);
+
+  size_t out_ix = 0;
+  for (size_t i = 0; i < broadcast_numel; ++i) {
+    size_t in_linear_index = i;
+    size_t mask_linear_index = i;
+
+    // If either `in` or `mask` is broadcasted, we need to compute the indexes
+    // in the broadcasted space.
+    if (any_is_broadcasted) {
+      size_t broadcast_indexes[kTensorDimensionLimit];
+      delinearize_index(
+          i,
+          {broadcast_sizes, broadcast_ndim},
+          broadcast_indexes,
+          kTensorDimensionLimit);
+
+      if (in_is_broadcasted) {
+        in_linear_index =
+            linearize_access_indexes(broadcast_indexes, broadcast_ndim, in);
+      }
+      if (mask_is_broadcasted) {
+        mask_linear_index =
+            linearize_access_indexes(broadcast_indexes, broadcast_ndim, mask);
+      }
+    }
+
+    // If the mask is true, copy the value from `in` to `out` and increment the
+    // `out_ix`
+    if (mask_data[mask_linear_index]) {
+      memcpy(
+          out_data + out_ix * elem_size,
+          in_data + in_linear_index * elem_size,
+          elem_size);
+      out_ix++;
+    }
+  }
+
+  return out;
+}
+
+} // namespace native
+} // namespace executor
+} // namespace torch
diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml
index d1eb8b8..a5d60eb 100644
--- a/kernels/portable/functions.yaml
+++ b/kernels/portable/functions.yaml
@@ -547,6 +547,11 @@
     - arg_meta: null
       kernel_name: torch::executor::masked_scatter_out
 
+- op: masked_select.out
+  kernels:
+    - arg_meta: null
+      kernel_name: torch::executor::masked_select_out
+
 - op: max.dim_max
   kernels:
     - arg_meta: null
diff --git a/kernels/test/op_masked_select_test.cpp b/kernels/test/op_masked_select_test.cpp
new file mode 100644
index 0000000..2a7791e
--- /dev/null
+++ b/kernels/test/op_masked_select_test.cpp
@@ -0,0 +1,115 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
+#include <executorch/kernels/test/TestUtil.h>
+#include <executorch/kernels/test/supported_features.h>
+#include <executorch/runtime/core/exec_aten/exec_aten.h>
+#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
+#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
+
+#include <gtest/gtest.h>
+
+using namespace ::testing;
+using exec_aten::ScalarType;
+using exec_aten::Tensor;
+using torch::executor::testing::SupportedFeatures;
+using torch::executor::testing::TensorFactory;
+
+class OpMaskedSelectOutTest : public OperatorTest {
+ protected:
+  Tensor&
+  op_masked_select_out(const Tensor& in, const Tensor& mask, Tensor& out) {
+    return torch::executor::aten::masked_select_outf(context_, in, mask, out);
+  }
+};
+
+TEST_F(OpMaskedSelectOutTest, SmokeTest) {
+  TensorFactory<ScalarType::Int> tf;
+  TensorFactory<ScalarType::Bool> tfBool;
+
+  Tensor in = tf.make({2, 3}, {1, 2, 3, 4, 5, 6});
+  Tensor mask = tfBool.make({2, 3}, {true, false, false, true, false, true});
+  Tensor out = tf.zeros({3});
+
+  op_masked_select_out(in, mask, out);
+  EXPECT_TENSOR_EQ(out, tf.make({3}, {1, 4, 6}));
+}
+
+TEST_F(OpMaskedSelectOutTest, BroadcastInput) {
+  TensorFactory<ScalarType::Int> tf;
+  TensorFactory<ScalarType::Bool> tfBool;
+
+  Tensor in = tf.make({3}, {1, 2, 3});
+  Tensor mask = tfBool.make({2, 3}, {true, false, false, true, false, true});
+  Tensor out = tf.zeros({3});
+
+  op_masked_select_out(in, mask, out);
+  EXPECT_TENSOR_EQ(out, tf.make({3}, {1, 1, 3}));
+}
+
+TEST_F(OpMaskedSelectOutTest, BroadcastMask) {
+  TensorFactory<ScalarType::Int> tf;
+  TensorFactory<ScalarType::Bool> tfBool;
+
+  Tensor in = tf.make({2, 3}, {1, 2, 3, 4, 5, 6});
+  Tensor mask = tfBool.make({3}, {false, true, false});
+
+  Tensor out = tf.zeros({2});
+
+  op_masked_select_out(in, mask, out);
+  EXPECT_TENSOR_EQ(out, tf.make({2}, {2, 5}));
+}
+
+TEST_F(OpMaskedSelectOutTest, BroadcastInputAndMask) {
+  TensorFactory<ScalarType::Int> tf;
+  TensorFactory<ScalarType::Bool> tfBool;
+
+  Tensor in = tf.ones({2, 3, 4, 1});
+  Tensor mask = tfBool.ones({2, 1, 1, 5});
+  Tensor out = tf.zeros({120});
+
+  op_masked_select_out(in, mask, out);
+  EXPECT_TENSOR_EQ(out, tf.ones({120}));
+}
+
+TEST_F(OpMaskedSelectOutTest, EmptyInput) {
+  TensorFactory<ScalarType::Int> tf;
+  TensorFactory<ScalarType::Bool> tfBool;
+
+  Tensor in = tf.make({2, 0}, {});
+  Tensor mask = tfBool.make({2, 1}, {true, true});
+  Tensor out = tf.zeros({0});
+
+  op_masked_select_out(in, mask, out);
+  EXPECT_TENSOR_EQ(out, tf.make({0}, {}));
+}
+
+TEST_F(OpMaskedSelectOutTest, EmptyMask) {
+  TensorFactory<ScalarType::Int> tf;
+  TensorFactory<ScalarType::Bool> tfBool;
+
+  Tensor in = tf.make({2, 1}, {100, 200});
+  Tensor mask = tfBool.make({2, 0}, {});
+  Tensor out = tf.zeros({0});
+
+  op_masked_select_out(in, mask, out);
+  EXPECT_TENSOR_EQ(out, tf.make({0}, {}));
+}
+
+TEST_F(OpMaskedSelectOutTest, EmptyInputAndMask) {
+  TensorFactory<ScalarType::Int> tf;
+  TensorFactory<ScalarType::Bool> tfBool;
+
+  Tensor in = tf.make({2, 0}, {});
+  Tensor mask = tfBool.make({0}, {});
+  Tensor out = tf.zeros({0});
+
+  op_masked_select_out(in, mask, out);
+  EXPECT_TENSOR_EQ(out, tf.make({0}, {}));
+}
diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl
index 91b3ba8..ce15a57 100644
--- a/kernels/test/targets.bzl
+++ b/kernels/test/targets.bzl
@@ -255,6 +255,7 @@
     _common_op_test("op_lt_test", ["aten", "portable"])
     _common_op_test("op_masked_fill_test", ["aten", "portable"])
     _common_op_test("op_masked_scatter_test", ["aten", "portable"])
+    _common_op_test("op_masked_select_test", ["aten", "portable"])
     _common_op_test("op_max_test", ["aten", "portable"])
     _common_op_test("op_max_pool2d_with_indices_test", ["aten", "portable"])
     _common_op_test("op_maximum_test", ["aten", "portable"])
diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl
index f63932d..ab8fc63 100644
--- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl
+++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl
@@ -790,6 +790,12 @@
         ],
     ),
     op_target(
+        name = "op_masked_select",
+        deps = [
+            "//executorch/kernels/portable/cpu/util:broadcast_util",
+        ],
+    ),
+    op_target(
         name = "op_max",
         deps = [
             "//executorch/kernels/portable/cpu/util:reduce_util",