[ONNX] Fix remainder export (#64230) (#64578)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64578

* Fix remainder export for edge case when input is negative. New export relies on true_divide export.
* Simplified true_divide export. Cleaned up redundant code which is handled by scalar type analysis pass. Removed dependency on `onnx::Where`, thus supports opset 7 & 8.

Fixes #60179

Test Plan: Imported from OSS

Reviewed By: jansel

Differential Revision: D30919601

Pulled By: malfet

fbshipit-source-id: 0f78621c0ac3bdb6bf4225e049ba5f470dc8ab12

Co-authored-by: BowenBao <bowbao@microsoft.com>
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 3451138..e1c5c6b 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -1713,7 +1713,7 @@
                     x.to(dtype=torch.float64) // y.to(dtype=torch.int64), x.to(dtype=torch.float64) // y.to(dtype=torch.float64), \
                     x.to(dtype=torch.int64) // y.to(dtype=torch.int64), x.to(dtype=torch.int64) // y
 
-        x = torch.randn(2, 3, 4)
+        x = torch.arange(-2, 4).reshape(2, 3, 1)
         y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4)
         self.run_test(FloorDivModule(), (x, y))
 
@@ -1723,7 +1723,7 @@
             def forward(self, x, y):
                 return x // 3, x // 2., x // y
 
-        x = torch.randn(2, 3, 4)
+        x = torch.arange(-2, 4).reshape(2, 3, 1)
         y = torch.randn(2, 3, 4)
         self.run_test(FloorDivModule(), (x, y))
 
@@ -1821,9 +1821,7 @@
                 return (x.div(y, rounding_mode="floor"),
                         torch.div(x, y, rounding_mode="floor"))
 
-        modules = [TrueDivModule(), TruncDivModule()]
-        if self.opset_version >= 9:
-            modules.append(FloorDivModule())
+        modules = [TrueDivModule(), TruncDivModule(), FloorDivModule()]
 
         x = (torch.randn(2, 3, 4) * 100).to(torch.int)
         y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
@@ -4258,6 +4256,13 @@
         y = torch.randint(10, (2, 3, 5), dtype=torch.long)
         self.run_test(XorModel(), input=(x, y))
 
+    @skipIfUnsupportedMinOpsetVersion(11)  # float equal added after opset 11
+    def test_eq(self):
+        class EqualModel(torch.nn.Module):
+            def forward(self, input, other):
+                return input == other
+        self._test_compare_ops(EqualModel(), 2)
+
     def test_gt(self):
         class GreaterModel(torch.nn.Module):
             def forward(self, input, other):
@@ -5954,14 +5959,34 @@
         y = torch.randn(1, 2, 1)
         self.run_test(RemainderModel(), (x, y))
 
+        x = torch.tensor([7, 6, -7, -6], dtype=torch.long)
+        y = torch.tensor([2], dtype=torch.long)
+        self.run_test(RemainderModel(), (x, y))
+
+        x = x.to(torch.float)
+        self.run_test(RemainderModel(), (x, y))
+
+        y = y.to(torch.float)
+        self.run_test(RemainderModel(), (x, y))
+
+        x = x.to(torch.int32)
+        self.run_test(RemainderModel(), (x, y))
+
     def test_remainder_scalar(self):
         class RemainderModel(torch.nn.Module):
+            def __init__(self, scalar=2.55):
+                super().__init__()
+                self.scalar = scalar
+
             def forward(self, input):
-                return torch.remainder(input, 2.55)
+                return torch.remainder(input, self.scalar)
 
         x = torch.randint(10, (2, 3))
         self.run_test(RemainderModel(), x)
 
+        x = torch.tensor([7, 6, -7, -6], dtype=torch.long)
+        self.run_test(RemainderModel(2), x)
+
     @skipIfUnsupportedMinOpsetVersion(10)
     def test_fmod(self):
         class FModModel(torch.nn.Module):
diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py
index 13bc480..7f24809 100644
--- a/torch/onnx/symbolic_helper.py
+++ b/torch/onnx/symbolic_helper.py
@@ -376,6 +376,15 @@
         return g.op("TopK", input, k, axis_i=dim, largest_i=largest, sorted_i=sorted, outputs=2)
 
 
+def _lt_helper(g, input, other):
+    if _export_onnx_opset_version <= 8:
+        from torch.onnx.symbolic_opset8 import lt as _lt8
+        return _lt8(g, input, other)
+    else:
+        from torch.onnx.symbolic_opset9 import lt as _lt9
+        return _lt9(g, input, other)
+
+
 def _interpolate_warning(interpolate_mode):
     onnx_op = "onnx:Resize" if _export_onnx_opset_version >= 10 else "onnx:Upsample"
     warnings.warn("You are trying to export the model with " + onnx_op + " for ONNX opset version "
diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py
index e772bbf..9aa9e16 100644
--- a/torch/onnx/symbolic_opset11.py
+++ b/torch/onnx/symbolic_opset11.py
@@ -393,6 +393,13 @@
     return g.op("Round", self)
 
 
+def remainder(g, input, other):
+    if sym_help._is_fp(input) or sym_help._is_fp(other):
+        from torch.onnx.symbolic_opset9 import remainder as _remainder_9
+        return _remainder_9(g, input, other)
+    return g.op("Mod", input, other, fmod_i=0)
+
+
 @parse_args("v", "v", "i", "i")
 def split(g, self, split_size_or_sizes, dim, _outputs=None):
     if not sym_help._is_split_static(split_size_or_sizes, _outputs):
diff --git a/torch/onnx/symbolic_opset7.py b/torch/onnx/symbolic_opset7.py
index a94010b..4bb62b1 100644
--- a/torch/onnx/symbolic_opset7.py
+++ b/torch/onnx/symbolic_opset7.py
@@ -1,5 +1,4 @@
-from torch.onnx.symbolic_helper import _block_list_in_opset, parse_args
-import torch.onnx.symbolic_helper as sym_help
+from torch.onnx.symbolic_helper import _block_list_in_opset
 
 import torch.onnx.symbolic_opset9 as sym_opset9
 
@@ -44,28 +43,5 @@
     return sym_opset9.min(g, self, dim_or_y, keepdim)
 
 
-def div(g, self, other, *args):
-    if len(args) == 0:
-        return sym_opset9.true_divide(g, self, other)
-    else:
-        return _div_rounding_mode(g, self, other, *args)
-
-
-@parse_args("v", "v", "s")
-def _div_rounding_mode(g, self, other, rounding_mode):
-    if rounding_mode == "floor":
-        return _floor_divide(g, self, other)
-    else:
-        return sym_opset9._div_rounding_mode(g, self, other, rounding_mode)
-
-
-def _floor_divide(g, self, other):
-    if sym_help._is_fp(self) or sym_help._is_fp(other):
-        out = sym_opset9.true_divide(g, self, other)
-        return g.op("Floor", out)
-    else:
-        raise RuntimeError("Integer floor division requires ONNX opset 9 or greater")
-
-
 for block_listed_op in block_listed_operators:
     vars()[block_listed_op] = _block_list_in_opset(block_listed_op)
diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py
index 6b1152a..cec192b 100644
--- a/torch/onnx/symbolic_opset8.py
+++ b/torch/onnx/symbolic_opset8.py
@@ -5,7 +5,6 @@
 
 from torch.onnx.symbolic_helper import parse_args, _unimplemented, _block_list_in_opset, _try_get_scalar_type
 from torch.onnx.symbolic_opset9 import _cast_Float  # type: ignore[attr-defined]
-from torch.onnx.symbolic_opset7 import div  # noqa: F401
 
 import warnings
 
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index 3522e77..5686737 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -163,8 +163,8 @@
         # Division is negative if: self < 0 != other < 0
         zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
         negative = g.op("Xor",
-                        g.op("Less", self, zero),
-                        g.op("Less", other, zero))
+                        sym_help._lt_helper(g, self, zero),
+                        sym_help._lt_helper(g, other, zero))
 
         # For negative numbers with self % other != 0, subtract 1 to round down instead of up
         mod = g.op("Sub", self, g.op("Mul", div, other))
@@ -172,8 +172,8 @@
                           g.op("Not", g.op("Equal", mod, zero)))
 
         one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
-        fixup = g.op("Sub", div, one)
-        return g.op("Where", fixup_mask, fixup, div)
+        fixup = g.op("Mul", fixup_mask, one)
+        return g.op("Sub", div, fixup)
 
 
 def floor_divide(g, self, other):
@@ -189,24 +189,13 @@
 # If only one input is a floating type, the other input is cast to its type
 # If neither input is a floating type, both inputs are cast to the default scalar type
 def true_divide(g, self, other):
-    # Case 1: both values are floating
-    # Performs div as usual
-    if sym_help._is_fp(self) and sym_help._is_fp(other):
+    # Case 1: either values are floating
+    # Performs div as usual.
+    # Implicit casting will be handled in scalar type analysis pass.
+    if sym_help._is_fp(self) or sym_help._is_fp(other):
         return g.op("Div", self, other)
 
-    # Case 2: self is floating, other is not
-    # Casts other to self's dtype
-    if sym_help._is_fp(self):
-        other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()])
-        return g.op("Div", self, other)
-
-    # Case 3: other is floating, self is not
-    # Casts self to other's dtype
-    if sym_help._is_fp(other):
-        self = g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[other.type().scalarType()])
-        return g.op("Div", self, other)
-
-    # Case 4: neither is floating
+    # Case 2: neither is floating
     # Casts both inputs to the default scalar type
     scalar_type = torch.get_default_dtype()
     onnx_scalar_type = sym_help.cast_pytorch_to_onnx["Float"]
@@ -2952,9 +2941,7 @@
 
 
 def remainder(g, input, other):
-    div = g.op("Div", input, other)
-    if sym_help._is_fp(input) or sym_help._is_fp(other):
-        div = g.op("Floor", div)
+    div = _floor_divide(g, input, other)
     quo = g.op("Mul", div, other)
     return g.op("Sub", input, quo)