[inductor] Misc division lowering fixes (#88603)

1. `aten.div.Tensor_mode` should allow broadcasting
2. `div` can use `ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT`
3. `prims.div` on integers should be truncating division
4. Add lowering for `true_divide` which is aliased to `div`
5. register lowering for inplace version of `div_mode`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88603
Approved by: https://github.com/ngimel
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index ec024c6..2196f4f 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -23,6 +23,7 @@
 from torch.fx.experimental.proxy_tensor import make_fx
 from torch.fx.passes.shape_prop import ShapeProp
 from torch.nn import functional as F
+from torch.testing import make_tensor
 from torch.testing._internal.common_utils import (
     TEST_WITH_ASAN,
     TEST_WITH_ROCM,
@@ -1166,6 +1167,45 @@
 
         self.common(fn, (1024, 100))
 
+    def test_div_zero_dim(self):
+        def fn(a, b):
+            return (
+                aten.div(a, b, rounding_mode=None),
+                aten.div(a, b, rounding_mode="floor"),
+                aten.div(a, b, rounding_mode="trunc"),
+                a / b,
+                a // b,
+            )
+
+        for dtype in (torch.float32, torch.int64):
+            self.common(
+                fn,
+                (
+                    make_tensor(10, device="cpu", dtype=dtype),
+                    make_tensor((), device="cpu", dtype=dtype, exclude_zero=True),
+                ),
+            )
+            self.common(
+                fn,
+                (
+                    make_tensor((), device="cpu", dtype=dtype),
+                    make_tensor(10, device="cpu", dtype=dtype, exclude_zero=True),
+                ),
+            )
+
+    def test_div_prim(self):
+        def fn(a, b):
+            return (torch.ops.prims.div(a, b),)
+
+        for dtype in (torch.float32, torch.int64):
+            self.common(
+                fn,
+                (
+                    make_tensor(100, device="cpu", dtype=dtype),
+                    make_tensor(100, device="cpu", dtype=dtype, exclude_zero=True),
+                ),
+            )
+
     def test_both_scalars(self):
         def fn(a, b):
             return (
@@ -2589,6 +2629,25 @@
         shape = [1, 2, 6, 6]
         self.common(fn, (torch.randn(shape), torch.randn(shape)))
 
+    def test_fmod_zero_dim(self):
+        def fn(a, b):
+            return (torch.fmod(a, b),)
+
+        self.common(
+            fn,
+            (
+                make_tensor(10, device="cpu", dtype=torch.float32),
+                make_tensor((), device="cpu", dtype=torch.float32),
+            ),
+        )
+        self.common(
+            fn,
+            (
+                make_tensor((), device="cpu", dtype=torch.float32),
+                make_tensor(10, device="cpu", dtype=torch.float32),
+            ),
+        )
+
     def test_log2(self):
         def fn(x):
             return torch.log2(x), torch.log2(x + 1) - 2
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index a76a9ba..0bd9200 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -3354,7 +3354,7 @@
     return ops.truncdiv(a, b)
 
 
-@register_lowering(aten.div.Tensor_mode)
+@register_lowering(aten.div, broadcast=True)
 def div_mode(a, b, rounding_mode=None):
     both_integer = is_integer_type(a) and is_integer_type(b)
     both_boolean = is_boolean_type(a) and is_boolean_type(b)
@@ -3370,23 +3370,6 @@
     return div(a, b)
 
 
-@register_lowering([aten.div], broadcast=True)
-def div(a, b):
-    def fn(*args):
-        return ops.div(*args)
-
-    dtype = get_promoted_dtype(
-        a, b, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
-    )
-    # truediv produces a float tensor even if both operands are integer types
-    if is_integer_type(a) and is_integer_type(b):
-        dtype = torch.get_default_dtype()
-    return make_pointwise(fn, override_return_dtype=dtype)(
-        a if isinstance(a, Number) else to_dtype(a, dtype),
-        b if isinstance(b, Number) else to_dtype(b, dtype),
-    )
-
-
 @register_lowering([aten.mul], broadcast=True)
 def mul(a, b):
     both_bool = is_boolean_type(a) and is_boolean_type(b)
@@ -3397,21 +3380,29 @@
         return make_pointwise(fn)(a, b)
 
 
-# TODO(lezcano) I believe the casting behaviour of prims.div is wrong
-# https://github.com/pytorch/pytorch/issues/84412
-# div prim performs truncation division on integer inputs
-#   and true division for floating and complex inputs
+# NOTE: prims.div maps to a / b in C, so performs truncation division on
+#   integer inputs and true division for floating and complex inputs.
 @register_lowering([prims.div], broadcast=True)
 def div_prim(a, b):
     is_integral = is_boolean_type(a) or is_integer_type(a)
 
     if is_integral:
-        return div_mode(a, b, rounding_mode="floor")
-    else:
-        return div(a, b)
+        return truncdiv(a, b)
+
+    def fn(*args):
+        return ops.div(*args)
+
+    return make_pointwise(fn)(a, b)
 
 
-@register_lowering([aten.fmod, prims.fmod])
+div = register_lowering(
+    [aten.true_divide, aten.div.Tensor],
+    broadcast=True,
+    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+)(div_prim)
+
+
+@register_lowering([aten.fmod, prims.fmod], broadcast=True)
 def fmod(a, b):
     is_integral = is_boolean_type(a) or is_integer_type(a)
 
@@ -3564,7 +3555,8 @@
 
 register_inplace(aten.add_, add)
 register_inplace(aten.mul_, mul)
-register_inplace(aten.div_, div)
+register_inplace(aten.div_.Tensor, div)
+register_inplace(aten.div_.Tensor_mode, div_mode)
 register_inplace(aten.sub_, sub)
 register_inplace(aten.relu_, relu)
 register_inplace(aten.sigmoid_, sigmoid)