Add lowerings to special functions (#119187)
As in the title.
In addition, the PR introduces infrastructure for lowerings of pointwise functions that have both cpp and triton implementations available.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119187
Approved by: https://github.com/peterbell10
diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect
index 359acea..925acb1 100644
--- a/test/expect/HasDecompTest.test_aten_core_operators.expect
+++ b/test/expect/HasDecompTest.test_aten_core_operators.expect
@@ -267,6 +267,8 @@
aten::hypot
aten::hypot.out
aten::hypot_
+aten::i0
+aten::i0.out
aten::i0_
aten::igamma
aten::igamma.out
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index 8cbae2f..c0f9255 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -816,8 +816,6 @@
aten::histogram.bins_tensor_out
aten::hspmm
aten::hspmm.out
-aten::i0
-aten::i0.out
aten::index.Tensor
aten::index.Tensor_out
aten::index_put
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index a2c9d85..9cb59d6 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -4349,7 +4349,6 @@
symbolic_aot_autograd_failures = {
xfail('combinations', ''), # aten.masked_select.default
xfail('frexp', ''), # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition
- xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
xfail('index_fill', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('linalg.eigvals', ''), # aten.linalg_eig.default - couldn't find symbolic meta function/decomposition
@@ -4369,7 +4368,6 @@
xfail('nn.functional.nll_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('_segment_reduce', 'lengths'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
xfail('_segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
- xfail('special.i1', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
xfail('trace', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('_upsample_bilinear2d_aa'), # RuntimeError: isIntList() INTERNAL ASSERT FAILED Expected IntList but got GenericList
decorate('linalg.householder_product', decorator=unittest.skipIf(IS_MACOS and IS_X86, 'flaky')),
diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py
index 8c50699..d3f2014 100644
--- a/test/inductor/test_cpu_repro.py
+++ b/test/inductor/test_cpu_repro.py
@@ -1559,9 +1559,42 @@
cpp_op_list.append(k)
diff = [
+ "airy_ai",
"bessel_j0",
"bessel_j1",
+ "bessel_y0",
+ "bessel_y1",
"modified_bessel_i0",
+ "modified_bessel_i1",
+ "modified_bessel_k0",
+ "modified_bessel_k1",
+ "scaled_modified_bessel_k0",
+ "scaled_modified_bessel_k1",
+ "spherical_bessel_j0",
+ "i1",
+ "i1e",
+ "ndtr",
+ "ndtri",
+ "log_ndtr",
+ "erfcx",
+ "gammainc",
+ "gammaincc",
+ "igamma",
+ "igammac",
+ "polygamma",
+ "zeta",
+ "shifted_chebyshev_polynomial_u",
+ "chebyshev_polynomial_u",
+ "chebyshev_polynomial_t",
+ "shifted_chebyshev_polynomial_w",
+ "chebyshev_polynomial_w",
+ "shifted_chebyshev_polynomial_t",
+ "chebyshev_polynomial_v",
+ "shifted_chebyshev_polynomial_v",
+ "hermite_polynomial_he",
+ "laguerre_polynomial_l",
+ "hermite_polynomial_h",
+ "legendre_polynomial_p",
"constant",
"index_expr",
"signbit",
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 28daaeb..94d61a9 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -62,12 +62,15 @@
from torch.testing._internal.common_dtype import all_types, get_all_dtypes
from torch.testing._internal.common_utils import (
DeterministicGuard,
+ instantiate_parametrized_tests,
IS_CI,
IS_FBCODE,
IS_MACOS,
IS_WINDOWS,
IS_X86,
+ parametrize,
skipIfRocm,
+ subtest,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
TestCase as TorchTestCase,
@@ -676,6 +679,7 @@
cls.gen_template(name1, name2)
+@instantiate_parametrized_tests
class CommonTemplate:
def test_bool(self):
def fn(a, b):
@@ -8864,23 +8868,107 @@
# should_pad_bench always returns False if has_triton returns False
self.assertFalse(should_pad)
- def test_bessel_j0(self):
- def fn(x):
- return torch.special.bessel_j0(x)
+ @parametrize(
+ "name, op",
+ [
+ subtest((name, getattr(torch.special, name)), name=name)
+ for name in torch.special.__all__
+ if name not in {"softmax", "log_softmax", "logsumexp"}
+ ],
+ )
+ def test_pointwise(self, name, op):
+ dtype = torch.float32
+ check_lowp = True
+ if self.device == "cuda" and name in {
+ "airy_ai",
+ "bessel_i0",
+ "bessel_i1",
+ "bessel_j0",
+ "bessel_j1",
+ "bessel_y0",
+ "bessel_y1",
+ "erfcx",
+ "gammainc",
+ "gammaincc",
+ "i1",
+ "i1e",
+ "modified_bessel_i0",
+ "modified_bessel_i1",
+ "modified_bessel_k0",
+ "modified_bessel_k1",
+ "ndtri",
+ "scaled_modified_bessel_k0",
+ "scaled_modified_bessel_k1",
+ "spherical_bessel_j0",
+ "zeta",
+ "chebyshev_polynomial_t",
+ "chebyshev_polynomial_v",
+ "chebyshev_polynomial_u",
+ "chebyshev_polynomial_w",
+ "legendre_polynomial_p",
+ "shifted_chebyshev_polynomial_t",
+ "shifted_chebyshev_polynomial_u",
+ "shifted_chebyshev_polynomial_v",
+ "shifted_chebyshev_polynomial_w",
+ "hermite_polynomial_h",
+ "hermite_polynomial_he",
+ "laguerre_polynomial_l",
+ }:
+ # <func>_cuda not implemented for Half
+ check_lowp = False
- self.common(fn, (torch.randn(8, 8),))
+ if name in {"gammainc", "gammaincc"}:
+ args = (
+ torch.randn(8, 8, dtype=dtype, device=self.device),
+ torch.empty(8, 8, dtype=dtype, device=self.device).uniform_(1, 2),
+ )
- def test_bessel_j1(self):
- def fn(x):
- return torch.special.bessel_j1(x)
+ def fn(x, y):
+ return op(x, y)
- self.common(fn, (torch.randn(8, 8),))
+ elif name in {"xlog1py", "xlogy", "zeta"}:
+ args = (
+ torch.randn(8, 8, dtype=dtype, device=self.device),
+ torch.empty(8, 8, dtype=dtype, device=self.device).uniform_(1, 2),
+ )
- def test_modified_bessel_i0(self):
- def fn(x):
- return torch.special.modified_bessel_i0(x)
+ def fn(x, y):
+ return op(x, y)
- self.common(fn, (torch.randn(8, 8),))
+ elif name == "multigammaln":
+ args = (
+ torch.empty(8, 8, dtype=dtype, device=self.device).uniform_(1, 2),
+ 2,
+ )
+
+ def fn(x, p):
+ return op(x, p)
+
+ elif name == "polygamma":
+ args = (
+ 1,
+ torch.empty(8, 8, dtype=dtype, device=self.device).uniform_(1, 10),
+ )
+
+ def fn(n, x):
+ return op(n, x)
+
+ elif "_polynomial_" in name:
+ args = (
+ torch.randn(8, 8, dtype=dtype, device=self.device),
+ 2,
+ )
+
+ def fn(x, n):
+ return op(x, n)
+
+ else:
+ args = (torch.randn(8, 8, dtype=dtype, device=self.device),)
+
+ def fn(x):
+ return op(x)
+
+ self.common(fn, args, check_lowp=check_lowp)
@dataclasses.dataclass
diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
index 0149387..815e215 100644
--- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
+++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py
@@ -183,6 +183,41 @@
"test_new_empty_strided_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_new_ones_dynamic_shapes": TestFailure(("cpu",)),
"test_permute2_dynamic_shapes": TestFailure(("cpu", "cuda")),
+ "test_pointwise_airy_ai_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_digamma_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_gammainc_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_gammaincc_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_i0e_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_i1e_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_modified_bessel_k0_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_modified_bessel_k1_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_ndtri_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_polygamma_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_psi_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_scaled_modified_bessel_k0_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_scaled_modified_bessel_k1_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_spherical_bessel_j0_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_zeta_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_chebyshev_polynomial_t_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_chebyshev_polynomial_u_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_chebyshev_polynomial_v_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_chebyshev_polynomial_w_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_shifted_chebyshev_polynomial_t_dynamic_shapes": TestFailure(
+ ("cuda",)
+ ),
+ "test_pointwise_shifted_chebyshev_polynomial_u_dynamic_shapes": TestFailure(
+ ("cuda",)
+ ),
+ "test_pointwise_shifted_chebyshev_polynomial_v_dynamic_shapes": TestFailure(
+ ("cuda",)
+ ),
+ "test_pointwise_shifted_chebyshev_polynomial_w_dynamic_shapes": TestFailure(
+ ("cuda",)
+ ),
+ "test_pointwise_hermite_polynomial_h_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_hermite_polynomial_he_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_laguerre_polynomial_l_dynamic_shapes": TestFailure(("cuda",)),
+ "test_pointwise_legendre_polynomial_p_dynamic_shapes": TestFailure(("cuda",)),
"test_randn_generator_dynamic_shapes": TestFailure(("cpu",)),
"test_randn_like_empty_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_single_elem_dynamic_shapes": TestFailure(("cpu",)),
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index 11173fa..f1c9ba5 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -200,6 +200,7 @@
if TEST_WITH_ROCM:
# Tensors are not alike
inductor_skips["cuda"]["logcumsumexp"] = {f32}
+ inductor_skips["cuda"]["special.modified_bessel_i1"] = {f64}
inductor_expected_failures_single_sample = defaultdict(dict)
@@ -349,6 +350,15 @@
("softmax", "cuda", f16): {"atol": 1e-4, "rtol": 0.02},
("_softmax_backward_data", "cuda", f16): {"atol": 0.008, "rtol": 0.002},
("special.log_ndtr", "cuda", f64): {"atol": 1e-6, "rtol": 1e-5},
+ ("polygamma.polygamma_n_0", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
+ ("polygamma.polygamma_n_1", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
+ ("polygamma.polygamma_n_2", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
+ ("polygamma.polygamma_n_3", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
+ ("polygamma.polygamma_n_4", "cpu", f32): {"atol": 1e-3, "rtol": 1e-4},
+ ("special.polygamma.special_polygamma_n_0", "cpu", f32): {
+ "atol": 1e-3,
+ "rtol": 1e-4,
+ },
("std_mean.unbiased", "cuda", f16): {"reference_in_float": True},
("uniform", "cuda"): {"reference_in_float": True},
# Following tests are failing with strict comparision but atol=1 is acceptable due roundings errors
@@ -497,7 +507,6 @@
overridden_kwargs = inductor_override_kwargs[(op_name, device_type)]
elif (op_name, device_type, dtype) in inductor_override_kwargs:
overridden_kwargs = inductor_override_kwargs[(op_name, device_type, dtype)]
-
func = op.get_op()
def fn(*args, **kwargs):
diff --git a/test/test_mps.py b/test/test_mps.py
index 3dc3f26..62d8562 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -552,13 +552,13 @@
# - MPS output: tensor([102.6681, inf])
# In the latter case, inf is probably correct (this is what scipy does).
'polygamma': [torch.float32, torch.uint8],
- 'polygammapolygamma_n_0': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
- 'polygammapolygamma_n_2': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
- 'polygammapolygamma_n_1': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
- 'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
- 'polygammapolygamma_n_4': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
- 'special.polygamma': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
- 'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
+ 'polygammapolygamma_n_0': [torch.float32, torch.int16, torch.int8],
+ 'polygammapolygamma_n_2': [torch.float32, torch.int16, torch.int8],
+ 'polygammapolygamma_n_1': [torch.float32, torch.int16, torch.int8],
+ 'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int8],
+ 'polygammapolygamma_n_4': [torch.float32, torch.int16, torch.int8],
+ 'special.polygamma': [torch.float32, torch.int16, torch.int32, torch.int8],
+ 'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int8],
# Failures due to precision issues (due to fast-math). These has been fixed in MacOS 13.3+
'tan': [torch.float32],
@@ -615,13 +615,13 @@
# - MPS output: tensor([102.6681, inf])
# In the latter case, inf is probably correct (this is what scipy does).
'polygamma': [torch.float32, torch.uint8],
- 'polygammapolygamma_n_0': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
- 'polygammapolygamma_n_2': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
- 'polygammapolygamma_n_1': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
- 'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
- 'polygammapolygamma_n_4': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
- 'special.polygamma': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
- 'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int32, torch.int64, torch.int8],
+ 'polygammapolygamma_n_0': [torch.float32, torch.int16, torch.int8],
+ 'polygammapolygamma_n_2': [torch.float32, torch.int16, torch.int8],
+ 'polygammapolygamma_n_1': [torch.float32, torch.int16, torch.int8],
+ 'polygammapolygamma_n_3': [torch.float32, torch.int16, torch.int8],
+ 'polygammapolygamma_n_4': [torch.float32, torch.int16, torch.int8],
+ 'special.polygamma': [torch.float32, torch.int16, torch.int32, torch.int8],
+ 'special.polygammaspecial_polygamma_n_0': [torch.float32, torch.int16, torch.int8],
}
# Those ops are not expected to work
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index a5e923e..5e816cb 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1931,8 +1931,6 @@
symbolic_tensor_failures.update(symbolic_tensor_segfaults)
outplace_symbolic_tensor_failures = {
- xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
-
xfail('linalg.norm', ''),
}
@@ -1957,7 +1955,6 @@
xfail('fft.ifft2', ''),
xfail('fft.ifftn', ''),
xfail('gather', ''),
- xfail('i0', ''),
xfail('linalg.cholesky', ''),
xfail('linalg.cholesky_ex', ''),
xfail('linalg.det', ''),
diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py
index 038a77b..38eedb5 100644
--- a/torch/_inductor/codegen/common.py
+++ b/torch/_inductor/codegen/common.py
@@ -26,6 +26,7 @@
import torch
import torch.fx
+from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils._sympy.value_ranges import ValueRanges
from .. import config, metrics
@@ -523,6 +524,262 @@
def load_seed(name, offset):
return ops.load(name, sympy.Integer(offset))
+ @classmethod
+ def _initialize_pointwise_overrides(cls, target):
+ assert target in {"triton", "cpp", "cppvec"}, target
+
+ def pointwise_factory_1(impl):
+ def func(x):
+ return impl.format(x=x)
+
+ return func
+
+ def pointwise_factory_2(impl):
+ def func(x, y):
+ return impl.format(x=x, y=y)
+
+ return func
+
+ for funcname, data in pointwise_overrides_data.items():
+ impl = getattr(data, target)
+ if isinstance(impl, str):
+ nof_args = 2 if "{y}" in impl else 1
+ # extend the following dictionary with factory
+ # functions for a specific number of arguments as
+ # needed:
+ factory = {1: pointwise_factory_1, 2: pointwise_factory_2}[nof_args]
+ setattr(cls, funcname, staticmethod(factory(impl)))
+
+
+@dataclasses.dataclass
+class OverridesData:
+ name: str
+ cpp: str
+ triton: Optional[str] = None # None when not impl in libdevice/triton
+ cppvec: Optional[str] = None # None when not impl in aten/.../vec
+ type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
+ ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ )
+
+
+pointwise_overrides_data: Dict[str, OverridesData] = dict(
+ airy_ai=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="airy_ai_forward({x})",
+ name="special_airy_ai",
+ ),
+ bessel_j0=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="bessel_j0_forward({x})",
+ triton="tl.math.j0({x})",
+ name="special_bessel_j0",
+ ),
+ bessel_j1=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="bessel_j1_forward({x})",
+ triton="tl.math.j1({x})",
+ name="special_bessel_j1",
+ ),
+ bessel_y0=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="bessel_y0_forward({x})",
+ triton="tl.math.y0({x})",
+ name="special_bessel_y0",
+ ),
+ bessel_y1=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="bessel_y1_forward({x})",
+ triton="tl.math.y1({x})",
+ name="special_bessel_y1",
+ ),
+ digamma=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="calc_digamma({x})",
+ cppvec="{x}.digamma()",
+ name="digamma",
+ ),
+ # no cpp nor triton implementation for entr, it is defined as decomposition
+ # erf, erfc
+ erfcx=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="calc_erfcx({x})",
+ triton="tl.math.erfcx({x})",
+ name="special_erfcx",
+ ),
+ # erfinv, exp2, expit, gammaln
+ igamma=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="calc_igamma({x}, {y})",
+ name="igamma",
+ ),
+ igammac=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="calc_igammac({x}, {y})",
+ name="igammac",
+ ),
+ gammainc=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="calc_igamma({x}, {y})",
+ name="special_gammainc",
+ ),
+ gammaincc=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="calc_igammac({x}, {y})",
+ name="special_gammaincc",
+ ),
+ i0=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="calc_i0({x})",
+ triton="tl.math.cyl_bessel_i0({x})",
+ cppvec="{x}.i0()",
+ name="i0",
+ ),
+ i0e=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="calc_i0e({x})",
+ cppvec="{x}.i0e()",
+ name="special_i0e",
+ ),
+ i1=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="calc_i1({x})",
+ triton="tl.math.cyl_bessel_i1({x})",
+ name="special_i1",
+ ),
+ i1e=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="calc_i1e({x})",
+ name="special_i1e",
+ ),
+ log_ndtr=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="calc_log_ndtr({x})",
+ name="special_log_ndtr",
+ ),
+ # logit
+ modified_bessel_i0=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="modified_bessel_i0_forward({x})",
+ triton="tl.math.cyl_bessel_i0({x})",
+ name="special_modified_bessel_i0",
+ ),
+ modified_bessel_i1=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="modified_bessel_i1_forward({x})",
+ triton="tl.math.cyl_bessel_i1({x})",
+ name="special_modified_bessel_i1",
+ ),
+ modified_bessel_k0=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="modified_bessel_k0_forward({x})",
+ name="special_modified_bessel_k0",
+ ),
+ modified_bessel_k1=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="modified_bessel_k1_forward({x})",
+ name="special_modified_bessel_k1",
+ ),
+ # multigamma
+ ndtr=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="calc_ndtr({x})",
+ name="special_ndtr",
+ ),
+ ndtri=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="calc_ndtri({x})",
+ name="special_ndtri",
+ ),
+ polygamma=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="calc_polygamma({y}, {x})",
+ name="polygamma",
+ ),
+ # psi - alias to digamma
+ # round
+ scaled_modified_bessel_k0=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="scaled_modified_bessel_k0_forward({x})",
+ name="special_scaled_modified_bessel_k0",
+ ),
+ scaled_modified_bessel_k1=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="scaled_modified_bessel_k1_forward({x})",
+ name="special_scaled_modified_bessel_k1",
+ ),
+ # sinc
+ spherical_bessel_j0=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="spherical_bessel_j0_forward({x})",
+ name="special_spherical_bessel_j0",
+ ),
+ zeta=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="zeta({x}, {y})",
+ name="special_zeta",
+ ),
+ chebyshev_polynomial_t=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="chebyshev_polynomial_t_forward({x}, {y})",
+ name="special_chebyshev_polynomial_t",
+ ),
+ chebyshev_polynomial_u=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="chebyshev_polynomial_u_forward({x}, {y})",
+ name="special_chebyshev_polynomial_u",
+ ),
+ chebyshev_polynomial_v=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="chebyshev_polynomial_v_forward({x}, {y})",
+ name="special_chebyshev_polynomial_v",
+ ),
+ chebyshev_polynomial_w=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="chebyshev_polynomial_w_forward({x}, {y})",
+ name="special_chebyshev_polynomial_w",
+ ),
+ legendre_polynomial_p=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="legendre_polynomial_p_forward({x}, {y})",
+ name="special_legendre_polynomial_p",
+ ),
+ shifted_chebyshev_polynomial_t=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="shifted_chebyshev_polynomial_t_forward({x}, {y})",
+ name="special_shifted_chebyshev_polynomial_t",
+ ),
+ shifted_chebyshev_polynomial_u=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="shifted_chebyshev_polynomial_u_forward({x}, {y})",
+ name="special_shifted_chebyshev_polynomial_u",
+ ),
+ shifted_chebyshev_polynomial_v=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="shifted_chebyshev_polynomial_v_forward({x}, {y})",
+ name="special_shifted_chebyshev_polynomial_v",
+ ),
+ shifted_chebyshev_polynomial_w=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="shifted_chebyshev_polynomial_w_forward({x}, {y})",
+ name="special_shifted_chebyshev_polynomial_w",
+ ),
+ hermite_polynomial_h=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="hermite_polynomial_h_forward({x}, {y})",
+ name="special_hermite_polynomial_h",
+ ),
+ hermite_polynomial_he=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="hermite_polynomial_he_forward({x}, {y})",
+ name="special_hermite_polynomial_he",
+ ),
+ laguerre_polynomial_l=OverridesData(
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ cpp="laguerre_polynomial_l_forward({x}, {y})",
+ name="special_laguerre_polynomial_l",
+ ),
+)
+
# Use mypy to check protocol implemented correctly
def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index 87051cb..3a4cdd1 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -936,17 +936,8 @@
V.kernel.compute.splice(code)
return result
- @staticmethod
- def bessel_j0(x):
- return f"bessel_j0_forward({x})"
- @staticmethod
- def bessel_j1(x):
- return f"bessel_j1_forward({x})"
-
- @staticmethod
- def modified_bessel_i0(x):
- return f"modified_bessel_i0_forward({x})"
+CppOverrides._initialize_pointwise_overrides("cpp")
class CppVecOverrides(CppOverrides):
@@ -1454,6 +1445,9 @@
return csevar
+CppVecOverrides._initialize_pointwise_overrides("cppvec")
+
+
class CppTile2DOverrides(CppVecOverrides):
@staticmethod
def index_expr(expr, dtype):
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index 321218f..ce77e85 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -846,17 +846,8 @@
def ceil(x):
return f"tl.math.ceil({x})"
- @staticmethod
- def bessel_j0(x):
- return f"tl.math.j0({x})"
- @staticmethod
- def bessel_j1(x):
- return f"tl.math.j1({x})"
-
- @staticmethod
- def modified_bessel_i0(x):
- return f"tl.math.cyl_bessel_i0({x})"
+TritonOverrides._initialize_pointwise_overrides("triton")
# Use mypy to check protocol implemented correctly
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 09d2188..54e1f1a 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -396,8 +396,13 @@
override_fn_when_input_bool=None,
override_fn_when_cuda_float64=None,
allow_alpha=False,
+ triton_fallback=None,
):
def inner(*inputs: List[TensorBox], alpha=None):
+ if triton_fallback is not None and any(map(is_triton, inputs)):
+ assert not allow_alpha # not implemented
+ return triton_fallback(*inputs)
+
inputs = promote_constants(inputs, override_return_dtype)
if allow_alpha:
if alpha is not None and alpha != 1:
@@ -605,6 +610,7 @@
override_fn_when_input_bool=None,
allow_alpha=False,
use_libdevice_for_f64=False,
+ triton_fallback=None,
):
"""A pointwise function that maps ops.{name} to inputs"""
name = name or aten_fn.__name__
@@ -620,6 +626,7 @@
override_fn_when_input_bool=override_fn_when_input_bool,
override_fn_when_cuda_float64=fn_libdevice if use_libdevice_for_f64 else None, # type: ignore[possibly-undefined]
allow_alpha=allow_alpha,
+ triton_fallback=triton_fallback,
)
fn = register_lowering(
aten_fn,
@@ -2233,7 +2240,6 @@
make_fallback(aten._cdist_forward)
make_fallback(aten.cummax)
make_fallback(aten.cummin)
-make_fallback(aten.digamma, warn=False)
make_fallback(aten._efficientzerotensor)
make_fallback(aten._embedding_bag_per_sample_weights_backward)
make_fallback(aten._efficientzerotensor)
@@ -2243,9 +2249,6 @@
make_fallback(aten.frexp)
make_fallback(aten.geqrf)
make_fallback(aten.histc)
-make_fallback(aten.i0)
-make_fallback(aten.igamma, warn=False)
-make_fallback(aten.igammac, warn=False)
make_fallback(aten.isin)
make_fallback(aten.kthvalue)
make_fallback(aten.linalg_cholesky_ex)
@@ -2274,33 +2277,12 @@
make_fallback(aten.nanmedian)
make_fallback(aten.ormqr)
make_fallback(aten._pdist_forward)
-make_fallback(aten.polygamma)
make_fallback(aten.put)
make_fallback(aten.resize)
make_fallback(aten.resize_)
make_fallback(aten.resize_as)
make_fallback(aten.resize_as_)
make_fallback(aten.searchsorted)
-make_fallback(aten.special_airy_ai)
-make_fallback(aten.special_bessel_y0, warn=False)
-make_fallback(aten.special_bessel_y1)
-make_fallback(aten.special_chebyshev_polynomial_t)
-make_fallback(aten.special_chebyshev_polynomial_u)
-make_fallback(aten.special_erfcx, warn=False)
-make_fallback(aten.special_hermite_polynomial_h)
-make_fallback(aten.special_hermite_polynomial_he)
-make_fallback(aten.special_i0e, warn=False)
-make_fallback(aten.special_i1, warn=False)
-make_fallback(aten.special_i1e, warn=False)
-make_fallback(aten.special_laguerre_polynomial_l)
-make_fallback(aten.special_modified_bessel_i1)
-make_fallback(aten.special_modified_bessel_k0)
-make_fallback(aten.special_modified_bessel_k1)
-make_fallback(aten.special_ndtri, warn=False)
-make_fallback(aten.special_scaled_modified_bessel_k0)
-make_fallback(aten.special_scaled_modified_bessel_k1)
-make_fallback(aten.special_spherical_bessel_j0, warn=False)
-make_fallback(aten.special_zeta, warn=False)
make_fallback(aten._trilinear)
make_fallback(aten.uniform, warn=False)
make_fallback(aten._adaptive_avg_pool3d_backward)
@@ -5097,11 +5079,12 @@
)
-def register_pointwise_numeric(op, name=None):
+def register_pointwise_numeric(op, name=None, triton_fallback=None):
return register_pointwise(
op,
name=name,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ triton_fallback=triton_fallback,
)
@@ -5203,11 +5186,48 @@
register_pointwise_numeric(aten.log10)
register_pointwise_numeric(aten.nextafter)
-register_pointwise_numeric(aten.special_bessel_j0, name="bessel_j0")
-register_pointwise_numeric(prims.bessel_j0, name="bessel_j0")
-register_pointwise_numeric(aten.special_bessel_j1, name="bessel_j1")
-register_pointwise_numeric(prims.bessel_j1, name="bessel_j1")
-register_pointwise_numeric(aten.special_modified_bessel_i0, name="modified_bessel_i0")
+from .codegen.common import pointwise_overrides_data
+
+
+def _get_pointwise_overrides(ns, name):
+ data = pointwise_overrides_data[name]
+ op = getattr(ns, data.name, None)
+ if op is None:
+ return
+
+ def make_triton_fallback(op):
+ if data.triton is None:
+ return fallback_handler(op)
+
+ if isinstance(op, torch._ops.OpOverloadPacket):
+ for olname in op.overloads():
+ ol = getattr(op, olname)
+ yield ol, data.type_promotion_kind, make_triton_fallback(ol)
+ else:
+ yield op, data.type_promotion_kind, make_triton_fallback(op)
+
+
+for name in pointwise_overrides_data:
+ for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides(
+ aten, name
+ ):
+ register_pointwise(
+ op,
+ name=name,
+ type_promotion_kind=type_promotion_kind,
+ triton_fallback=triton_fallback,
+ )
+
+ for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides(
+ prims, name
+ ):
+ register_pointwise(
+ op,
+ name=name,
+ type_promotion_kind=type_promotion_kind,
+ triton_fallback=triton_fallback,
+ )
+
foreach_add_list = register_foreach_pointwise(
aten._foreach_add.List, add, allow_alpha=True
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index a837161..62cde04c 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -6206,9 +6206,16 @@
_create_binary_float_meta_func(aten.special_chebyshev_polynomial_t)
_create_binary_float_meta_func(aten.special_chebyshev_polynomial_u)
+_create_binary_float_meta_func(aten.special_chebyshev_polynomial_v)
+_create_binary_float_meta_func(aten.special_chebyshev_polynomial_w)
+_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_t)
+_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_u)
+_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_v)
+_create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_w)
_create_binary_float_meta_func(aten.special_hermite_polynomial_h)
_create_binary_float_meta_func(aten.special_hermite_polynomial_he)
_create_binary_float_meta_func(aten.special_laguerre_polynomial_l)
+_create_binary_float_meta_func(aten.special_legendre_polynomial_p)
# We must also trigger meta registrations from PrimTorch ref
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index aaa3ffe..d18e24b 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -739,7 +739,7 @@
# TODO: if this is special maybe it should be defined there and imported here?
@_make_elementwise_unary_reference(
- ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=aten.special_i0
+ ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=aten.i0
)
def i0(a):
return prims.bessel_i0(a)
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 8d3a721..acff31b 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -16170,98 +16170,43 @@
skips=(
DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
),
- sample_kwargs=lambda device, dtype, input: ({'n': 0}, {'n': 0})),
- UnaryUfuncInfo('polygamma',
- op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
- variant_test_name='polygamma_n_1',
- ref=reference_polygamma if TEST_SCIPY else None,
- dtypes=all_types_and(torch.bool, torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- promotes_int_to_float=True,
- sample_inputs_func=sample_inputs_polygamma,
- skips=(
- # Redundant tests
- DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
- DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
- DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'),
- DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
- # Mismatch: https://github.com/pytorch/pytorch/issues/55357
- DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
- DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large'),
- ),
- sample_kwargs=lambda device, dtype, input: ({'n': 1}, {'n': 1}),
- # polygamma functions have multiple singularities at x <= 0
- reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)),
- UnaryUfuncInfo('polygamma',
- op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
- variant_test_name='polygamma_n_2',
- ref=reference_polygamma if TEST_SCIPY else None,
- dtypes=all_types_and(torch.bool, torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- promotes_int_to_float=True,
- sample_inputs_func=sample_inputs_polygamma,
- skips=(
- # Redundant tests
- DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
- DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
- DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'),
- DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
- # Mismatch: https://github.com/pytorch/pytorch/issues/55357
- DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),),
- sample_kwargs=lambda device, dtype, input: ({'n': 2}, {'n': 2}),
- # polygamma functions have multiple singularities at x <= 0
- reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)),
- UnaryUfuncInfo('polygamma',
- op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
- variant_test_name='polygamma_n_3',
- ref=reference_polygamma if TEST_SCIPY else None,
- dtypes=all_types_and(torch.bool, torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- promotes_int_to_float=True,
- sample_inputs_func=sample_inputs_polygamma,
- skips=(
- # Redundant tests
- DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
- DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
- DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'),
- DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
- # Mismatch: https://github.com/pytorch/pytorch/issues/55357
- DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),),
- sample_kwargs=lambda device, dtype, input: ({'n': 3}, {'n': 3}),
- # polygamma functions have multiple singularities at x <= 0
- reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)),
- UnaryUfuncInfo('polygamma',
- op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
- variant_test_name='polygamma_n_4',
- ref=reference_polygamma if TEST_SCIPY else None,
- decorators=(precisionOverride({torch.float16: 5e-4, torch.float32: 5e-4}),),
- dtypes=all_types_and(torch.bool, torch.bfloat16),
- dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
- supports_forward_ad=True,
- supports_fwgrad_bwgrad=True,
- promotes_int_to_float=True,
- sample_inputs_func=sample_inputs_polygamma,
- skips=(
- # Redundant tests
- DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
- DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
- DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
- DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'),
- DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
- # Mismatch: https://github.com/pytorch/pytorch/issues/55357
- DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),),
- sample_kwargs=lambda device, dtype, input: ({'n': 4}, {'n': 4}),
- # polygamma functions have multiple singularities at x <= 0
- reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)),
+ sample_kwargs=lambda device, dtype, input: ({'n': 0}, {'n': 0}),
+ # polygamma functions have multiple singularities at x having non-positive integer value
+ reference_numerics_filter=NumericsFilter(condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4),
+ safe_val=1)),
+ *(UnaryUfuncInfo('polygamma',
+ op=lambda x, n, **kwargs: torch.polygamma(n, x, **kwargs),
+ variant_test_name=f'polygamma_n_{n_}',
+ ref=reference_polygamma if TEST_SCIPY else None,
+ dtypes=all_types_and(torch.bool, torch.bfloat16),
+ dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
+ supports_forward_ad=True,
+ supports_fwgrad_bwgrad=True,
+ promotes_int_to_float=True,
+ sample_inputs_func=sample_inputs_polygamma,
+ decorators=(
+ DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-3)}), 'TestUnaryUfuncs'),
+ DecorateInfo(toleranceOverride({torch.bfloat16: tol(atol=1e1, rtol=1e-1),
+ torch.float32: tol(atol=1e-4, rtol=1e-2)}),
+ 'TestUnaryUfuncs', 'test_reference_numerics_normal',
+ active_if=IS_WINDOWS),
+ ),
+ skips=(
+ # Redundant tests
+ DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestJit'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'),
+ # Mismatch: https://github.com/pytorch/pytorch/issues/55357
+ DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large'),
+ ),
+ sample_kwargs=lambda device, dtype, input: ({'n': n_}, {'n': n_}),
+ # polygamma functions have multiple singularities at x having non-positive integer value
+ reference_numerics_filter=NumericsFilter(condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4),
+ safe_val=1))
+ for n_ in (1, 2, 3, 4)),
OpInfo('ravel',
ref=np.ravel,
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py
index 545d28b..0357e9f 100644
--- a/torch/testing/_internal/common_modules.py
+++ b/torch/testing/_internal/common_modules.py
@@ -4138,6 +4138,10 @@
),
ModuleInfo(torch.nn.Embedding,
module_inputs_func=module_inputs_torch_nn_Embedding,
+ decorators=[
+ DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
+ 'TestModule', 'test_non_contiguous_tensors',
+ device_type='mps')],
skips=(
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
),
diff --git a/torch/testing/_internal/dynamo_test_failures.py b/torch/testing/_internal/dynamo_test_failures.py
index 177eb4c..5bea6db 100644
--- a/torch/testing/_internal/dynamo_test_failures.py
+++ b/torch/testing/_internal/dynamo_test_failures.py
@@ -1488,7 +1488,6 @@
"TestTorchDeviceTypeCPU.test_broadcast_fn_div_cpu", # test_torch
"TestTorchDeviceTypeCPU.test_nondeterministic_resize_quantized_cpu_quint8", # test_torch
"TestTorchDeviceTypeCPU.test_broadcast_fn_lt_cpu", # test_torch
- "TestTorchDeviceTypeCPU.test_memory_format_operators_cpu", # test_torch
"TestTorch.test_pin_memory", # test_torch
"TestTorchDeviceTypeCPU.test_broadcast_fn_masked_fill_cpu", # test_torch
"TestTorchDeviceTypeCPU.test_nondeterministic_alert_MaxUnpool2d_cpu_float64", # test_torch
diff --git a/torch/testing/_internal/opinfo/definitions/special.py b/torch/testing/_internal/opinfo/definitions/special.py
index cc1e41d..426eb7e 100644
--- a/torch/testing/_internal/opinfo/definitions/special.py
+++ b/torch/testing/_internal/opinfo/definitions/special.py
@@ -65,7 +65,12 @@
def sample_inputs_polygamma(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(
- make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
+ make_tensor,
+ device=device,
+ # TODO: eliminate low after gh-106692 is fixed:
+ low=(1 if dtype in {torch.int32, torch.int64} else None),
+ dtype=dtype,
+ requires_grad=requires_grad,
)
tensor_shapes = ((S, S), ())
ns = (1, 2, 3, 4, 5)
@@ -101,6 +106,19 @@
yield SampleInput(make_arg(()))
+def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs):
+ for shape in ((L,), (1, 0, 3), ()):
+ yield SampleInput(
+ make_tensor(
+ shape,
+ device=device,
+ dtype=dtype,
+ low=-5,
+ requires_grad=requires_grad,
+ ),
+ )
+
+
op_db: List[OpInfo] = [
UnaryUfuncInfo(
"special.i0e",
@@ -195,9 +213,9 @@
),
),
sample_kwargs=lambda device, dtype, input: ({"n": 0}, {"n": 0}),
- # polygamma functions have multiple singularities at x <= 0
+ # polygamma functions have multiple singularities at x having non-positive integer value
reference_numerics_filter=NumericsFilter(
- condition=lambda x: x < 0.1, safe_val=1
+ condition=lambda x: (x < 0.1) & ((x - x.round()).abs() < 1e-4), safe_val=1
),
),
BinaryUfuncInfo(
@@ -293,6 +311,7 @@
dtypes=all_types_and(torch.bool),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
+ sample_inputs_func=sample_inputs_erfcx,
),
UnaryUfuncInfo(
"special.airy_ai",