[inductor] Lower divide by constant as multiplication by reciprocal (#121924)

Fixes #101039

This lowers division by a constant value to be multipication by reciprocal.
The same optimization is applied in eager mode on CUDA:

https://github.com/pytorch/pytorch/blob/0636c11811e15d1919cdd6cf20cb2d2bed2ee1da/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu#L36-L38

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121924
Approved by: https://github.com/lezcano
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 4a1e18f..e6ba330 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -2068,6 +2068,39 @@
             fn_int_input, (make_tensor(10, device=self.device, dtype=torch.float32), 33)
         )
 
+    def test_div_precision(self):
+        # Reproducer for https://github.com/pytorch/pytorch/issues/101039
+
+        def forward(y):
+            z = y.div(1e-06)
+            return F.softmax(z, dim=-1)
+
+        query = torch.randn(1, 10, 40)
+        key = torch.randn(1, 2, 40)
+        y = torch.matmul(query, key.transpose(-2, -1))
+        self.common(forward, (y,))
+
+    def test_div_by_zero(self):
+        def fn(x, runtime_zero, runtime_neg_zero):
+            zero = torch.zeros_like(x)
+            return (
+                x / 0.0,
+                x / -0.0,
+                zero / 0.0,
+                x / zero,
+                x / -zero,
+                zero / zero,
+                x / runtime_zero,
+                # NOTE: -runtime_zero doesn't work as -(0.0) is broken in triton
+                x / runtime_neg_zero,
+                runtime_zero / runtime_neg_zero,
+            )
+
+        a = torch.randn(10)
+        zero = torch.zeros(10)
+        neg_zero = -zero
+        self.common(fn, (a, zero, neg_zero))
+
     def test_both_scalars(self):
         def fn(a, b):
             return (
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 2f99229..4532060 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -1,11 +1,13 @@
 import functools
 import itertools
 import logging
+import math
 import operator
 import os
 import warnings
 from collections import defaultdict
 from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
+from unittest.mock import patch
 
 import sympy
 
@@ -5231,6 +5233,35 @@
         return make_pointwise(fn)(a, b)
 
 
+def get_constant_value(x: ir.IRNode) -> Optional[ir.Constant]:
+    """Try convert an arbitrary IR node into an ir.Constant value"""
+
+    # First try unwrapping the IRNode to see if it is already an ir.Constant
+    # Optional step, but avoids unnecessary inner_fn evaluation.
+    if isinstance(x, ir.MutableBox):
+        return get_constant_value(x.data)
+    if isinstance(x, ir.BaseView):
+        return get_constant_value(x.unwrap_view())
+    if isinstance(x, ir.Constant):
+        return x
+
+    # If the unwrapped node is not an ir.Constant, try evaluating inner_fn
+    # to see if the returned value is from an `ops.constant` call
+    if not isinstance(x, ir.Loops):
+        return None
+
+    handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device())
+    with V.set_ops_handler(handler), patch.object(
+        ir.FlexibleLayout, "allow_indexing", True
+    ):
+        out = x.inner_fn(*x.inner_fn_args())
+
+    assert isinstance(out, torch._inductor.virtualized.OpsValue)
+    if isinstance(out.value, ir.Constant):
+        return out.value
+    return None
+
+
 # NOTE: prims.div maps to a / b in C, so performs truncation division on
 #   integer inputs and true division for floating and complex inputs.
 @register_lowering([prims.div], broadcast=True)
@@ -5240,6 +5271,14 @@
     if is_integral:
         return truncdiv(a, b)
 
+    if (divisor := get_constant_value(b)) is not None:
+        # Replace divide by constant with multiply by reciprocal
+        if divisor.value == 0:
+            reciprocal = math.copysign(float("inf"), divisor.value)
+        else:
+            reciprocal = 1.0 / divisor.value
+        return mul(a, reciprocal)
+
     def fn(*args):
         return ops.truediv(*args)
 
diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py
index 2e8d5e0..6c5537f 100644
--- a/torch/_inductor/ops_handler.py
+++ b/torch/_inductor/ops_handler.py
@@ -494,6 +494,38 @@
         ...
 
 
+class NoopHandler:
+    def __getattr__(self, name):
+        if name == "name":
+            return "NoopHandler"
+
+        def inner(*args, **kwargs):
+            return None
+
+        return inner
+
+    @staticmethod
+    def masked(mask, body, other) -> None:
+        return None
+
+    @staticmethod
+    def frexp(x) -> Tuple[None, None]:
+        return (None, None)
+
+    @staticmethod
+    def scan(dtypes, combine_fn, values, inits) -> Tuple[None, ...]:
+        return tuple(None for i in range(len(values)))
+
+    @staticmethod
+    def indirect_indexing(index_var, size, check=True) -> sympy.Symbol:
+        return sympy.Integer(0)
+
+
+# Use mypy to check protocol implemented correctly
+def _typecheck_NoopHandler(h: NoopHandler) -> OpsHandler[None]:
+    return h
+
+
 class MockHandler:
     def __getattr__(self, name):
         if name == "name":
@@ -664,3 +696,17 @@
 
 def _typecheck_OpCounterCSE(h: OpCounterCSE) -> OpsHandler[str]:
     return h
+
+
+class ExtractConstantsHandler(NoopHandler):
+    def __init__(self, device):
+        self.device = device
+
+    def constant(self, value: Any, dtype: torch.dtype) -> "torch._inductor.ir.Constant":
+        from torch._inductor import ir
+
+        return ir.Constant(value=value, dtype=dtype, device=self.device)
+
+
+def _typecheck_ExtractConstantsHandler(h: ExtractConstantsHandler) -> OpsHandler[Any]:
+    return h