Allow for [-oo, oo] ranges for bools (#114362)
This fixes a problem in Seamless M4T in fairseq2 repro
instructions at https://docs.google.com/document/d/1PVy4KibfljirQDoijOwyHCV97B67r_iElWqFh7h1Acc/edit
I tried extracting a minimal repro but I couldn't actually manage it!
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114362
Approved by: https://github.com/Skylion007
diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py
index 45e6222..bb84795 100644
--- a/torch/utils/_sympy/value_ranges.py
+++ b/torch/utils/_sympy/value_ranges.py
@@ -79,6 +79,14 @@
object.__setattr__(self, "is_bool", isinstance(lower, SympyBoolean))
assert isinstance(upper, SympyBoolean) == self.is_bool
+ def boolify(self):
+ if self.is_bool:
+ return self
+ elif self == ValueRanges.unknown():
+ return ValueRanges.unknown_bool()
+ else:
+ raise AssertionError(f"not bool like {self}")
+
def __contains__(self, x):
x = simple_sympify(x)
return sympy_generic_le(self.lower, x) and sympy_generic_le(x, self.upper)
@@ -120,6 +128,10 @@
return cls(-sympy.oo, sympy.oo)
@classmethod
+ def unknown_bool(cls):
+ return cls(sympy.false, sympy.true)
+
+ @classmethod
def wrap(cls, arg):
if isinstance(arg, ValueRanges):
return arg
@@ -215,6 +227,7 @@
@staticmethod
def not_(a):
a = ValueRanges.wrap(a)
+ a = a.boolify()
assert a.is_bool
return ValueRanges.decreasing_map(a, sympy.Not)
@@ -469,7 +482,7 @@
def where(a, b, c):
b = ValueRanges.wrap(b)
c = ValueRanges.wrap(c)
- assert a.is_bool
+ a = a.boolify()
assert b.is_bool == c.is_bool
if b.is_bool:
return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper))
@@ -481,7 +494,7 @@
# and defer the analysis to piecewise
@staticmethod
def expr_cond_pair(a, b):
- assert b.is_bool, f"expect cond_expr's ValueRange to be a boolean range but got {b}"
+ b = b.boolify()
return (a, b)
# piecewise function can be used to convert a SymBool to SymInt: