[ONNX] Fix remainder export (#64230) (#64578)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64578
* Fix remainder export for edge case when input is negative. New export relies on true_divide export.
* Simplified true_divide export. Cleaned up redundant code which is handled by scalar type analysis pass. Removed dependency on `onnx::Where`, thus supports opset 7 & 8.
Fixes #60179
Test Plan: Imported from OSS
Reviewed By: jansel
Differential Revision: D30919601
Pulled By: malfet
fbshipit-source-id: 0f78621c0ac3bdb6bf4225e049ba5f470dc8ab12
Co-authored-by: BowenBao <bowbao@microsoft.com>
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 3451138..e1c5c6b 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -1713,7 +1713,7 @@
x.to(dtype=torch.float64) // y.to(dtype=torch.int64), x.to(dtype=torch.float64) // y.to(dtype=torch.float64), \
x.to(dtype=torch.int64) // y.to(dtype=torch.int64), x.to(dtype=torch.int64) // y
- x = torch.randn(2, 3, 4)
+ x = torch.arange(-2, 4).reshape(2, 3, 1)
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4)
self.run_test(FloorDivModule(), (x, y))
@@ -1723,7 +1723,7 @@
def forward(self, x, y):
return x // 3, x // 2., x // y
- x = torch.randn(2, 3, 4)
+ x = torch.arange(-2, 4).reshape(2, 3, 1)
y = torch.randn(2, 3, 4)
self.run_test(FloorDivModule(), (x, y))
@@ -1821,9 +1821,7 @@
return (x.div(y, rounding_mode="floor"),
torch.div(x, y, rounding_mode="floor"))
- modules = [TrueDivModule(), TruncDivModule()]
- if self.opset_version >= 9:
- modules.append(FloorDivModule())
+ modules = [TrueDivModule(), TruncDivModule(), FloorDivModule()]
x = (torch.randn(2, 3, 4) * 100).to(torch.int)
y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
@@ -4258,6 +4256,13 @@
y = torch.randint(10, (2, 3, 5), dtype=torch.long)
self.run_test(XorModel(), input=(x, y))
+ @skipIfUnsupportedMinOpsetVersion(11) # float equal added after opset 11
+ def test_eq(self):
+ class EqualModel(torch.nn.Module):
+ def forward(self, input, other):
+ return input == other
+ self._test_compare_ops(EqualModel(), 2)
+
def test_gt(self):
class GreaterModel(torch.nn.Module):
def forward(self, input, other):
@@ -5954,14 +5959,34 @@
y = torch.randn(1, 2, 1)
self.run_test(RemainderModel(), (x, y))
+ x = torch.tensor([7, 6, -7, -6], dtype=torch.long)
+ y = torch.tensor([2], dtype=torch.long)
+ self.run_test(RemainderModel(), (x, y))
+
+ x = x.to(torch.float)
+ self.run_test(RemainderModel(), (x, y))
+
+ y = y.to(torch.float)
+ self.run_test(RemainderModel(), (x, y))
+
+ x = x.to(torch.int32)
+ self.run_test(RemainderModel(), (x, y))
+
def test_remainder_scalar(self):
class RemainderModel(torch.nn.Module):
+ def __init__(self, scalar=2.55):
+ super().__init__()
+ self.scalar = scalar
+
def forward(self, input):
- return torch.remainder(input, 2.55)
+ return torch.remainder(input, self.scalar)
x = torch.randint(10, (2, 3))
self.run_test(RemainderModel(), x)
+ x = torch.tensor([7, 6, -7, -6], dtype=torch.long)
+ self.run_test(RemainderModel(2), x)
+
@skipIfUnsupportedMinOpsetVersion(10)
def test_fmod(self):
class FModModel(torch.nn.Module):
diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py
index 13bc480..7f24809 100644
--- a/torch/onnx/symbolic_helper.py
+++ b/torch/onnx/symbolic_helper.py
@@ -376,6 +376,15 @@
return g.op("TopK", input, k, axis_i=dim, largest_i=largest, sorted_i=sorted, outputs=2)
+def _lt_helper(g, input, other):
+ if _export_onnx_opset_version <= 8:
+ from torch.onnx.symbolic_opset8 import lt as _lt8
+ return _lt8(g, input, other)
+ else:
+ from torch.onnx.symbolic_opset9 import lt as _lt9
+ return _lt9(g, input, other)
+
+
def _interpolate_warning(interpolate_mode):
onnx_op = "onnx:Resize" if _export_onnx_opset_version >= 10 else "onnx:Upsample"
warnings.warn("You are trying to export the model with " + onnx_op + " for ONNX opset version "
diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py
index e772bbf..9aa9e16 100644
--- a/torch/onnx/symbolic_opset11.py
+++ b/torch/onnx/symbolic_opset11.py
@@ -393,6 +393,13 @@
return g.op("Round", self)
+def remainder(g, input, other):
+ if sym_help._is_fp(input) or sym_help._is_fp(other):
+ from torch.onnx.symbolic_opset9 import remainder as _remainder_9
+ return _remainder_9(g, input, other)
+ return g.op("Mod", input, other, fmod_i=0)
+
+
@parse_args("v", "v", "i", "i")
def split(g, self, split_size_or_sizes, dim, _outputs=None):
if not sym_help._is_split_static(split_size_or_sizes, _outputs):
diff --git a/torch/onnx/symbolic_opset7.py b/torch/onnx/symbolic_opset7.py
index a94010b..4bb62b1 100644
--- a/torch/onnx/symbolic_opset7.py
+++ b/torch/onnx/symbolic_opset7.py
@@ -1,5 +1,4 @@
-from torch.onnx.symbolic_helper import _block_list_in_opset, parse_args
-import torch.onnx.symbolic_helper as sym_help
+from torch.onnx.symbolic_helper import _block_list_in_opset
import torch.onnx.symbolic_opset9 as sym_opset9
@@ -44,28 +43,5 @@
return sym_opset9.min(g, self, dim_or_y, keepdim)
-def div(g, self, other, *args):
- if len(args) == 0:
- return sym_opset9.true_divide(g, self, other)
- else:
- return _div_rounding_mode(g, self, other, *args)
-
-
-@parse_args("v", "v", "s")
-def _div_rounding_mode(g, self, other, rounding_mode):
- if rounding_mode == "floor":
- return _floor_divide(g, self, other)
- else:
- return sym_opset9._div_rounding_mode(g, self, other, rounding_mode)
-
-
-def _floor_divide(g, self, other):
- if sym_help._is_fp(self) or sym_help._is_fp(other):
- out = sym_opset9.true_divide(g, self, other)
- return g.op("Floor", out)
- else:
- raise RuntimeError("Integer floor division requires ONNX opset 9 or greater")
-
-
for block_listed_op in block_listed_operators:
vars()[block_listed_op] = _block_list_in_opset(block_listed_op)
diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py
index 6b1152a..cec192b 100644
--- a/torch/onnx/symbolic_opset8.py
+++ b/torch/onnx/symbolic_opset8.py
@@ -5,7 +5,6 @@
from torch.onnx.symbolic_helper import parse_args, _unimplemented, _block_list_in_opset, _try_get_scalar_type
from torch.onnx.symbolic_opset9 import _cast_Float # type: ignore[attr-defined]
-from torch.onnx.symbolic_opset7 import div # noqa: F401
import warnings
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index 3522e77..5686737 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -163,8 +163,8 @@
# Division is negative if: self < 0 != other < 0
zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
negative = g.op("Xor",
- g.op("Less", self, zero),
- g.op("Less", other, zero))
+ sym_help._lt_helper(g, self, zero),
+ sym_help._lt_helper(g, other, zero))
# For negative numbers with self % other != 0, subtract 1 to round down instead of up
mod = g.op("Sub", self, g.op("Mul", div, other))
@@ -172,8 +172,8 @@
g.op("Not", g.op("Equal", mod, zero)))
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
- fixup = g.op("Sub", div, one)
- return g.op("Where", fixup_mask, fixup, div)
+ fixup = g.op("Mul", fixup_mask, one)
+ return g.op("Sub", div, fixup)
def floor_divide(g, self, other):
@@ -189,24 +189,13 @@
# If only one input is a floating type, the other input is cast to its type
# If neither input is a floating type, both inputs are cast to the default scalar type
def true_divide(g, self, other):
- # Case 1: both values are floating
- # Performs div as usual
- if sym_help._is_fp(self) and sym_help._is_fp(other):
+ # Case 1: either values are floating
+ # Performs div as usual.
+ # Implicit casting will be handled in scalar type analysis pass.
+ if sym_help._is_fp(self) or sym_help._is_fp(other):
return g.op("Div", self, other)
- # Case 2: self is floating, other is not
- # Casts other to self's dtype
- if sym_help._is_fp(self):
- other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()])
- return g.op("Div", self, other)
-
- # Case 3: other is floating, self is not
- # Casts self to other's dtype
- if sym_help._is_fp(other):
- self = g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[other.type().scalarType()])
- return g.op("Div", self, other)
-
- # Case 4: neither is floating
+ # Case 2: neither is floating
# Casts both inputs to the default scalar type
scalar_type = torch.get_default_dtype()
onnx_scalar_type = sym_help.cast_pytorch_to_onnx["Float"]
@@ -2952,9 +2941,7 @@
def remainder(g, input, other):
- div = g.op("Div", input, other)
- if sym_help._is_fp(input) or sym_help._is_fp(other):
- div = g.op("Floor", div)
+ div = _floor_divide(g, input, other)
quo = g.op("Mul", div, other)
return g.op("Sub", input, quo)