Add lowering to special.bessel_j0 (2nd try) (#118565)

This PR is a copy of https://github.com/pytorch/pytorch/pull/118464 that was merged without using pytorchbot. Sorry for the noise!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118565
Approved by: https://github.com/peterbell10
diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py
index 8e08670..5f0cfc4 100644
--- a/test/inductor/test_cpu_repro.py
+++ b/test/inductor/test_cpu_repro.py
@@ -1348,6 +1348,7 @@
                 cpp_op_list.append(k)
 
         diff = [
+            "bessel_j0",
             "constant",
             "index_expr",
             "signbit",
diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py
index dbaaee9..844e0b9 100644
--- a/test/inductor/test_perf.py
+++ b/test/inductor/test_perf.py
@@ -581,7 +581,10 @@
 
 
 def unfusible(x):
-    return aten.special_bessel_j0(x)
+    # For the purpose of noop tests, we want inductor to fall back to
+    # eager mode, so, below we must use a aten operator that does not
+    # have decomposition nor lowering:
+    return aten._lazy_clone(x)
 
 
 class NoopTests(TestCase):
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 14429e8..9674849 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -8428,6 +8428,12 @@
             # should_pad_bench always returns False if has_triton returns False
             self.assertFalse(should_pad)
 
+    def test_bessel_j0(self):
+        def fn(x):
+            return torch.special.bessel_j0(x)
+
+        self.common(fn, (torch.randn(8, 8),))
+
 
 @dataclasses.dataclass
 class TestFailure:
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index d55c6c6..ef831be 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -972,6 +972,10 @@
         V.kernel.compute.splice(code)
         return result
 
+    @staticmethod
+    def bessel_j0(x):
+        return f"bessel_j0_forward({x})"
+
 
 class CppVecOverrides(CppOverrides):
     """Map element-wise ops to aten vectorization C++"""
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index 12c32bd..2d89ca8 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -835,6 +835,10 @@
     def ceil(x):
         return f"tl.math.ceil({x})"
 
+    @staticmethod
+    def bessel_j0(x):
+        return f"tl.math.j0({x})"
+
 
 # Use mypy to check protocol implemented correctly
 def _typecheck_TritonOverrides(h: TritonOverrides) -> OpsHandler[str]:
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index d5ee499..3431664 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -768,6 +768,15 @@
     return make_pointwise(fn)(x)
 
 
+@register_lowering(
+    [aten.special_bessel_j0, prims.bessel_j0],
+    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+)
+def bessel_j0(x):
+    fn = ops_wrapper("bessel_j0")
+    return make_pointwise(fn)(x)
+
+
 @register_lowering(aten.expand, type_promotion_kind=None)
 def expand(x, sizes):
     (x,) = promote_constants([x])
@@ -2052,7 +2061,7 @@
 make_fallback(aten._fused_moving_avg_obs_fq_helper)
 make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
 make_fallback(aten.grid_sampler_2d_backward, require_dense)
-make_fallback(aten.randperm)
+make_fallback(aten.randperm)  # needs sort
 
 
 def sdpa_constraint(fx_node, *args, **kwargs):
@@ -2224,7 +2233,6 @@
 make_fallback(aten.resize_as_)
 make_fallback(aten.searchsorted)
 make_fallback(aten.special_airy_ai)
-make_fallback(aten.special_bessel_j0, warn=False)
 make_fallback(aten.special_bessel_j1, warn=False)
 make_fallback(aten.special_bessel_y0, warn=False)
 make_fallback(aten.special_bessel_y1)
@@ -2265,7 +2273,7 @@
 make_fallback(aten.linalg_pinv.atol_rtol_tensor)
 make_fallback(aten.segment_reduce.default)
 make_fallback(aten._segment_reduce_backward.default)
-make_fallback(aten.angle)
+make_fallback(aten.angle)  # needs complex
 make_fallback(aten.cholesky_inverse)
 make_fallback(aten.cholesky_solve)
 make_fallback(aten._fft_r2c)