[inductor] Lower divide by constant as multiplication by reciprocal (#121924)
Fixes #101039
This lowers division by a constant value to be multipication by reciprocal.
The same optimization is applied in eager mode on CUDA:
https://github.com/pytorch/pytorch/blob/0636c11811e15d1919cdd6cf20cb2d2bed2ee1da/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu#L36-L38
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121924
Approved by: https://github.com/lezcano
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 4a1e18f..e6ba330 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -2068,6 +2068,39 @@
fn_int_input, (make_tensor(10, device=self.device, dtype=torch.float32), 33)
)
+ def test_div_precision(self):
+ # Reproducer for https://github.com/pytorch/pytorch/issues/101039
+
+ def forward(y):
+ z = y.div(1e-06)
+ return F.softmax(z, dim=-1)
+
+ query = torch.randn(1, 10, 40)
+ key = torch.randn(1, 2, 40)
+ y = torch.matmul(query, key.transpose(-2, -1))
+ self.common(forward, (y,))
+
+ def test_div_by_zero(self):
+ def fn(x, runtime_zero, runtime_neg_zero):
+ zero = torch.zeros_like(x)
+ return (
+ x / 0.0,
+ x / -0.0,
+ zero / 0.0,
+ x / zero,
+ x / -zero,
+ zero / zero,
+ x / runtime_zero,
+ # NOTE: -runtime_zero doesn't work as -(0.0) is broken in triton
+ x / runtime_neg_zero,
+ runtime_zero / runtime_neg_zero,
+ )
+
+ a = torch.randn(10)
+ zero = torch.zeros(10)
+ neg_zero = -zero
+ self.common(fn, (a, zero, neg_zero))
+
def test_both_scalars(self):
def fn(a, b):
return (
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 2f99229..4532060 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -1,11 +1,13 @@
import functools
import itertools
import logging
+import math
import operator
import os
import warnings
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
+from unittest.mock import patch
import sympy
@@ -5231,6 +5233,35 @@
return make_pointwise(fn)(a, b)
+def get_constant_value(x: ir.IRNode) -> Optional[ir.Constant]:
+ """Try convert an arbitrary IR node into an ir.Constant value"""
+
+ # First try unwrapping the IRNode to see if it is already an ir.Constant
+ # Optional step, but avoids unnecessary inner_fn evaluation.
+ if isinstance(x, ir.MutableBox):
+ return get_constant_value(x.data)
+ if isinstance(x, ir.BaseView):
+ return get_constant_value(x.unwrap_view())
+ if isinstance(x, ir.Constant):
+ return x
+
+ # If the unwrapped node is not an ir.Constant, try evaluating inner_fn
+ # to see if the returned value is from an `ops.constant` call
+ if not isinstance(x, ir.Loops):
+ return None
+
+ handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device())
+ with V.set_ops_handler(handler), patch.object(
+ ir.FlexibleLayout, "allow_indexing", True
+ ):
+ out = x.inner_fn(*x.inner_fn_args())
+
+ assert isinstance(out, torch._inductor.virtualized.OpsValue)
+ if isinstance(out.value, ir.Constant):
+ return out.value
+ return None
+
+
# NOTE: prims.div maps to a / b in C, so performs truncation division on
# integer inputs and true division for floating and complex inputs.
@register_lowering([prims.div], broadcast=True)
@@ -5240,6 +5271,14 @@
if is_integral:
return truncdiv(a, b)
+ if (divisor := get_constant_value(b)) is not None:
+ # Replace divide by constant with multiply by reciprocal
+ if divisor.value == 0:
+ reciprocal = math.copysign(float("inf"), divisor.value)
+ else:
+ reciprocal = 1.0 / divisor.value
+ return mul(a, reciprocal)
+
def fn(*args):
return ops.truediv(*args)
diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py
index 2e8d5e0..6c5537f 100644
--- a/torch/_inductor/ops_handler.py
+++ b/torch/_inductor/ops_handler.py
@@ -494,6 +494,38 @@
...
+class NoopHandler:
+ def __getattr__(self, name):
+ if name == "name":
+ return "NoopHandler"
+
+ def inner(*args, **kwargs):
+ return None
+
+ return inner
+
+ @staticmethod
+ def masked(mask, body, other) -> None:
+ return None
+
+ @staticmethod
+ def frexp(x) -> Tuple[None, None]:
+ return (None, None)
+
+ @staticmethod
+ def scan(dtypes, combine_fn, values, inits) -> Tuple[None, ...]:
+ return tuple(None for i in range(len(values)))
+
+ @staticmethod
+ def indirect_indexing(index_var, size, check=True) -> sympy.Symbol:
+ return sympy.Integer(0)
+
+
+# Use mypy to check protocol implemented correctly
+def _typecheck_NoopHandler(h: NoopHandler) -> OpsHandler[None]:
+ return h
+
+
class MockHandler:
def __getattr__(self, name):
if name == "name":
@@ -664,3 +696,17 @@
def _typecheck_OpCounterCSE(h: OpCounterCSE) -> OpsHandler[str]:
return h
+
+
+class ExtractConstantsHandler(NoopHandler):
+ def __init__(self, device):
+ self.device = device
+
+ def constant(self, value: Any, dtype: torch.dtype) -> "torch._inductor.ir.Constant":
+ from torch._inductor import ir
+
+ return ir.Constant(value=value, dtype=dtype, device=self.device)
+
+
+def _typecheck_ExtractConstantsHandler(h: ExtractConstantsHandler) -> OpsHandler[Any]:
+ return h