[inductor] Misc division lowering fixes (#88603)
1. `aten.div.Tensor_mode` should allow broadcasting
2. `div` can use `ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT`
3. `prims.div` on integers should be truncating division
4. Add lowering for `true_divide` which is aliased to `div`
5. register lowering for inplace version of `div_mode`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88603
Approved by: https://github.com/ngimel
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index ec024c6..2196f4f 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -23,6 +23,7 @@
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.shape_prop import ShapeProp
from torch.nn import functional as F
+from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
TEST_WITH_ASAN,
TEST_WITH_ROCM,
@@ -1166,6 +1167,45 @@
self.common(fn, (1024, 100))
+ def test_div_zero_dim(self):
+ def fn(a, b):
+ return (
+ aten.div(a, b, rounding_mode=None),
+ aten.div(a, b, rounding_mode="floor"),
+ aten.div(a, b, rounding_mode="trunc"),
+ a / b,
+ a // b,
+ )
+
+ for dtype in (torch.float32, torch.int64):
+ self.common(
+ fn,
+ (
+ make_tensor(10, device="cpu", dtype=dtype),
+ make_tensor((), device="cpu", dtype=dtype, exclude_zero=True),
+ ),
+ )
+ self.common(
+ fn,
+ (
+ make_tensor((), device="cpu", dtype=dtype),
+ make_tensor(10, device="cpu", dtype=dtype, exclude_zero=True),
+ ),
+ )
+
+ def test_div_prim(self):
+ def fn(a, b):
+ return (torch.ops.prims.div(a, b),)
+
+ for dtype in (torch.float32, torch.int64):
+ self.common(
+ fn,
+ (
+ make_tensor(100, device="cpu", dtype=dtype),
+ make_tensor(100, device="cpu", dtype=dtype, exclude_zero=True),
+ ),
+ )
+
def test_both_scalars(self):
def fn(a, b):
return (
@@ -2589,6 +2629,25 @@
shape = [1, 2, 6, 6]
self.common(fn, (torch.randn(shape), torch.randn(shape)))
+ def test_fmod_zero_dim(self):
+ def fn(a, b):
+ return (torch.fmod(a, b),)
+
+ self.common(
+ fn,
+ (
+ make_tensor(10, device="cpu", dtype=torch.float32),
+ make_tensor((), device="cpu", dtype=torch.float32),
+ ),
+ )
+ self.common(
+ fn,
+ (
+ make_tensor((), device="cpu", dtype=torch.float32),
+ make_tensor(10, device="cpu", dtype=torch.float32),
+ ),
+ )
+
def test_log2(self):
def fn(x):
return torch.log2(x), torch.log2(x + 1) - 2
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index a76a9ba..0bd9200 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -3354,7 +3354,7 @@
return ops.truncdiv(a, b)
-@register_lowering(aten.div.Tensor_mode)
+@register_lowering(aten.div, broadcast=True)
def div_mode(a, b, rounding_mode=None):
both_integer = is_integer_type(a) and is_integer_type(b)
both_boolean = is_boolean_type(a) and is_boolean_type(b)
@@ -3370,23 +3370,6 @@
return div(a, b)
-@register_lowering([aten.div], broadcast=True)
-def div(a, b):
- def fn(*args):
- return ops.div(*args)
-
- dtype = get_promoted_dtype(
- a, b, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
- )
- # truediv produces a float tensor even if both operands are integer types
- if is_integer_type(a) and is_integer_type(b):
- dtype = torch.get_default_dtype()
- return make_pointwise(fn, override_return_dtype=dtype)(
- a if isinstance(a, Number) else to_dtype(a, dtype),
- b if isinstance(b, Number) else to_dtype(b, dtype),
- )
-
-
@register_lowering([aten.mul], broadcast=True)
def mul(a, b):
both_bool = is_boolean_type(a) and is_boolean_type(b)
@@ -3397,21 +3380,29 @@
return make_pointwise(fn)(a, b)
-# TODO(lezcano) I believe the casting behaviour of prims.div is wrong
-# https://github.com/pytorch/pytorch/issues/84412
-# div prim performs truncation division on integer inputs
-# and true division for floating and complex inputs
+# NOTE: prims.div maps to a / b in C, so performs truncation division on
+# integer inputs and true division for floating and complex inputs.
@register_lowering([prims.div], broadcast=True)
def div_prim(a, b):
is_integral = is_boolean_type(a) or is_integer_type(a)
if is_integral:
- return div_mode(a, b, rounding_mode="floor")
- else:
- return div(a, b)
+ return truncdiv(a, b)
+
+ def fn(*args):
+ return ops.div(*args)
+
+ return make_pointwise(fn)(a, b)
-@register_lowering([aten.fmod, prims.fmod])
+div = register_lowering(
+ [aten.true_divide, aten.div.Tensor],
+ broadcast=True,
+ type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+)(div_prim)
+
+
+@register_lowering([aten.fmod, prims.fmod], broadcast=True)
def fmod(a, b):
is_integral = is_boolean_type(a) or is_integer_type(a)
@@ -3564,7 +3555,8 @@
register_inplace(aten.add_, add)
register_inplace(aten.mul_, mul)
-register_inplace(aten.div_, div)
+register_inplace(aten.div_.Tensor, div)
+register_inplace(aten.div_.Tensor_mode, div_mode)
register_inplace(aten.sub_, sub)
register_inplace(aten.relu_, relu)
register_inplace(aten.sigmoid_, sigmoid)