Fix isin decomp and add python meta registration (#120821)
Fixes #119792
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120821
Approved by: https://github.com/malfet, https://github.com/peterbell10
diff --git a/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_float32 b/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_float32
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_float32
+++ /dev/null
diff --git a/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_float64 b/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_float64
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_float64
+++ /dev/null
diff --git a/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_int16 b/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_int16
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_int16
+++ /dev/null
diff --git a/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_int32 b/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_int32
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_int32
+++ /dev/null
diff --git a/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_int64 b/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_int64
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_int64
+++ /dev/null
diff --git a/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_int8 b/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_int8
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_int8
+++ /dev/null
diff --git a/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_uint8 b/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_uint8
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/TestSortAndSelectCPU.test_isin_cpu_uint8
+++ /dev/null
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 8c89897..bcd6652 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -4479,6 +4479,12 @@
@register_decomposition(aten.isin)
@out_wrapper()
def isin(elements, test_elements, *, assume_unique=False, invert=False):
+ # handle when either elements or test_elements are Scalars (they can't both be)
+ if not isinstance(elements, torch.Tensor):
+ elements = torch.tensor(elements, device=test_elements.device)
+ if not isinstance(test_elements, torch.Tensor):
+ test_elements = torch.tensor(test_elements, device=elements.device)
+
if test_elements.numel() < 10.0 * pow(elements.numel(), 0.145):
return isin_default(elements, test_elements, invert=invert)
else:
@@ -4488,6 +4494,9 @@
def isin_default(elements, test_elements, *, invert=False):
+ if elements.numel() == 0:
+ return torch.empty_like(elements, dtype=torch.bool)
+
x = elements.view(*elements.shape, *((1,) * test_elements.ndim))
if not invert:
cmp = x == test_elements
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 5d20be2..7ee6355 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -5961,6 +5961,31 @@
return torch.empty((), dtype=dtype, device=sorted_sequence.device)
+def _check_for_unsupported_isin_dtype(dtype):
+ torch._check(
+ dtype not in [torch.bool, torch.bfloat16, torch.complex128, torch.complex64],
+ lambda: f"Unsupported input type encountered for isin(): {dtype}",
+ )
+
+
+@register_meta(aten.isin)
+@out_wrapper()
+def meta_isin(elements, test_elements, *, assume_unique=False, invert=False):
+ torch._check(
+ isinstance(elements, Tensor) or isinstance(test_elements, Tensor),
+ lambda: "At least one of elements and test_elements must be a Tensor.",
+ )
+ if not isinstance(elements, Tensor):
+ elements = torch.tensor(elements, device=test_elements.device)
+
+ if not isinstance(test_elements, Tensor):
+ test_elements = torch.tensor(test_elements, device=elements.device)
+
+ _check_for_unsupported_isin_dtype(elements.dtype)
+ _check_for_unsupported_isin_dtype(test_elements.dtype)
+ return torch.empty_like(elements, dtype=torch.bool)
+
+
@register_meta(aten.polygamma)
@out_wrapper()
def meta_polygamma(n: int, self: Tensor) -> Tensor: