Implementation of torch.isin() (#53125)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/3025

## Background

This PR implements a function similar to numpy's [`isin()`](https://numpy.org/doc/stable/reference/generated/numpy.isin.html#numpy.isin).

The op supports integral and floating point types on CPU and CUDA (+ half & bfloat16 for CUDA). Inputs can be one of:
* (Tensor, Tensor)
* (Tensor, Scalar)
* (Scalar, Tensor)

Internally, one of two algorithms is selected based on the number of elements vs. test elements. The heuristic for deciding which algorithm to use is taken from [numpy's implementation](https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/arraysetops.py#L575): if `len(test_elements) < 10 * len(elements) ** 0.145`, then a naive brute-force checking algorithm is used. Otherwise, a stablesort-based algorithm is used.

I've done some preliminary benchmarking to verify this heuristic on a devgpu, and determined for a limited set of tests that a power value of `0.407` instead of `0.145` is a better inflection point. For now, the heuristic has been left to match numpy's, but input is welcome for the best way to select it or whether it should be left the same as numpy's.

Tests are adapted from numpy's [isin and in1d tests](https://github.com/numpy/numpy/blob/7dcd29aaafe1ab8be4be04d3c793e5bcaf17459f/numpy/lib/tests/test_arraysetops.py).

Note: my locally generated docs look terrible for some reason, so I'm not including the screenshot for them until I figure out why.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/53125

Test Plan:
```
python test/test_ops.py   # Ex: python test/test_ops.py TestOpInfoCPU.test_supported_dtypes_isin_cpu_int32
python test/test_sort_and_select.py   # Ex: python test/test_sort_and_select.py TestSortAndSelectCPU.test_isin_cpu_int32
```

Reviewed By: soulitzer

Differential Revision: D29101165

Pulled By: jbschlosser

fbshipit-source-id: 2dcc38d497b1e843f73f332d837081e819454b4e
diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp
index f12e16a..33bcaeb 100644
--- a/aten/src/ATen/native/TensorCompare.cpp
+++ b/aten/src/ATen/native/TensorCompare.cpp
@@ -5,10 +5,50 @@
 #include <ATen/NativeFunctions.h>
 #include <ATen/native/ReduceOpsUtils.h>
 #include <c10/util/Exception.h>
+#include <ATen/native/Resize.h>
 #include <ATen/native/TensorCompare.h>
 #include <ATen/NamedTensorUtils.h>
+#include <ATen/TensorIndexing.h>
 
-namespace at { namespace native {
+namespace at {
+namespace meta {
+
+static inline void check_for_unsupported_isin_dtype(const ScalarType type) {
+  // Bail out for dtypes unsupported by the sorting algorithm to keep the interface consistent.
+  TORCH_CHECK(type != ScalarType::Bool &&
+      type != ScalarType::BFloat16 &&
+      type != ScalarType::ComplexFloat &&
+      type != ScalarType::ComplexDouble,
+      "Unsupported input type encountered for isin(): ", type);
+}
+
+TORCH_META_FUNC2(isin, Tensor_Tensor) (
+  const Tensor& elements, const Tensor& test_elements, bool assume_unique, bool invert
+) {
+  check_for_unsupported_isin_dtype(elements.scalar_type());
+  check_for_unsupported_isin_dtype(test_elements.scalar_type());
+  set_output(elements.sizes(), TensorOptions(elements.device()).dtype(ScalarType::Bool));
+}
+
+TORCH_META_FUNC2(isin, Tensor_Scalar) (
+  const Tensor& elements, const c10::Scalar& test_elements, bool assume_unique, bool invert
+) {
+  check_for_unsupported_isin_dtype(elements.scalar_type());
+  check_for_unsupported_isin_dtype(test_elements.type());
+  set_output(elements.sizes(), TensorOptions(elements.device()).dtype(ScalarType::Bool));
+}
+
+TORCH_META_FUNC2(isin, Scalar_Tensor) (
+  const c10::Scalar& elements, const Tensor& test_elements, bool assume_unique, bool invert
+) {
+  check_for_unsupported_isin_dtype(elements.type());
+  check_for_unsupported_isin_dtype(test_elements.scalar_type());
+  set_output({0}, TensorOptions(test_elements.device()).dtype(ScalarType::Bool));
+}
+
+} // namespace meta
+
+namespace native {
 
 DEFINE_DISPATCH(where_kernel); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
 DEFINE_DISPATCH(max_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
@@ -23,6 +63,7 @@
 DEFINE_DISPATCH(clamp_scalar_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
 DEFINE_DISPATCH(clamp_min_scalar_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
 DEFINE_DISPATCH(clamp_max_scalar_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
+DEFINE_DISPATCH(isin_default_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
 
 bool allclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) {
   return at::isclose(self, other, rtol, atol, equal_nan).all().item<uint8_t>();
@@ -245,6 +286,56 @@
 
 } // anonymous namespace
 
+// Sorting-based algorithm for isin(); used when the number of test elements is large.
+static void isin_sorting(
+    const Tensor& elements,
+    const Tensor& test_elements,
+    bool assume_unique,
+    bool invert,
+    const Tensor& out) {
+  // 1. Concatenate unique elements with unique test elements in 1D form. If
+  //    assume_unique is true, skip calls to unique().
+  Tensor elements_flat, test_elements_flat, unique_order;
+  if (assume_unique) {
+    elements_flat = elements.ravel();
+    test_elements_flat = test_elements.ravel();
+  } else {
+    std::tie (elements_flat, unique_order) = at::_unique(
+        elements, /*sorted=*/ false, /*return_inverse=*/ true);
+    std::tie (test_elements_flat, std::ignore) = at::_unique(test_elements, /*sorted=*/ false);
+  }
+
+  // 2. Stable sort all elements, maintaining order indices to reverse the
+  //    operation. Stable sort is necessary to keep elements before test
+  //    elements within the sorted list.
+  Tensor all_elements = at::_cat({elements_flat, test_elements_flat});
+  Tensor sorted_elements, sorted_order;
+  std::tie (sorted_elements, sorted_order) = all_elements.sort(
+      /*stable=*/ true, /*dim=*/ 0, /*descending=*/ false);
+
+  // 3. Create a mask for locations of adjacent duplicate values within the
+  //    sorted list. Duplicate values are in both elements and test elements.
+  Tensor duplicate_mask = at::empty_like(sorted_elements, TensorOptions(ScalarType::Bool));
+  Tensor sorted_except_first = sorted_elements.slice(0, 1, at::indexing::None);
+  Tensor sorted_except_last = sorted_elements.slice(0, 0, -1);
+  duplicate_mask.slice(0, 0, -1).copy_(
+    invert ? sorted_except_first.ne(sorted_except_last) : sorted_except_first.eq(sorted_except_last));
+  duplicate_mask.index_put_({-1}, invert);
+
+  // 4. Reorder the mask to match the pre-sorted element order.
+  Tensor mask = at::empty_like(duplicate_mask);
+  mask.index_copy_(0, sorted_order, duplicate_mask);
+
+  // 5. Index the mask to match the pre-unique element order. If
+  //    assume_unique is true, just take the first N items of the mask,
+  //    where N is the original number of elements.
+  if (assume_unique) {
+    out.copy_(mask.slice(0, 0, elements.numel()).view_as(out));
+  } else {
+    out.copy_(at::index(mask, {c10::optional<Tensor>(unique_order)}));
+  }
+}
+
 Tensor where(const Tensor& condition, const Tensor& self, const Tensor& other) {
   TORCH_CHECK(condition.device() == self.device() && self.device() == other.device(),
               "Expected condition, x and y to be on the same device, but condition is on ",
@@ -659,4 +750,42 @@
   return at::mode_out(values, indices, self, dimname_to_position(self, dim), keepdim);
 }
 
-}} // namespace at::native
+TORCH_IMPL_FUNC(isin_Tensor_Tensor_out) (
+  const Tensor& elements, const Tensor& test_elements, bool assume_unique, bool invert, const Tensor& out
+) {
+  if (elements.numel() == 0) {
+    return;
+  }
+
+  // Heuristic taken from numpy's implementation.
+  // See https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/arraysetops.py#L575
+  if (test_elements.numel() < static_cast<int64_t>(
+        10.0f * std::pow(static_cast<double>(elements.numel()), 0.145))) {
+    out.fill_(invert);
+    isin_default_stub(elements.device().type(), elements, test_elements, invert, out);
+  } else {
+    isin_sorting(elements, test_elements, assume_unique, invert, out);
+  }
+}
+
+TORCH_IMPL_FUNC(isin_Tensor_Scalar_out) (
+  const Tensor& elements, const c10::Scalar& test_elements, bool assume_unique, bool invert, const Tensor& out
+) {
+  // redispatch to eq / ne
+  if (invert) {
+    at::ne_out(const_cast<Tensor&>(out), elements, test_elements);
+  } else {
+    at::eq_out(const_cast<Tensor&>(out), elements, test_elements);
+  }
+}
+
+TORCH_IMPL_FUNC(isin_Scalar_Tensor_out) (
+  const c10::Scalar& elements, const Tensor& test_elements, bool assume_unique, bool invert, const Tensor& out
+) {
+  // redispatch
+  at::isin_out(const_cast<Tensor&>(out), wrapped_scalar_tensor(elements, test_elements.device()),
+    test_elements, assume_unique, invert);
+}
+
+} // namespace native
+} // namespace at
diff --git a/aten/src/ATen/native/TensorCompare.h b/aten/src/ATen/native/TensorCompare.h
index 9ffbfe7..6bcd386 100644
--- a/aten/src/ATen/native/TensorCompare.h
+++ b/aten/src/ATen/native/TensorCompare.h
@@ -33,4 +33,6 @@
 DECLARE_DISPATCH(void (*)(TensorIterator &, Scalar), clamp_min_scalar_stub);
 DECLARE_DISPATCH(void (*)(TensorIterator &, Scalar), clamp_max_scalar_stub);
 
+using isin_default_fn = void (*)(const Tensor&, const Tensor&, bool, const Tensor&);
+DECLARE_DISPATCH(isin_default_fn, isin_default_stub);
 }} // namespace at::native
diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
index 48eacf6..d774ccd 100644
--- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
+++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp
@@ -301,6 +301,37 @@
       });
 }
 
+// Default brute force implementation of isin(). Used when the number of test elements is small.
+// Iterates through each element and checks it against each test element.
+static void isin_default_kernel_cpu(
+    const Tensor& elements,
+    const Tensor& test_elements,
+    bool invert,
+    const Tensor& out) {
+  // Since test elements is not an input of the TensorIterator, type promotion
+  // must be done manually.
+  ScalarType common_type = at::result_type(elements, test_elements);
+  Tensor test_elements_flat = test_elements.to(common_type).ravel();
+  Tensor promoted_elements = elements.to(common_type);
+  auto iter = TensorIteratorConfig()
+    .add_output(out)
+    .add_input(promoted_elements)
+    .check_all_same_dtype(false)
+    .build();
+  // Dispatch based on promoted type.
+  AT_DISPATCH_ALL_TYPES(iter.dtype(1), "isin_default_cpu", [&]() {
+    cpu_kernel(iter, [&](scalar_t element_val) -> bool {
+      const auto* test_element_data = reinterpret_cast<scalar_t*>(test_elements_flat.data_ptr());
+      for (auto j = 0; j < test_elements_flat.numel(); ++j) {
+        if (element_val == test_element_data[j]) {
+          return !invert;
+        }
+      }
+      return invert;
+    });
+  });
+}
+
 static void clamp_kernel_impl(TensorIterator& iter) {
   AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.common_dtype(), "clamp_cpu", [&]() {
     cpu_kernel_vec(iter,
@@ -403,5 +434,6 @@
 REGISTER_DISPATCH(clamp_scalar_stub, &clamp_scalar_kernel_impl);
 REGISTER_DISPATCH(clamp_min_scalar_stub, &clamp_min_scalar_kernel_impl);
 REGISTER_DISPATCH(clamp_max_scalar_stub, &clamp_max_scalar_kernel_impl);
+REGISTER_DISPATCH(isin_default_stub, &isin_default_kernel_cpu);
 
 }} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu
index 0b4be60..66f6c21 100644
--- a/aten/src/ATen/native/cuda/TensorCompare.cu
+++ b/aten/src/ATen/native/cuda/TensorCompare.cu
@@ -129,6 +129,15 @@
   });
 }
 
+// Composite op implementation for simplicity. This materializes the cross product of elements and test elements,
+// so it is not very memory efficient, but it is fast on CUDA.
+void isin_default_kernel_gpu(const Tensor& elements, const Tensor& test_elements, bool invert, const Tensor& out) {
+  std::vector<int64_t> bc_shape(elements.dim(), 1);
+  bc_shape.push_back(-1);
+  out.copy_(invert ? elements.unsqueeze(-1).ne(test_elements.view(bc_shape)).all(-1)
+    : elements.unsqueeze(-1).eq(test_elements.view(bc_shape)).any(-1));
+}
+
 } // anonymous namespace
 
 
@@ -141,6 +150,7 @@
 REGISTER_DISPATCH(clamp_scalar_stub, &clamp_scalar_kernel_impl);
 REGISTER_DISPATCH(clamp_min_scalar_stub, &clamp_min_scalar_kernel_impl);
 REGISTER_DISPATCH(clamp_max_scalar_stub, &clamp_max_scalar_kernel_impl);
+REGISTER_DISPATCH(isin_default_stub, &isin_default_kernel_gpu);
 
 template <typename scalar_t>
 __global__ void _assert_async_cuda_kernel(scalar_t* input) {
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index f0762df..d5711da 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -2248,6 +2248,36 @@
 - func: isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor
   variants: function, method
 
+- func: isin.Tensor_Tensor_out(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+  structured: True
+  dispatch:
+    CPU, CUDA: isin_Tensor_Tensor_out
+
+- func: isin.Tensor_Tensor(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor
+  variants: function
+  structured_delegate: isin.Tensor_Tensor_out
+
+- func: isin.Tensor_Scalar_out(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+  structured: True
+  dispatch:
+    CPU, CUDA: isin_Tensor_Scalar_out
+
+- func: isin.Tensor_Scalar(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False) -> Tensor
+  variants: function
+  structured_delegate: isin.Tensor_Scalar_out
+
+- func: isin.Scalar_Tensor_out(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+  structured: True
+  dispatch:
+    CPU, CUDA: isin_Scalar_Tensor_out
+
+- func: isin.Scalar_Tensor(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor
+  variants: function
+  structured_delegate: isin.Scalar_Tensor_out
+
 - func: isnan(Tensor self) -> Tensor
   variants: function, method
   device_check: NoCheck
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index 94b2889..620b2ac 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -421,6 +421,7 @@
     greater
     isclose
     isfinite
+    isin
     isinf
     isposinf
     isneginf
diff --git a/test/test_sort_and_select.py b/test/test_sort_and_select.py
index a5cf5ec..526aaa6 100644
--- a/test/test_sort_and_select.py
+++ b/test/test_sort_and_select.py
@@ -5,6 +5,7 @@
 from torch._six import nan
 from itertools import permutations, product
 
+from torch.testing import all_types, all_types_and
 from torch.testing._internal.common_utils import \
     (TEST_WITH_ROCM, TestCase, run_tests, make_tensor, slowTest)
 from torch.testing._internal.common_device_type import \
@@ -843,6 +844,127 @@
         self.assertEqual(res[0], ref[0].squeeze())
         self.assertEqual(res[1], ref[1].squeeze())
 
+    @dtypes(*all_types())
+    @dtypesIfCUDA(*all_types_and(torch.half))
+    def test_isin(self, device, dtype):
+        def assert_isin_equal(a, b):
+            # Compare to the numpy reference implementation.
+            x = torch.isin(a, b)
+            a = a.cpu().numpy() if torch.is_tensor(a) else np.array(a)
+            b = b.cpu().numpy() if torch.is_tensor(b) else np.array(b)
+            y = np.isin(a, b)
+            self.assertEqual(x, y)
+
+        # multi-dim tensor, multi-dim tensor
+        a = torch.arange(24, device=device, dtype=dtype).reshape([2, 3, 4])
+        b = torch.tensor([[10, 20, 30], [0, 1, 3], [11, 22, 33]], device=device, dtype=dtype)
+        assert_isin_equal(a, b)
+
+        # zero-dim tensor
+        zero_d = torch.tensor(3, device=device, dtype=dtype)
+        assert_isin_equal(zero_d, b)
+        assert_isin_equal(a, zero_d)
+        assert_isin_equal(zero_d, zero_d)
+
+        # empty tensor
+        empty = torch.tensor([], device=device, dtype=dtype)
+        assert_isin_equal(empty, b)
+        assert_isin_equal(a, empty)
+        assert_isin_equal(empty, empty)
+
+        # scalar
+        assert_isin_equal(a, 6)
+        assert_isin_equal(5, b)
+
+        def define_expected(lst, invert=False):
+            expected = torch.tensor(lst, device=device)
+            if invert:
+                expected = expected.logical_not()
+            return expected
+
+        # Adapted from numpy's in1d tests
+        for mult in [1, 10]:
+            for invert in [False, True]:
+                a = torch.tensor([5, 7, 1, 2], device=device, dtype=dtype)
+                b = torch.tensor([2, 4, 3, 1, 5] * mult, device=device, dtype=dtype)
+                ec = define_expected([True, False, True, True], invert=invert)
+                c = torch.isin(a, b, assume_unique=True, invert=invert)
+                self.assertEqual(c, ec)
+
+                a[0] = 8
+                ec = define_expected([False, False, True, True], invert=invert)
+                c = torch.isin(a, b, assume_unique=True, invert=invert)
+                self.assertEqual(c, ec)
+
+                a[0], a[3] = 4, 8
+                ec = define_expected([True, False, True, False], invert=invert)
+                c = torch.isin(a, b, assume_unique=True, invert=invert)
+                self.assertEqual(c, ec)
+
+                a = torch.tensor([5, 4, 5, 3, 4, 4, 3, 4, 3, 5, 2, 1, 5, 5], device=device, dtype=dtype)
+                b = torch.tensor([2, 3, 4] * mult, device=device, dtype=dtype)
+                ec = define_expected([False, True, False, True, True, True, True, True, True,
+                                      False, True, False, False, False], invert=invert)
+                c = torch.isin(a, b, invert=invert)
+                self.assertEqual(c, ec)
+
+                b = torch.tensor([2, 3, 4] * mult + [5, 5, 4] * mult, device=device, dtype=dtype)
+                ec = define_expected([True, True, True, True, True, True, True, True, True, True,
+                                      True, False, True, True], invert=invert)
+                c = torch.isin(a, b, invert=invert)
+                self.assertEqual(c, ec)
+
+                a = torch.tensor([5, 7, 1, 2], device=device, dtype=dtype)
+                b = torch.tensor([2, 4, 3, 1, 5] * mult, device=device, dtype=dtype)
+                ec = define_expected([True, False, True, True], invert=invert)
+                c = torch.isin(a, b, invert=invert)
+                self.assertEqual(c, ec)
+
+                a = torch.tensor([5, 7, 1, 1, 2], device=device, dtype=dtype)
+                b = torch.tensor([2, 4, 3, 3, 1, 5] * mult, device=device, dtype=dtype)
+                ec = define_expected([True, False, True, True, True], invert=invert)
+                c = torch.isin(a, b, invert=invert)
+                self.assertEqual(c, ec)
+
+                a = torch.tensor([5, 5], device=device, dtype=dtype)
+                b = torch.tensor([2, 2] * mult, device=device, dtype=dtype)
+                ec = define_expected([False, False], invert=invert)
+                c = torch.isin(a, b, invert=invert)
+                self.assertEqual(c, ec)
+
+                # multi-dimensional input case using sort-based algo
+                for assume_unique in [False, True]:
+                    a = torch.arange(6, device=device, dtype=dtype).reshape([2, 3])
+                    b = torch.arange(3, 30, device=device, dtype=dtype)
+                    ec = define_expected([[False, False, False], [True, True, True]], invert=invert)
+                    c = torch.isin(a, b, invert=invert, assume_unique=assume_unique)
+                    self.assertEqual(c, ec)
+
+    def test_isin_different_dtypes(self, device):
+        supported_types = all_types() if device == 'cpu' else all_types_and(torch.half)
+        for mult in [1, 10]:
+            for assume_unique in [False, True]:
+                for dtype1, dtype2 in product(supported_types, supported_types):
+                    a = torch.tensor([1, 2, 3], device=device, dtype=dtype1)
+                    b = torch.tensor([3, 4, 5] * mult, device=device, dtype=dtype2)
+                    ec = torch.tensor([False, False, True], device=device)
+                    c = torch.isin(a, b, assume_unique=assume_unique)
+                    self.assertEqual(c, ec)
+
+    @onlyCUDA
+    @dtypes(*all_types())
+    def test_isin_different_devices(self, device, dtype):
+        a = torch.arange(6, device=device, dtype=dtype).reshape([2, 3])
+        b = torch.arange(3, 30, device='cpu', dtype=dtype)
+        with self.assertRaises(RuntimeError):
+            torch.isin(a, b)
+
+        c = torch.arange(6, device='cpu', dtype=dtype).reshape([2, 3])
+        d = torch.arange(3, 30, device=device, dtype=dtype)
+        with self.assertRaises(RuntimeError):
+            torch.isin(c, d)
+
+
 instantiate_device_type_tests(TestSortAndSelect, globals())
 
 if __name__ == '__main__':
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 81fe8c0..772aef7 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -4042,6 +4042,35 @@
 Alias for :func:`torch.linalg.inv`
 """.format(**common_args))
 
+add_docstr(torch.isin, r"""
+isin(elements, test_elements, *, assume_unique=False, invert=False) -> Tensor
+
+Tests if each element of :attr:`elements` is in :attr:`test_elements`. Returns
+a boolean tensor of the same shape as :attr:`elements` that is True for elements
+in :attr:`test_elements` and False otherwise.
+
+.. note::
+    One of :attr:`elements` or :attr:`test_elements` can be a scalar, but not both.
+
+Args:
+    elements (Tensor or Scalar): Input elements
+    test_elements (Tensor or Scalar): Values against which to test for each input element
+    assume_unique (bool, optional): If True, assumes both :attr:`elements` and
+        :attr:`test_elements` contain unique elements, which can speed up the
+        calculation. Default: False
+    invert (bool, optional): If True, inverts the boolean return tensor, resulting in True
+        values for elements *not* in :attr:`test_elements`. Default: False
+
+Returns:
+    A boolean tensor of the same shape as :attr:`elements` that is True for elements in
+    :attr:`test_elements` and False otherwise
+
+Example:
+    >>> torch.isin(torch.tensor([[1, 2], [3, 4]]), torch.tensor([2, 3]))
+    tensor([[False,  True],
+            [ True, False]])
+""")
+
 add_docstr(torch.isinf, r"""
 isinf(input) -> Tensor
 
diff --git a/torch/overrides.py b/torch/overrides.py
index aa6876d..a0353fd 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -493,6 +493,7 @@
         torch.index_select: lambda input, dim, index, out=None: -1,
         torch.index_fill: lambda input, dim, index, value: -1,
         torch.isfinite: lambda tensor: -1,
+        torch.isin: lambda e, te, assume_unique=False, invert=False: -1,
         torch.isinf: lambda tensor: -1,
         torch.isreal: lambda tensor: -1,
         torch.isposinf: lambda input, out=None: -1,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 08e35af..eb60f1f 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -3283,6 +3283,13 @@
         SampleInput(lhs, args=(3.14,)),
     ]
 
+def sample_inputs_isin(op_info, device, dtype, requires_grad):
+    element = make_tensor((L,), device, dtype, low=None, high=None, requires_grad=requires_grad)
+    indices = torch.randint(0, L, size=[S])
+    test_elements = element[indices].clone()
+    return [
+        SampleInput(element, args=(test_elements,))
+    ]
 
 def sample_inputs_masked_scatter(op_info, device, dtype, requires_grad, **kwargs):
     make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@@ -5336,6 +5343,11 @@
            gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
            sample_inputs_func=sample_inputs_linalg_invertible,
            decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack]),
+    OpInfo('isin',
+           dtypesIfCPU=all_types(),
+           dtypesIfCUDA=all_types_and(torch.half),
+           supports_autograd=False,
+           sample_inputs_func=sample_inputs_isin),
     OpInfo('kthvalue',
            dtypes=all_types(),
            dtypesIfCUDA=all_types_and(torch.float16),