Add lowering to special.bessel_j0 (2nd try) (#118565)
This PR is a copy of https://github.com/pytorch/pytorch/pull/118464 that was merged without using pytorchbot. Sorry for the noise!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118565
Approved by: https://github.com/peterbell10
diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py
index 8e08670..5f0cfc4 100644
--- a/test/inductor/test_cpu_repro.py
+++ b/test/inductor/test_cpu_repro.py
@@ -1348,6 +1348,7 @@
cpp_op_list.append(k)
diff = [
+ "bessel_j0",
"constant",
"index_expr",
"signbit",
diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py
index dbaaee9..844e0b9 100644
--- a/test/inductor/test_perf.py
+++ b/test/inductor/test_perf.py
@@ -581,7 +581,10 @@
def unfusible(x):
- return aten.special_bessel_j0(x)
+ # For the purpose of noop tests, we want inductor to fall back to
+ # eager mode, so, below we must use a aten operator that does not
+ # have decomposition nor lowering:
+ return aten._lazy_clone(x)
class NoopTests(TestCase):
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 14429e8..9674849 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -8428,6 +8428,12 @@
# should_pad_bench always returns False if has_triton returns False
self.assertFalse(should_pad)
+ def test_bessel_j0(self):
+ def fn(x):
+ return torch.special.bessel_j0(x)
+
+ self.common(fn, (torch.randn(8, 8),))
+
@dataclasses.dataclass
class TestFailure:
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index d55c6c6..ef831be 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -972,6 +972,10 @@
V.kernel.compute.splice(code)
return result
+ @staticmethod
+ def bessel_j0(x):
+ return f"bessel_j0_forward({x})"
+
class CppVecOverrides(CppOverrides):
"""Map element-wise ops to aten vectorization C++"""
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index 12c32bd..2d89ca8 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -835,6 +835,10 @@
def ceil(x):
return f"tl.math.ceil({x})"
+ @staticmethod
+ def bessel_j0(x):
+ return f"tl.math.j0({x})"
+
# Use mypy to check protocol implemented correctly
def _typecheck_TritonOverrides(h: TritonOverrides) -> OpsHandler[str]:
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index d5ee499..3431664 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -768,6 +768,15 @@
return make_pointwise(fn)(x)
+@register_lowering(
+ [aten.special_bessel_j0, prims.bessel_j0],
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+)
+def bessel_j0(x):
+ fn = ops_wrapper("bessel_j0")
+ return make_pointwise(fn)(x)
+
+
@register_lowering(aten.expand, type_promotion_kind=None)
def expand(x, sizes):
(x,) = promote_constants([x])
@@ -2052,7 +2061,7 @@
make_fallback(aten._fused_moving_avg_obs_fq_helper)
make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
make_fallback(aten.grid_sampler_2d_backward, require_dense)
-make_fallback(aten.randperm)
+make_fallback(aten.randperm) # needs sort
def sdpa_constraint(fx_node, *args, **kwargs):
@@ -2224,7 +2233,6 @@
make_fallback(aten.resize_as_)
make_fallback(aten.searchsorted)
make_fallback(aten.special_airy_ai)
-make_fallback(aten.special_bessel_j0, warn=False)
make_fallback(aten.special_bessel_j1, warn=False)
make_fallback(aten.special_bessel_y0, warn=False)
make_fallback(aten.special_bessel_y1)
@@ -2265,7 +2273,7 @@
make_fallback(aten.linalg_pinv.atol_rtol_tensor)
make_fallback(aten.segment_reduce.default)
make_fallback(aten._segment_reduce_backward.default)
-make_fallback(aten.angle)
+make_fallback(aten.angle) # needs complex
make_fallback(aten.cholesky_inverse)
make_fallback(aten.cholesky_solve)
make_fallback(aten._fft_r2c)