Add lowerings to special functions (#119187)

As in the title.

In addition, the PR introduces infrastructure for lowerings of pointwise functions that have both cpp and triton implementations available.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119187
Approved by: https://github.com/peterbell10
diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect
index 359acea..925acb1 100644
--- a/test/expect/HasDecompTest.test_aten_core_operators.expect
+++ b/test/expect/HasDecompTest.test_aten_core_operators.expect
@@ -267,6 +267,8 @@
 aten::hypot
 aten::hypot.out
 aten::hypot_
+aten::i0
+aten::i0.out
 aten::i0_
 aten::igamma
 aten::igamma.out
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index 8cbae2f..c0f9255 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -816,8 +816,6 @@
 aten::histogram.bins_tensor_out
 aten::hspmm
 aten::hspmm.out
-aten::i0
-aten::i0.out
 aten::index.Tensor
 aten::index.Tensor_out
 aten::index_put
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index a2c9d85..9cb59d6 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -4349,7 +4349,6 @@
 symbolic_aot_autograd_failures = {
     xfail('combinations', ''),  # aten.masked_select.default
     xfail('frexp', ''),  # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition
-    xfail('i0', ''),  # aten.i0.default - couldn't find symbolic meta function/decomposition
     xfail('index_fill', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
     xfail('kthvalue', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
     xfail('linalg.eigvals', ''),  # aten.linalg_eig.default - couldn't find symbolic meta function/decomposition
@@ -4369,7 +4368,6 @@
     xfail('nn.functional.nll_loss', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
     xfail('_segment_reduce', 'lengths'),  # aten.segment_reduce.default - couldn't find symbolic meta functio...
     xfail('_segment_reduce', 'offsets'),  # aten.segment_reduce.default - couldn't find symbolic meta functio...
-    xfail('special.i1', ''),  # aten.i0.default - couldn't find symbolic meta function/decomposition
     xfail('trace', ''),  # Cannot call sizes() on tensor with symbolic sizes/strides
     xfail('_upsample_bilinear2d_aa'),  # RuntimeError: isIntList() INTERNAL ASSERT FAILED  Expected IntList but got GenericList
     decorate('linalg.householder_product', decorator=unittest.skipIf(IS_MACOS and IS_X86, 'flaky')),
diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py
index 8c50699..d3f2014 100644
--- a/test/inductor/test_cpu_repro.py
+++ b/test/inductor/test_cpu_repro.py
@@ -1559,9 +1559,42 @@
                 cpp_op_list.append(k)
 
         diff = [
+            "airy_ai",
             "bessel_j0",
             "bessel_j1",
+            "bessel_y0",
+            "bessel_y1",
             "modified_bessel_i0",
+            "modified_bessel_i1",
+            "modified_bessel_k0",
+            "modified_bessel_k1",
+            "scaled_modified_bessel_k0",
+            "scaled_modified_bessel_k1",
+            "spherical_bessel_j0",
+            "i1",
+            "i1e",
+            "ndtr",
+            "ndtri",
+            "log_ndtr",
+            "erfcx",
+            "gammainc",
+            "gammaincc",
+            "igamma",
+            "igammac",
+            "polygamma",
+            "zeta",
+            "shifted_chebyshev_polynomial_u",
+            "chebyshev_polynomial_u",
+            "chebyshev_polynomial_t",
+            "shifted_chebyshev_polynomial_w",
+            "chebyshev_polynomial_w",
+            "shifted_chebyshev_polynomial_t",
+            "chebyshev_polynomial_v",
+            "shifted_chebyshev_polynomial_v",
+            "hermite_polynomial_he",
+            "laguerre_polynomial_l",
+            "hermite_polynomial_h",
+            "legendre_polynomial_p",
             "constant",
             "index_expr",
             "signbit",
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 28daaeb..94d61a9 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -62,12 +62,15 @@
 from torch.testing._internal.common_dtype import all_types, get_all_dtypes
 from torch.testing._internal.common_utils import (
     DeterministicGuard,
+    instantiate_parametrized_tests,
     IS_CI,
     IS_FBCODE,
     IS_MACOS,
     IS_WINDOWS,
     IS_X86,
+    parametrize,
     skipIfRocm,
+    subtest,
     TEST_WITH_ASAN,
     TEST_WITH_ROCM,
     TestCase as TorchTestCase,
@@ -676,6 +679,7 @@
                 cls.gen_template(name1, name2)
 
 
+@instantiate_parametrized_tests
 class CommonTemplate:
     def test_bool(self):
         def fn(a, b):
@@ -8864,23 +8868,107 @@
             # should_pad_bench always returns False if has_triton returns False
             self.assertFalse(should_pad)
 
-    def test_bessel_j0(self):
-        def fn(x):
-            return torch.special.bessel_j0(x)
+    @parametrize(
+        "name, op",
+        [
+            subtest((name, getattr(torch.special, name)), name=name)
+            for name in torch.special.__all__
+            if name not in {"softmax", "log_softmax", "logsumexp"}
+        ],
+    )
+    def test_pointwise(self, name, op):
+        dtype = torch.float32
+        check_lowp = True
+        if self.device == "cuda" and name in {
+            "airy_ai",
+            "bessel_i0",
+            "bessel_i1",
+            "bessel_j0",
+            "bessel_j1",
+            "bessel_y0",
+            "bessel_y1",
+            "erfcx",
+            "gammainc",
+            "gammaincc",
+            "i1",
+            "i1e",
+            "modified_bessel_i0",
+            "modified_bessel_i1",
+            "modified_bessel_k0",
+            "modified_bessel_k1",
+            "ndtri",
+            "scaled_modified_bessel_k0",
+            "scaled_modified_bessel_k1",
+            "spherical_bessel_j0",
+            "zeta",
+            "chebyshev_polynomial_t",
+            "chebyshev_polynomial_v",
+            "chebyshev_polynomial_u",
+            "chebyshev_polynomial_w",
+            "legendre_polynomial_p",
+            "shifted_chebyshev_polynomial_t",
+            "shifted_chebyshev_polynomial_u",
+            "shifted_chebyshev_polynomial_v",
+            "shifted_chebyshev_polynomial_w",
+            "hermite_polynomial_h",
+            "hermite_polynomial_he",
+            "laguerre_polynomial_l",
+        }:
+            # <func>_cuda not implemented for Half
+            check_lowp = False
 
-        self.common(fn, (torch.randn(8, 8),))
+        if name in {"gammainc", "gammaincc"}:
+            args = (
+                torch.randn(8, 8, dtype=dtype, device=self.device),
+                torch.empty(8, 8, dtype=dtype, device=self.device).uniform_(1, 2),
+            )
 
-    def test_bessel_j1(self):
-        def fn(x):
-            return torch.special.bessel_j1(x)
+            def fn(x, y):
+                return op(x, y)
 
-        self.common(fn, (torch.randn(8, 8),))
+        elif name in {"xlog1py", "xlogy", "zeta"}:
+            args = (
+                torch.randn(8, 8, dtype=dtype, device=self.device),
+                torch.empty(8, 8, dtype=dtype, device=self.device).uniform_(1, 2),
+            )
 
-    def test_modified_bessel_i0(self):
-        def fn(x):
-            return torch.special.modified_bessel_i0(x)
+            def fn(x, y):
+                return op(x, y)
 
-        self.common(fn, (torch.randn(8, 8),))
+        elif name == "multigammaln":
+            args = (
+                torch.empty(8, 8, dtype=dtype, device=self.device).uniform_(1, 2),
+                2,
+            )
+
+            def fn(x, p):
+                return op(x, p)
+
+        elif name == "polygamma":
+            args = (
+                1,
+                torch.empty(8, 8, dtype=dtype, device=self.device).uniform_(1, 10),
+            )
+
+            def fn(n, x):
+                return op(n, x)
+
+        elif "_polynomial_" in name:
+            args = (
+                torch.randn(8, 8, dtype=dtype, device=self.device),
+                2,
+            )
+
+            def fn(x, n):
+                return op(x, n)
+
+        else:
+            args = (torch.randn(8, 8, dtype=dtype, device=self.device),)
+
+            def fn(x):
+                return op(x)
+
+        self.common(fn, args, check_lowp=check_lowp)
 
 
 @dataclasses.dataclass
diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
index 0149387..815e215 100644
--- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
+++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
@@ -183,6 +183,41 @@
     "test_new_empty_strided_dynamic_shapes": TestFailure(("cpu", "cuda")),
     "test_new_ones_dynamic_shapes": TestFailure(("cpu",)),
     "test_permute2_dynamic_shapes": TestFailure(("cpu", "cuda")),
+    "test_pointwise_airy_ai_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_digamma_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_gammainc_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_gammaincc_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_i0e_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_i1e_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_modified_bessel_k0_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_modified_bessel_k1_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_ndtri_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_polygamma_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_psi_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_scaled_modified_bessel_k0_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_scaled_modified_bessel_k1_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_spherical_bessel_j0_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_zeta_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_chebyshev_polynomial_t_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_chebyshev_polynomial_u_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_chebyshev_polynomial_v_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_chebyshev_polynomial_w_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_shifted_chebyshev_polynomial_t_dynamic_shapes": TestFailure(
+        ("cuda",)
+    ),
+    "test_pointwise_shifted_chebyshev_polynomial_u_dynamic_shapes": TestFailure(
+        ("cuda",)
+    ),
+    "test_pointwise_shifted_chebyshev_polynomial_v_dynamic_shapes": TestFailure(
+        ("cuda",)
+    ),
+    "test_pointwise_shifted_chebyshev_polynomial_w_dynamic_shapes": TestFailure(
+        ("cuda",)
+    ),
+    "test_pointwise_hermite_polynomial_h_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_hermite_polynomial_he_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_laguerre_polynomial_l_dynamic_shapes": TestFailure(("cuda",)),
+    "test_pointwise_legendre_polynomial_p_dynamic_shapes": TestFailure(("cuda",)),
     "test_randn_generator_dynamic_shapes": TestFailure(("cpu",)),
     "test_randn_like_empty_dynamic_shapes": TestFailure(("cpu", "cuda")),
     "test_single_elem_dynamic_shapes": TestFailure(("cpu",)),
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index 11173fa..f1c9ba5 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -200,6 +200,7 @@
 if TEST_WITH_ROCM:
     # Tensors are not alike
     inductor_skips["cuda"]["logcumsumexp"] = {f32}
+    inductor_skips["cuda"]["special.modified_bessel_i1"] = {f64}
 
 inductor_expected_failures_single_sample = defaultdict(dict)
 
@@ -349,6 +350,15 @@
     ("softmax", "cuda", f16): {"atol": 1e-4, "rtol": 0.02},
     ("_softmax_backward_data", "cuda", f16): {"atol": 0.008, "rtol": 0.002},
     ("special.log_ndtr", "cuda", f64): {"atol": 1e-6, "rtol": 1e-5},
+    ("polygamma.polygamma_n_0", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
+    ("polygamma.polygamma_n_1", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
+    ("polygamma.polygamma_n_2", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
+    ("polygamma.polygamma_n_3", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
+    ("polygamma.polygamma_n_4", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
+    ("special.polygamma.special_polygamma_n_0", "cpu", f32): {
+        "atol": 1e-3,
+        "rtol": 1e-4,
+    },
     ("std_mean.unbiased", "cuda", f16): {"reference_in_float": True},
     ("uniform", "cuda"): {"reference_in_float": True},
     # Following tests are failing with strict comparision but atol=1 is acceptable due roundings errors
@@ -497,7 +507,6 @@
             overridden_kwargs = inductor_override_kwargs[(op_name, device_type)]
         elif (op_name, device_type, dtype) in inductor_override_kwargs:
             overridden_kwargs = inductor_override_kwargs[(op_name, device_type, dtype)]
-
         func = op.get_op()
 
         def fn(*args, **kwargs):
diff --git a/test/test_mps.py b/test/test_mps.py
index 3dc3f26..62d8562 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -552,13 +552,13 @@
         # - MPS output: tensor([102.6681, inf])
         # In the latter case, inf is probably correct (this is what scipy does).
         'polygamma': [torch.float32, torch.uint8],
-        'polygammapolygamma_n_0': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
-        'polygammapolygamma_n_2': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
-        'polygammapolygamma_n_1': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
-        'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
-        'polygammapolygamma_n_4': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
-        'special.polygamma': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
-        'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
+        'polygammapolygamma_n_0': [torch.float32, torch.int16, torch.int8],
+        'polygammapolygamma_n_2': [torch.float32, torch.int16, torch.int8],
+        'polygammapolygamma_n_1': [torch.float32, torch.int16, torch.int8],
+        'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int8],
+        'polygammapolygamma_n_4': [torch.float32, torch.int16, torch.int8],
+        'special.polygamma': [torch.float32, torch.int16, torch.int32, torch.int8],
+        'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int8],
 
         # Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
         'tan': [torch.float32],
@@ -615,13 +615,13 @@
         # - MPS output: tensor([102.6681, inf])
         # In the latter case, inf is probably correct (this is what scipy does).
         'polygamma': [torch.float32, torch.uint8],
-        'polygammapolygamma_n_0': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
-        'polygammapolygamma_n_2': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
-        'polygammapolygamma_n_1': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
-        'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
-        'polygammapolygamma_n_4': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
-        'special.polygamma': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
-        'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
+        'polygammapolygamma_n_0': [torch.float32, torch.int16, torch.int8],
+        'polygammapolygamma_n_2': [torch.float32, torch.int16, torch.int8],
+        'polygammapolygamma_n_1': [torch.float32, torch.int16, torch.int8],
+        'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int8],
+        'polygammapolygamma_n_4': [torch.float32, torch.int16, torch.int8],
+        'special.polygamma': [torch.float32, torch.int16, torch.int32, torch.int8],
+        'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int8],
     }
 
     # Those ops are not expected to work
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index a5e923e..5e816cb 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1931,8 +1931,6 @@
 symbolic_tensor_failures.update(symbolic_tensor_segfaults)
 
 outplace_symbolic_tensor_failures = {
-    xfail('i0', ''),  # aten.i0.default - couldn't find symbolic meta function/decomposition
-
     xfail('linalg.norm', ''),
 }
 
@@ -1957,7 +1955,6 @@
     xfail('fft.ifft2', ''),
     xfail('fft.ifftn', ''),
     xfail('gather', ''),
-    xfail('i0', ''),
     xfail('linalg.cholesky', ''),
     xfail('linalg.cholesky_ex', ''),
     xfail('linalg.det', ''),
diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py
index 038a77b..38eedb5 100644
--- a/torch/_inductor/codegen/common.py
+++ b/torch/_inductor/codegen/common.py
@@ -26,6 +26,7 @@
 
 import torch
 import torch.fx
+from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
 from torch.utils._sympy.value_ranges import ValueRanges
 
 from .. import config, metrics
@@ -523,6 +524,262 @@
     def load_seed(name, offset):
         return ops.load(name, sympy.Integer(offset))
 
+    @classmethod
+    def _initialize_pointwise_overrides(cls, target):
+        assert target in {"triton", "cpp", "cppvec"}, target
+
+        def pointwise_factory_1(impl):
+            def func(x):
+                return impl.format(x=x)
+
+            return func
+
+        def pointwise_factory_2(impl):
+            def func(x, y):
+                return impl.format(x=x, y=y)
+
+            return func
+
+        for funcname, data in pointwise_overrides_data.items():
+            impl = getattr(data, target)
+            if isinstance(impl, str):
+                nof_args = 2 if "{y}" in impl else 1
+                # extend the following dictionary with factory
+                # functions for a specific number of arguments as
+                # needed:
+                factory = {1: pointwise_factory_1, 2: pointwise_factory_2}[nof_args]
+                setattr(cls, funcname, staticmethod(factory(impl)))
+
+
+@dataclasses.dataclass
+class OverridesData:
+    name: str
+    cpp: str
+    triton: Optional[str] = None  # None when not impl in libdevice/triton
+    cppvec: Optional[str] = None  # None when not impl in aten/.../vec
+    type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
+        ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+    )
+
+
+pointwise_overrides_data: Dict[str, OverridesData] = dict(
+    airy_ai=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="airy_ai_forward({x})",
+        name="special_airy_ai",
+    ),
+    bessel_j0=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="bessel_j0_forward({x})",
+        triton="tl.math.j0({x})",
+        name="special_bessel_j0",
+    ),
+    bessel_j1=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="bessel_j1_forward({x})",
+        triton="tl.math.j1({x})",
+        name="special_bessel_j1",
+    ),
+    bessel_y0=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="bessel_y0_forward({x})",
+        triton="tl.math.y0({x})",
+        name="special_bessel_y0",
+    ),
+    bessel_y1=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="bessel_y1_forward({x})",
+        triton="tl.math.y1({x})",
+        name="special_bessel_y1",
+    ),
+    digamma=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="calc_digamma({x})",
+        cppvec="{x}.digamma()",
+        name="digamma",
+    ),
+    # no cpp nor triton implementation for entr, it is defined as decomposition
+    # erf, erfc
+    erfcx=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="calc_erfcx({x})",
+        triton="tl.math.erfcx({x})",
+        name="special_erfcx",
+    ),
+    # erfinv, exp2, expit, gammaln
+    igamma=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="calc_igamma({x}, {y})",
+        name="igamma",
+    ),
+    igammac=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="calc_igammac({x}, {y})",
+        name="igammac",
+    ),
+    gammainc=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="calc_igamma({x}, {y})",
+        name="special_gammainc",
+    ),
+    gammaincc=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="calc_igammac({x}, {y})",
+        name="special_gammaincc",
+    ),
+    i0=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="calc_i0({x})",
+        triton="tl.math.cyl_bessel_i0({x})",
+        cppvec="{x}.i0()",
+        name="i0",
+    ),
+    i0e=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="calc_i0e({x})",
+        cppvec="{x}.i0e()",
+        name="special_i0e",
+    ),
+    i1=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="calc_i1({x})",
+        triton="tl.math.cyl_bessel_i1({x})",
+        name="special_i1",
+    ),
+    i1e=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="calc_i1e({x})",
+        name="special_i1e",
+    ),
+    log_ndtr=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="calc_log_ndtr({x})",
+        name="special_log_ndtr",
+    ),
+    # logit
+    modified_bessel_i0=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="modified_bessel_i0_forward({x})",
+        triton="tl.math.cyl_bessel_i0({x})",
+        name="special_modified_bessel_i0",
+    ),
+    modified_bessel_i1=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="modified_bessel_i1_forward({x})",
+        triton="tl.math.cyl_bessel_i1({x})",
+        name="special_modified_bessel_i1",
+    ),
+    modified_bessel_k0=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="modified_bessel_k0_forward({x})",
+        name="special_modified_bessel_k0",
+    ),
+    modified_bessel_k1=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="modified_bessel_k1_forward({x})",
+        name="special_modified_bessel_k1",
+    ),
+    # multigamma
+    ndtr=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="calc_ndtr({x})",
+        name="special_ndtr",
+    ),
+    ndtri=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="calc_ndtri({x})",
+        name="special_ndtri",
+    ),
+    polygamma=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="calc_polygamma({y}, {x})",
+        name="polygamma",
+    ),
+    # psi - alias to digamma
+    # round
+    scaled_modified_bessel_k0=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="scaled_modified_bessel_k0_forward({x})",
+        name="special_scaled_modified_bessel_k0",
+    ),
+    scaled_modified_bessel_k1=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="scaled_modified_bessel_k1_forward({x})",
+        name="special_scaled_modified_bessel_k1",
+    ),
+    # sinc
+    spherical_bessel_j0=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="spherical_bessel_j0_forward({x})",
+        name="special_spherical_bessel_j0",
+    ),
+    zeta=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="zeta({x}, {y})",
+        name="special_zeta",
+    ),
+    chebyshev_polynomial_t=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="chebyshev_polynomial_t_forward({x}, {y})",
+        name="special_chebyshev_polynomial_t",
+    ),
+    chebyshev_polynomial_u=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="chebyshev_polynomial_u_forward({x}, {y})",
+        name="special_chebyshev_polynomial_u",
+    ),
+    chebyshev_polynomial_v=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="chebyshev_polynomial_v_forward({x}, {y})",
+        name="special_chebyshev_polynomial_v",
+    ),
+    chebyshev_polynomial_w=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="chebyshev_polynomial_w_forward({x}, {y})",
+        name="special_chebyshev_polynomial_w",
+    ),
+    legendre_polynomial_p=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="legendre_polynomial_p_forward({x}, {y})",
+        name="special_legendre_polynomial_p",
+    ),
+    shifted_chebyshev_polynomial_t=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="shifted_chebyshev_polynomial_t_forward({x}, {y})",
+        name="special_shifted_chebyshev_polynomial_t",
+    ),
+    shifted_chebyshev_polynomial_u=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="shifted_chebyshev_polynomial_u_forward({x}, {y})",
+        name="special_shifted_chebyshev_polynomial_u",
+    ),
+    shifted_chebyshev_polynomial_v=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="shifted_chebyshev_polynomial_v_forward({x}, {y})",
+        name="special_shifted_chebyshev_polynomial_v",
+    ),
+    shifted_chebyshev_polynomial_w=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="shifted_chebyshev_polynomial_w_forward({x}, {y})",
+        name="special_shifted_chebyshev_polynomial_w",
+    ),
+    hermite_polynomial_h=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="hermite_polynomial_h_forward({x}, {y})",
+        name="special_hermite_polynomial_h",
+    ),
+    hermite_polynomial_he=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="hermite_polynomial_he_forward({x}, {y})",
+        name="special_hermite_polynomial_he",
+    ),
+    laguerre_polynomial_l=OverridesData(
+        type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        cpp="laguerre_polynomial_l_forward({x}, {y})",
+        name="special_laguerre_polynomial_l",
+    ),
+)
+
 
 # Use mypy to check protocol implemented correctly
 def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index 87051cb..3a4cdd1 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -936,17 +936,8 @@
         V.kernel.compute.splice(code)
         return result
 
-    @staticmethod
-    def bessel_j0(x):
-        return f"bessel_j0_forward({x})"
 
-    @staticmethod
-    def bessel_j1(x):
-        return f"bessel_j1_forward({x})"
-
-    @staticmethod
-    def modified_bessel_i0(x):
-        return f"modified_bessel_i0_forward({x})"
+CppOverrides._initialize_pointwise_overrides("cpp")
 
 
 class CppVecOverrides(CppOverrides):
@@ -1454,6 +1445,9 @@
         return csevar
 
 
+CppVecOverrides._initialize_pointwise_overrides("cppvec")
+
+
 class CppTile2DOverrides(CppVecOverrides):
     @staticmethod
     def index_expr(expr, dtype):
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index 321218f..ce77e85 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -846,17 +846,8 @@
     def ceil(x):
         return f"tl.math.ceil({x})"
 
-    @staticmethod
-    def bessel_j0(x):
-        return f"tl.math.j0({x})"
 
-    @staticmethod
-    def bessel_j1(x):
-        return f"tl.math.j1({x})"
-
-    @staticmethod
-    def modified_bessel_i0(x):
-        return f"tl.math.cyl_bessel_i0({x})"
+TritonOverrides._initialize_pointwise_overrides("triton")
 
 
 # Use mypy to check protocol implemented correctly
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 09d2188..54e1f1a 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -396,8 +396,13 @@
     override_fn_when_input_bool=None,
     override_fn_when_cuda_float64=None,
     allow_alpha=False,
+    triton_fallback=None,
 ):
     def inner(*inputs: List[TensorBox], alpha=None):
+        if triton_fallback is not None and any(map(is_triton, inputs)):
+            assert not allow_alpha  # not implemented
+            return triton_fallback(*inputs)
+
         inputs = promote_constants(inputs, override_return_dtype)
         if allow_alpha:
             if alpha is not None and alpha != 1:
@@ -605,6 +610,7 @@
     override_fn_when_input_bool=None,
     allow_alpha=False,
     use_libdevice_for_f64=False,
+    triton_fallback=None,
 ):
     """A pointwise function that maps ops.{name} to inputs"""
     name = name or aten_fn.__name__
@@ -620,6 +626,7 @@
         override_fn_when_input_bool=override_fn_when_input_bool,
         override_fn_when_cuda_float64=fn_libdevice if use_libdevice_for_f64 else None,  # type: ignore[possibly-undefined]
         allow_alpha=allow_alpha,
+        triton_fallback=triton_fallback,
     )
     fn = register_lowering(
         aten_fn,
@@ -2233,7 +2240,6 @@
 make_fallback(aten._cdist_forward)
 make_fallback(aten.cummax)
 make_fallback(aten.cummin)
-make_fallback(aten.digamma, warn=False)
 make_fallback(aten._efficientzerotensor)
 make_fallback(aten._embedding_bag_per_sample_weights_backward)
 make_fallback(aten._efficientzerotensor)
@@ -2243,9 +2249,6 @@
 make_fallback(aten.frexp)
 make_fallback(aten.geqrf)
 make_fallback(aten.histc)
-make_fallback(aten.i0)
-make_fallback(aten.igamma, warn=False)
-make_fallback(aten.igammac, warn=False)
 make_fallback(aten.isin)
 make_fallback(aten.kthvalue)
 make_fallback(aten.linalg_cholesky_ex)
@@ -2274,33 +2277,12 @@
 make_fallback(aten.nanmedian)
 make_fallback(aten.ormqr)
 make_fallback(aten._pdist_forward)
-make_fallback(aten.polygamma)
 make_fallback(aten.put)
 make_fallback(aten.resize)
 make_fallback(aten.resize_)
 make_fallback(aten.resize_as)
 make_fallback(aten.resize_as_)
 make_fallback(aten.searchsorted)
-make_fallback(aten.special_airy_ai)
-make_fallback(aten.special_bessel_y0, warn=False)
-make_fallback(aten.special_bessel_y1)
-make_fallback(aten.special_chebyshev_polynomial_t)
-make_fallback(aten.special_chebyshev_polynomial_u)
-make_fallback(aten.special_erfcx, warn=False)
-make_fallback(aten.special_hermite_polynomial_h)
-make_fallback(aten.special_hermite_polynomial_he)
-make_fallback(aten.special_i0e, warn=False)
-make_fallback(aten.special_i1, warn=False)
-make_fallback(aten.special_i1e, warn=False)
-make_fallback(aten.special_laguerre_polynomial_l)
-make_fallback(aten.special_modified_bessel_i1)
-make_fallback(aten.special_modified_bessel_k0)
-make_fallback(aten.special_modified_bessel_k1)
-make_fallback(aten.special_ndtri, warn=False)
-make_fallback(aten.special_scaled_modified_bessel_k0)
-make_fallback(aten.special_scaled_modified_bessel_k1)
-make_fallback(aten.special_spherical_bessel_j0, warn=False)
-make_fallback(aten.special_zeta, warn=False)
 make_fallback(aten._trilinear)
 make_fallback(aten.uniform, warn=False)
 make_fallback(aten._adaptive_avg_pool3d_backward)
@@ -5097,11 +5079,12 @@
 )
 
 
-def register_pointwise_numeric(op, name=None):
+def register_pointwise_numeric(op, name=None, triton_fallback=None):
     return register_pointwise(
         op,
         name=name,
         type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+        triton_fallback=triton_fallback,
     )
 
 
@@ -5203,11 +5186,48 @@
 register_pointwise_numeric(aten.log10)
 register_pointwise_numeric(aten.nextafter)
 
-register_pointwise_numeric(aten.special_bessel_j0, name="bessel_j0")
-register_pointwise_numeric(prims.bessel_j0, name="bessel_j0")
-register_pointwise_numeric(aten.special_bessel_j1, name="bessel_j1")
-register_pointwise_numeric(prims.bessel_j1, name="bessel_j1")
-register_pointwise_numeric(aten.special_modified_bessel_i0, name="modified_bessel_i0")
+from .codegen.common import pointwise_overrides_data
+
+
+def _get_pointwise_overrides(ns, name):
+    data = pointwise_overrides_data[name]
+    op = getattr(ns, data.name, None)
+    if op is None:
+        return
+
+    def make_triton_fallback(op):
+        if data.triton is None:
+            return fallback_handler(op)
+
+    if isinstance(op, torch._ops.OpOverloadPacket):
+        for olname in op.overloads():
+            ol = getattr(op, olname)
+            yield ol, data.type_promotion_kind, make_triton_fallback(ol)
+    else:
+        yield op, data.type_promotion_kind, make_triton_fallback(op)
+
+
+for name in pointwise_overrides_data:
+    for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides(
+        aten, name
+    ):
+        register_pointwise(
+            op,
+            name=name,
+            type_promotion_kind=type_promotion_kind,
+            triton_fallback=triton_fallback,
+        )
+
+    for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides(
+        prims, name
+    ):
+        register_pointwise(
+            op,
+            name=name,
+            type_promotion_kind=type_promotion_kind,
+            triton_fallback=triton_fallback,
+        )
+
 
 foreach_add_list = register_foreach_pointwise(
     aten._foreach_add.List, add, allow_alpha=True
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index a837161..62cde04c 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -6206,9 +6206,16 @@
 
 _create_binary_float_meta_func(aten.special_chebyshev_polynomial_t)
 _create_binary_float_meta_func(aten.special_chebyshev_polynomial_u)
+_create_binary_float_meta_func(aten.special_chebyshev_polynomial_v)
+_create_binary_float_meta_func(aten.special_chebyshev_polynomial_w)
+_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_t)
+_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_u)
+_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_v)
+_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_w)
 _create_binary_float_meta_func(aten.special_hermite_polynomial_h)
 _create_binary_float_meta_func(aten.special_hermite_polynomial_he)
 _create_binary_float_meta_func(aten.special_laguerre_polynomial_l)
+_create_binary_float_meta_func(aten.special_legendre_polynomial_p)
 
 
 # We must also trigger meta registrations from PrimTorch ref
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index aaa3ffe..d18e24b 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -739,7 +739,7 @@
 
 # TODO: if this is special maybe it should be defined there and imported here?
 @_make_elementwise_unary_reference(
-    ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=aten.special_i0
+    ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=aten.i0
 )
 def i0(a):
     return prims.bessel_i0(a)
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 8d3a721..acff31b 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -16170,98 +16170,43 @@
                    skips=(
                        DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
                    ),
-                   sample_kwargs=lambda device, dtype, input: ({'n': 0}, {'n': 0})),
-    UnaryUfuncInfo('polygamma',
-                   op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
-                   variant_test_name='polygamma_n_1',
-                   ref=reference_polygamma if TEST_SCIPY else None,
-                   dtypes=all_types_and(torch.bool, torch.bfloat16),
-                   dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
-                   supports_forward_ad=True,
-                   supports_fwgrad_bwgrad=True,
-                   promotes_int_to_float=True,
-                   sample_inputs_func=sample_inputs_polygamma,
-                   skips=(
-                       # Redundant tests
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'),
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
-                       # Mismatch: https://github.com/pytorch/pytorch/issues/55357
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large'),
-                   ),
-                   sample_kwargs=lambda device, dtype, input: ({'n': 1}, {'n': 1}),
-                   # polygamma functions have multiple singularities at x <= 0
-                   reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)),
-    UnaryUfuncInfo('polygamma',
-                   op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
-                   variant_test_name='polygamma_n_2',
-                   ref=reference_polygamma if TEST_SCIPY else None,
-                   dtypes=all_types_and(torch.bool, torch.bfloat16),
-                   dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
-                   supports_forward_ad=True,
-                   supports_fwgrad_bwgrad=True,
-                   promotes_int_to_float=True,
-                   sample_inputs_func=sample_inputs_polygamma,
-                   skips=(
-                       # Redundant tests
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'),
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
-                       # Mismatch: https://github.com/pytorch/pytorch/issues/55357
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),),
-                   sample_kwargs=lambda device, dtype, input: ({'n': 2}, {'n': 2}),
-                   # polygamma functions have multiple singularities at x <= 0
-                   reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)),
-    UnaryUfuncInfo('polygamma',
-                   op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
-                   variant_test_name='polygamma_n_3',
-                   ref=reference_polygamma if TEST_SCIPY else None,
-                   dtypes=all_types_and(torch.bool, torch.bfloat16),
-                   dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
-                   supports_forward_ad=True,
-                   supports_fwgrad_bwgrad=True,
-                   promotes_int_to_float=True,
-                   sample_inputs_func=sample_inputs_polygamma,
-                   skips=(
-                       # Redundant tests
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'),
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
-                       # Mismatch: https://github.com/pytorch/pytorch/issues/55357
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),),
-                   sample_kwargs=lambda device, dtype, input: ({'n': 3}, {'n': 3}),
-                   # polygamma functions have multiple singularities at x <= 0
-                   reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)),
-    UnaryUfuncInfo('polygamma',
-                   op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
-                   variant_test_name='polygamma_n_4',
-                   ref=reference_polygamma if TEST_SCIPY else None,
-                   decorators=(precisionOverride({torch.float16: 5e-4, torch.float32: 5e-4}),),
-                   dtypes=all_types_and(torch.bool, torch.bfloat16),
-                   dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
-                   supports_forward_ad=True,
-                   supports_fwgrad_bwgrad=True,
-                   promotes_int_to_float=True,
-                   sample_inputs_func=sample_inputs_polygamma,
-                   skips=(
-                       # Redundant tests
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'),
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
-                       # Mismatch: https://github.com/pytorch/pytorch/issues/55357
-                       DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),),
-                   sample_kwargs=lambda device, dtype, input: ({'n': 4}, {'n': 4}),
-                   # polygamma functions have multiple singularities at x <= 0
-                   reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)),
+                   sample_kwargs=lambda device, dtype, input: ({'n': 0}, {'n': 0}),
+                   # polygamma functions have multiple singularities at x having non-positive integer value
+                   reference_numerics_filter=NumericsFilter(condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4),
+                                                            safe_val=1)),
+    *(UnaryUfuncInfo('polygamma',
+                     op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
+                     variant_test_name=f'polygamma_n_{n_}',
+                     ref=reference_polygamma if TEST_SCIPY else None,
+                     dtypes=all_types_and(torch.bool, torch.bfloat16),
+                     dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+                     supports_forward_ad=True,
+                     supports_fwgrad_bwgrad=True,
+                     promotes_int_to_float=True,
+                     sample_inputs_func=sample_inputs_polygamma,
+                     decorators=(
+                         DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-3)}), 'TestUnaryUfuncs'),
+                         DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e1, rtol=1e-1),
+                                                         torch.float32: tol(atol=1e-4, rtol=1e-2)}),
+                                      'TestUnaryUfuncs', 'test_reference_numerics_normal',
+                                      active_if=IS_WINDOWS),
+                     ),
+                     skips=(
+                         # Redundant tests
+                         DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
+                         DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
+                         DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
+                         DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'),
+                         DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
+                         # Mismatch: https://github.com/pytorch/pytorch/issues/55357
+                         DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
+                         DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large'),
+                     ),
+                     sample_kwargs=lambda device, dtype, input: ({'n': n_}, {'n': n_}),
+                     # polygamma functions have multiple singularities at x having non-positive integer value
+                     reference_numerics_filter=NumericsFilter(condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4),
+                                                              safe_val=1))
+      for n_ in (1, 2, 3, 4)),
     OpInfo('ravel',
            ref=np.ravel,
            dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py
index 545d28b..0357e9f 100644
--- a/torch/testing/_internal/common_modules.py
+++ b/torch/testing/_internal/common_modules.py
@@ -4138,6 +4138,10 @@
                ),
     ModuleInfo(torch.nn.Embedding,
                module_inputs_func=module_inputs_torch_nn_Embedding,
+               decorators=[
+                   DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
+                                'TestModule', 'test_non_contiguous_tensors',
+                                device_type='mps')],
                skips=(
                    DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
                ),
diff --git a/torch/testing/_internal/dynamo_test_failures.py b/torch/testing/_internal/dynamo_test_failures.py
index 177eb4c..5bea6db 100644
--- a/torch/testing/_internal/dynamo_test_failures.py
+++ b/torch/testing/_internal/dynamo_test_failures.py
@@ -1488,7 +1488,6 @@
     "TestTorchDeviceTypeCPU.test_broadcast_fn_div_cpu",  # test_torch
     "TestTorchDeviceTypeCPU.test_nondeterministic_resize_quantized_cpu_quint8",  # test_torch
     "TestTorchDeviceTypeCPU.test_broadcast_fn_lt_cpu",  # test_torch
-    "TestTorchDeviceTypeCPU.test_memory_format_operators_cpu",  # test_torch
     "TestTorch.test_pin_memory",  # test_torch
     "TestTorchDeviceTypeCPU.test_broadcast_fn_masked_fill_cpu",  # test_torch
     "TestTorchDeviceTypeCPU.test_nondeterministic_alert_MaxUnpool2d_cpu_float64",  # test_torch
diff --git a/torch/testing/_internal/opinfo/definitions/special.py b/torch/testing/_internal/opinfo/definitions/special.py
index cc1e41d..426eb7e 100644
--- a/torch/testing/_internal/opinfo/definitions/special.py
+++ b/torch/testing/_internal/opinfo/definitions/special.py
@@ -65,7 +65,12 @@
 
 def sample_inputs_polygamma(op_info, device, dtype, requires_grad, **kwargs):
     make_arg = partial(
-        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+        make_tensor,
+        device=device,
+        # TODO: eliminate low after gh-106692 is fixed:
+        low=(1 if dtype in {torch.int32, torch.int64} else None),
+        dtype=dtype,
+        requires_grad=requires_grad,
     )
     tensor_shapes = ((S, S), ())
     ns = (1, 2, 3, 4, 5)
@@ -101,6 +106,19 @@
     yield SampleInput(make_arg(()))
 
 
+def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs):
+    for shape in ((L,), (1, 0, 3), ()):
+        yield SampleInput(
+            make_tensor(
+                shape,
+                device=device,
+                dtype=dtype,
+                low=-5,
+                requires_grad=requires_grad,
+            ),
+        )
+
+
 op_db: List[OpInfo] = [
     UnaryUfuncInfo(
         "special.i0e",
@@ -195,9 +213,9 @@
             ),
         ),
         sample_kwargs=lambda device, dtype, input: ({"n": 0}, {"n": 0}),
-        # polygamma functions have multiple singularities at x <= 0
+        # polygamma functions have multiple singularities at x having non-positive integer value
         reference_numerics_filter=NumericsFilter(
-            condition=lambda x: x < 0.1, safe_val=1
+            condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4), safe_val=1
         ),
     ),
     BinaryUfuncInfo(
@@ -293,6 +311,7 @@
         dtypes=all_types_and(torch.bool),
         supports_forward_ad=True,
         supports_fwgrad_bwgrad=True,
+        sample_inputs_func=sample_inputs_erfcx,
     ),
     UnaryUfuncInfo(
         "special.airy_ai",