Allow for [-oo, oo] ranges for bools (#114362)

This fixes a problem in Seamless M4T in fairseq2 repro
instructions at https://docs.google.com/document/d/1PVy4KibfljirQDoijOwyHCV97B67r_iElWqFh7h1Acc/edit

I tried extracting a minimal repro but I couldn't actually manage it!

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114362
Approved by: https://github.com/Skylion007
diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py
index 45e6222..bb84795 100644
--- a/torch/utils/_sympy/value_ranges.py
+++ b/torch/utils/_sympy/value_ranges.py
@@ -79,6 +79,14 @@
         object.__setattr__(self, "is_bool", isinstance(lower, SympyBoolean))
         assert isinstance(upper, SympyBoolean) == self.is_bool
 
+    def boolify(self):
+        if self.is_bool:
+            return self
+        elif self == ValueRanges.unknown():
+            return ValueRanges.unknown_bool()
+        else:
+            raise AssertionError(f"not bool like {self}")
+
     def __contains__(self, x):
         x = simple_sympify(x)
         return sympy_generic_le(self.lower, x) and sympy_generic_le(x, self.upper)
@@ -120,6 +128,10 @@
         return cls(-sympy.oo, sympy.oo)
 
     @classmethod
+    def unknown_bool(cls):
+        return cls(sympy.false, sympy.true)
+
+    @classmethod
     def wrap(cls, arg):
         if isinstance(arg, ValueRanges):
             return arg
@@ -215,6 +227,7 @@
     @staticmethod
     def not_(a):
         a = ValueRanges.wrap(a)
+        a = a.boolify()
         assert a.is_bool
         return ValueRanges.decreasing_map(a, sympy.Not)
 
@@ -469,7 +482,7 @@
     def where(a, b, c):
         b = ValueRanges.wrap(b)
         c = ValueRanges.wrap(c)
-        assert a.is_bool
+        a = a.boolify()
         assert b.is_bool == c.is_bool
         if b.is_bool:
             return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper))
@@ -481,7 +494,7 @@
     # and defer the analysis to piecewise
     @staticmethod
     def expr_cond_pair(a, b):
-        assert b.is_bool, f"expect cond_expr's ValueRange to be a boolean range but got {b}"
+        b = b.boolify()
         return (a, b)
 
     # piecewise function can be used to convert a SymBool to SymInt: