Support SymBool input to torch.compile (#107850)

We could have SymBool inputs for torch.compile, e.g. in the following situation:
```
def f(x:torch.Tensor):
  pred = x.size(0) == 3
  torch.compile(f)(pred, x)

make_fx(f, tracing_mode="symbolic")(x)
```

The idea of this PR (credit to @ezyang) is to support SymBool by re-using the infra we've already had for SymInt so that we don't need to replicate a lot of stuff.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107850
Approved by: https://github.com/ezyang
ghstack dependencies: #107662
diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py
index 0ba53e2..d5a3818 100644
--- a/test/dynamo/test_export.py
+++ b/test/dynamo/test_export.py
@@ -19,11 +19,17 @@
 from functorch.experimental.control_flow import cond
 from torch._dynamo import config
 from torch._dynamo.exc import UserError
+from torch._dynamo.testing import normalize_gm
 from torch._export import dynamic_dim
 from torch._export.constraints import constrain_as_size, constrain_as_value
 from torch._higher_order_ops.out_dtype import out_dtype
+from torch._subclasses import fake_tensor
 from torch.fx.experimental.proxy_tensor import make_fx
-from torch.fx.experimental.symbolic_shapes import ConstraintViolationError
+from torch.fx.experimental.symbolic_shapes import (
+    ConstraintViolationError,
+    DimDynamic,
+    ShapeEnv,
+)
 from torch.testing._internal import common_utils
 
 
@@ -3155,8 +3161,6 @@
 
     def test_capture_symbolic_tracing_simple_within_fake_mode(self):
         from torch._dynamo.output_graph import config
-        from torch._subclasses import fake_tensor
-        from torch.fx.experimental.symbolic_shapes import ShapeEnv
 
         def f(x):
             y = torch.randn(3)
@@ -3178,6 +3182,76 @@
                     + str(aten_graph),
                 )
 
+    def test_export_with_symbool_inputs(self):
+        def f(pred: bool, x: torch.Tensor):
+            if pred:
+                return x.sin()
+            else:
+                return x.cos()
+
+        x = torch.randn([3, 4])
+
+        def test_symbool_guards(
+            f, size_tests, exp_graph, exp_guard_code, exp_shape_env_guards
+        ):
+            shape_env = ShapeEnv()
+            with fake_tensor.FakeTensorMode(
+                shape_env=shape_env,
+            ) as fake_mode:
+                fake_x = fake_mode.from_tensor(
+                    x, dynamic_dims=[DimDynamic.DYNAMIC for _ in range(x.dim())]
+                )
+                for i, size in enumerate(size_tests):
+                    pred = fake_x.size(0) == size
+                    gm, guards = torch._dynamo.export(f)(pred, x)
+                    actual = normalize_gm(gm.print_readable(print_output=False))
+                    self.assertExpectedInline(actual, exp_graph[i])
+                    dynamo_shape_env_guards = [
+                        guard
+                        for guard in guards
+                        if guard.guard_types is not None
+                        and "SHAPE_ENV" in guard.guard_types
+                    ]
+                    self.assertEqual(len(dynamo_shape_env_guards), 1)
+                    guard_code_on_predicate = [
+                        code
+                        for code in dynamo_shape_env_guards[0].code_list
+                        if "L['pred']" in code
+                    ]
+                    self.assertEqual(guard_code_on_predicate, exp_guard_code[i])
+                    outter_shape_env_guards = [
+                        str(guard.expr) for guard in shape_env.guards
+                    ]
+                    self.assertEqual(outter_shape_env_guards, exp_shape_env_guards[i])
+
+        true_graph = """\
+class GraphModule(torch.nn.Module):
+    def forward(self, pred, x):
+        arg0, arg1: f32[s1, s2], = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
+        sin = arg1.sin();  arg1 = None
+        return pytree.tree_unflatten([sin], self._out_spec)
+"""
+        false_graph = """\
+class GraphModule(torch.nn.Module):
+    def forward(self, pred, x):
+        arg0, arg1: f32[s1, s2], = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
+        cos = arg1.cos();  arg1 = None
+        return pytree.tree_unflatten([cos], self._out_spec)
+"""
+        true_guard_code = ["cast_symbool_to_symint_guardless(L['pred']) == 1"]
+        false_guard_code = [
+            "Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
+            "-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])",
+        ]
+        test_symbool_guards(
+            f,
+            [3, 3, 4, 5],
+            [true_graph, true_graph, false_graph, false_graph],
+            [true_guard_code, true_guard_code, false_guard_code, false_guard_code],
+            # Outter shape env should have no guards in it because we never specialize on the outter symbool.
+            [[], [], [], []],
+        )
+
     def test_invalid_input_global(self) -> None:
         global bulbous_bouffant
         bulbous_bouffant = torch.randn(3)
diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py
index 38513a6..9dcd89a 100644
--- a/test/dynamo/test_subclasses.py
+++ b/test/dynamo/test_subclasses.py
@@ -436,6 +436,96 @@
         self.assertEqual(lower_bound_str, expected_lower_bound)
         self.assertEqual(upper_bound_str, expected_upper_bound)
 
+    def test_recompile_with_symbool_inputs(self):
+        def f(pred: bool):
+            if pred:
+                return torch.ones([3, 4])
+            else:
+                return torch.ones([4, 3])
+
+        def test_recompilation(
+            f, x, sizes, exp_graphs, exp_frame_count, exp_shape_env_guards
+        ):
+            torch._dynamo.reset()
+            shape_env = ShapeEnv()
+            backend = torch._dynamo.testing.EagerAndRecordGraphs()
+            cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
+            f_cond = torch.compile(f, backend=cnt, fullgraph=True)
+            with torch._subclasses.fake_tensor.FakeTensorMode(
+                shape_env=shape_env
+            ) as fake_mode:
+                fake_inp = fake_mode.from_tensor(
+                    x, dynamic_dims=[DimDynamic.DYNAMIC for i in range(x.dim())]
+                )
+                for i, size in enumerate(sizes):
+                    pred = fake_inp.size(0) == size
+                    f_cond(pred)
+                    actual = normalize_gm(
+                        backend.graphs[exp_frame_count[i] - 1].print_readable(
+                            print_output=False
+                        )
+                    )
+                    actual_guard_str = [str(guard.expr) for guard in shape_env.guards]
+                    self.assertExpectedInline(actual, exp_graphs[i])
+                    self.assertEqual(cnt.frame_count, exp_frame_count[i])
+                    self.assertEqual(actual_guard_str, exp_shape_env_guards[i])
+
+        true_graph = """\
+class GraphModule(torch.nn.Module):
+    def forward(self):
+        ones = torch.ones([3, 4])
+        return (ones,)
+"""
+        false_graph = """\
+class GraphModule(torch.nn.Module):
+    def forward(self):
+        ones = torch.ones([4, 3])
+        return (ones,)
+"""
+        test_recompilation(
+            f,
+            torch.randn([3, 4]),
+            [3, 3, 4, 5],
+            exp_graphs=[true_graph, true_graph, false_graph, false_graph],
+            exp_frame_count=[1, 1, 2, 2],
+            exp_shape_env_guards=[
+                [],
+                # s0 is specialized and guarded in outter shape_env when dynamo checks the guards
+                ["Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)"],
+                [
+                    "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
+                    "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)",
+                ],
+                [
+                    "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
+                    "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)",
+                    "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
+                ],
+            ],
+        )
+
+        test_recompilation(
+            f,
+            torch.randn([3, 4]),
+            [4, 5, 3, 3],
+            exp_graphs=[false_graph, false_graph, true_graph, true_graph],
+            exp_frame_count=[1, 1, 2, 2],
+            exp_shape_env_guards=[
+                [],
+                # s0 is specialized and guarded in outter shape_env when dynamo checks the guards
+                ["Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)"],
+                [
+                    "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
+                    "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
+                ],
+                [
+                    "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
+                    "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
+                    "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
+                ],
+            ],
+        )
+
 
 if __name__ == "__main__":
     from torch._dynamo.test_case import run_tests
diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py
index 963f8c2..92fa0b7 100644
--- a/test/functorch/test_control_flow.py
+++ b/test/functorch/test_control_flow.py
@@ -267,7 +267,6 @@
         graph = make_fx(f, tracing_mode="symbolic")(x, torch.tensor(False), torch.tensor(False))
         self.assertEqual(graph(x, torch.tensor(True), torch.tensor(True)), f(x, torch.tensor(True), torch.tensor(True)))
 
-    @unittest.expectedFailure
     def test_cond_functionalized(self):
         def true_fn(x):
             y = x.sin()
@@ -313,7 +312,6 @@
         gm_functional = make_fx(torch.func.functionalize(gm_non_functional), tracing_mode="real")(inp)
         self.assertEqual(gm_functional(torch.zeros(1, 2)), f(torch.zeros(1, 2)))
 
-    @unittest.expectedFailure
     def test_cond_functionalized_nested(self):
         def true_true_fn(x):
             y = x.cos()
@@ -1216,7 +1214,6 @@
         self.assertEqual(res, main(p, pred, xs, y))
         self.check_map_count(gm, 2)
 
-    @unittest.expectedFailure
     def test_cond_with_sym_pred(self):
         def true_fn(x):
             return x + x
@@ -1228,10 +1225,28 @@
             return cond(x.shape[0] == 4, true_fn, false_fn, [x])
 
         gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 2, 1))
+        # The symbols in make_fx's shape_env should not be speciliazed.
+        self.assertEqual(len(gm.shape_env.guards), 0)
+
+        exp_code = """\
+def forward(self, x_1):
+    sym_size = torch.ops.aten.sym_size(x_1, 0)
+    eq = sym_size == 4;  sym_size = None
+    true_graph_0 = self.true_graph_0
+    false_graph_0 = self.false_graph_0
+    conditional = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]);  \
+eq = true_graph_0 = false_graph_0 = x_1 = None
+    return conditional
+"""
+        self._expected_inline_normalized(gm.code, exp_code)
+
+
+        # We expect the traced graph module to work even if input size changes.
         x = torch.ones(4, 3, 2)
         self.assertEqual(gm(x), true_fn(x))
         self.assertEqual(foo(x), true_fn(x))
 
+
     def _check_closure_correctly_lifted(self, f, *, args, exp_res, exp_arg_num):
         assert isinstance(args, (tuple, list))
         self.assertEqual(f(*args), exp_res)
diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py
index e3cf0e4..7728b64 100644
--- a/torch/_dynamo/source.py
+++ b/torch/_dynamo/source.py
@@ -271,6 +271,21 @@
 
 
 @dataclasses.dataclass(frozen=True)
+class ConvertIntSource(ChainedSource):
+    def __post_init__(self):
+        assert self.base is not None
+
+    def reconstruct(self, codegen):
+        return self.base.reconstruct(codegen)
+
+    def guard_source(self):
+        return self.base.guard_source()
+
+    def name(self):
+        return f"cast_symbool_to_symint_guardless({self.base.name()})"
+
+
+@dataclasses.dataclass(frozen=True)
 class DefaultsSource(ChainedSource):
     idx_key: Union[int, str]
     is_kw: bool = False
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index 2ed000d..95dc05b 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -44,6 +44,7 @@
 from ..source import (
     AttrSource,
     ConstantSource,
+    ConvertIntSource,
     GetItemSource,
     GlobalWeakRefSource,
     is_constant_source,
@@ -713,6 +714,46 @@
                 source=self.source,
                 guards=make_guards(GuardBuilder.FUNCTION_MATCH),
             )
+        elif isinstance(value, torch.SymBool):
+            # Note: the idea here is to re-use the infra we've built for SymInt by simulating the
+            # user provided SymBool with a SymInt in dynamo.
+
+            # Concretely,
+            # 1. We create a SymInt in dynamo's shape_env, whose source is constructed as ConvertIntSource(self.source).
+            # so that guards on the SymInts can be effectively applied on the original SymBool in user program.
+            # 2. We create a SymBool based on the SymInt in dynamo's ShapeEnv. Because the original user program
+            # depends on the value being a SymBool. This allows dynamo to interpret the user's program correctly.
+
+            value_hint = value.node.require_hint()
+            new_source = ConvertIntSource(self.source)
+
+            new_symint = self.tx.output.shape_env.create_unspecified_symint_and_symbol(
+                int(value_hint),
+                new_source,
+                dynamic_dim=DimDynamic.DYNAMIC,
+            )
+
+            sym_node_proxy = self.tx.output.root_tracer.create_graph_input(
+                re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
+                type(new_symint),
+                source=new_source,
+            )
+
+            sym_node_proxy.node.meta["grapharg"] = GraphArg(
+                new_source,
+                new_symint,
+                False,
+                None,
+                is_tensor=False,
+                example_strong_ref=new_symint,
+            )
+            self.tx.output.tracked_fakes.append(
+                TrackedFake(new_symint, new_source, None)
+            )
+            return SymNodeVariable(
+                sym_node_proxy,
+                new_symint == 1,
+            )
         else:
             result = UserDefinedObjectVariable(
                 value,
diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py
index 22af501..8b4dad7 100644
--- a/torch/_higher_order_ops/cond.py
+++ b/torch/_higher_order_ops/cond.py
@@ -124,7 +124,7 @@
         return cond_op(pred, true_fn, false_fn, operands)
 
     def _validate_input(pred, true_fn, false_fn, operands):
-        if not isinstance(pred, (bool, torch.Tensor)):
+        if not isinstance(pred, (bool, torch.Tensor, torch.SymBool)):
             raise RuntimeError(f"Expected pred to be bool or tensor, but got {pred}.")
 
         if isinstance(pred, torch.Tensor) and pred.numel() != 1:
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index 985986d..3ffc52c 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -1286,6 +1286,10 @@
         op = getattr(operator, method_attr)
     return op
 
+def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt:
+    int_sym = sympy.Piecewise((1, symbool.node.expr), (0, True))
+    return symbool.node.shape_env.create_symintnode(int_sym, hint=int(symbool.node.require_hint()))
+
 SYMPY_INTERP = {
     'Eq': operator.eq,
     'Ne': operator.ne,
@@ -1301,6 +1305,7 @@
     'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense,
     'floor': math.floor,
     'ceiling': math.ceil,
+    'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless,
 }
 
 always_float_magic_methods = {"truediv", "sym_float", "sym_sqrt", "pow"}
@@ -2723,7 +2728,10 @@
     ) -> "sympy.Expr":
         # 'positive' is None for unspecified symbols, since we can't
         # assume that it will be neither positive nor negative.
-        return self.create_symbol(val, source, dynamic_dim, constraint_dim, positive=None)
+
+        # We don't want to specialize zero one val for unspecified symbol
+        # so that we can always get a new symbol despite val.
+        return self.create_symbol(val, source, dynamic_dim, constraint_dim, positive=None, do_not_specialize_zero_one=True)
 
     @record_shapeenv_event()
     def create_symbol(
@@ -2733,7 +2741,13 @@
         dynamic_dim: DimDynamic = DimDynamic.DUCK,
         constraint_dim: DimConstraint = None,  # NB: includes None
         positive: Optional[bool] = True,
+        do_not_specialize_zero_one: bool = False,
     ) -> "sympy.Expr":
+        if do_not_specialize_zero_one:
+            specialize_zero_one = False
+        else:
+            specialize_zero_one = self.specialize_zero_one
+
         assert isinstance(source, Source), f"{type(source)} {source}"
         assert not (positive and val < 0), f"positive set for negative value: {val}"
         # It's always sound to allocate a symbol as DYNAMIC.  If the user
@@ -2753,7 +2767,7 @@
         else:
             raise AssertionError(f"unhandled dynamic_dim {dynamic_dim}")
 
-        if val in (0, 1) and self.specialize_zero_one:
+        if val in (0, 1) and specialize_zero_one:
             r = self.val_to_var[val]
         elif not duck or val not in self.val_to_var:
             # If we're not duck shaping, we always create a new symbol
diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py
index 3eaf4e1..742737b 100644
--- a/torch/utils/_sympy/interp.py
+++ b/torch/utils/_sympy/interp.py
@@ -57,6 +57,8 @@
         sympy.Min: "minimum",
         sympy.Max: "maximum",
         ModularIndexing: "modular_indexing",
+        sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair",
+        sympy.Piecewise: "piecewise",
     }
     return HANDLERS
 
diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py
index 87db5f7..f40db2c 100644
--- a/torch/utils/_sympy/value_ranges.py
+++ b/torch/utils/_sympy/value_ranges.py
@@ -6,7 +6,7 @@
 import math
 import logging
 import torch
-from typing import Union, Dict, Optional
+from typing import Union, Dict, Optional, SupportsFloat
 
 from torch._prims_common import dtype_to_type
 from .interp import sympy_interp
@@ -190,7 +190,7 @@
 
         # using nan makes subsequent computation throw, and for the purposes of optimization
         # returning -math.inf - math.inf is equivalent to giving up
-        if math.isnan(value):
+        if isinstance(value, SupportsFloat) and math.isnan(value):
             return ValueRanges.unknown()
 
         if is_python:
@@ -456,6 +456,30 @@
         else:
             return ValueRanges(sympy.Min(b.lower, c.lower), sympy.Max(b.upper, c.upper))
 
+    # expr_cond_pair is used to represent a single (expr, condition) pair in piecewise.
+    # We just return the value range of the expression and its corresponding condition as a tuple
+    # and defer the analysis to piecewise
+    @staticmethod
+    def expr_cond_pair(a, b):
+        assert b.is_bool, f"expect cond_expr's ValueRange to be a boolean range but got {b}"
+        return (a, b)
+
+    # piecewise function can be used to convert a SymBool to SymInt:
+    # int_expr = Piecewise((1, bool_expr), (0, True)), it evalutes to 1 when sym_bool is True and 0 otherwise.
+    #
+    # ranges is a sequence of (expr_range, condition_range) pairs. The range pair is constructed in expr_cond_pair.
+    # The ValueRange of Piecewise is just the union of all expr ranges whose condition expr can be True.
+    @staticmethod
+    def piecewise(*ranges):
+        init_range = None
+        for expr_range, cond_range in ranges:
+            if sympy.true in cond_range:
+                if init_range is None:
+                    init_range = expr_range
+                else:
+                    init_range = init_range | expr_range
+        return init_range
+
 
 class ValueRangeAnalysis(SymPyValueRangeAnalysis):
     def __init__(self):