Support SymBool input to torch.compile (#107850)
We could have SymBool inputs for torch.compile, e.g. in the following situation:
```
def f(x:torch.Tensor):
pred = x.size(0) == 3
torch.compile(f)(pred, x)
make_fx(f, tracing_mode="symbolic")(x)
```
The idea of this PR (credit to @ezyang) is to support SymBool by re-using the infra we've already had for SymInt so that we don't need to replicate a lot of stuff.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107850
Approved by: https://github.com/ezyang
ghstack dependencies: #107662
diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py
index 0ba53e2..d5a3818 100644
--- a/test/dynamo/test_export.py
+++ b/test/dynamo/test_export.py
@@ -19,11 +19,17 @@
from functorch.experimental.control_flow import cond
from torch._dynamo import config
from torch._dynamo.exc import UserError
+from torch._dynamo.testing import normalize_gm
from torch._export import dynamic_dim
from torch._export.constraints import constrain_as_size, constrain_as_value
from torch._higher_order_ops.out_dtype import out_dtype
+from torch._subclasses import fake_tensor
from torch.fx.experimental.proxy_tensor import make_fx
-from torch.fx.experimental.symbolic_shapes import ConstraintViolationError
+from torch.fx.experimental.symbolic_shapes import (
+ ConstraintViolationError,
+ DimDynamic,
+ ShapeEnv,
+)
from torch.testing._internal import common_utils
@@ -3155,8 +3161,6 @@
def test_capture_symbolic_tracing_simple_within_fake_mode(self):
from torch._dynamo.output_graph import config
- from torch._subclasses import fake_tensor
- from torch.fx.experimental.symbolic_shapes import ShapeEnv
def f(x):
y = torch.randn(3)
@@ -3178,6 +3182,76 @@
+ str(aten_graph),
)
+ def test_export_with_symbool_inputs(self):
+ def f(pred: bool, x: torch.Tensor):
+ if pred:
+ return x.sin()
+ else:
+ return x.cos()
+
+ x = torch.randn([3, 4])
+
+ def test_symbool_guards(
+ f, size_tests, exp_graph, exp_guard_code, exp_shape_env_guards
+ ):
+ shape_env = ShapeEnv()
+ with fake_tensor.FakeTensorMode(
+ shape_env=shape_env,
+ ) as fake_mode:
+ fake_x = fake_mode.from_tensor(
+ x, dynamic_dims=[DimDynamic.DYNAMIC for _ in range(x.dim())]
+ )
+ for i, size in enumerate(size_tests):
+ pred = fake_x.size(0) == size
+ gm, guards = torch._dynamo.export(f)(pred, x)
+ actual = normalize_gm(gm.print_readable(print_output=False))
+ self.assertExpectedInline(actual, exp_graph[i])
+ dynamo_shape_env_guards = [
+ guard
+ for guard in guards
+ if guard.guard_types is not None
+ and "SHAPE_ENV" in guard.guard_types
+ ]
+ self.assertEqual(len(dynamo_shape_env_guards), 1)
+ guard_code_on_predicate = [
+ code
+ for code in dynamo_shape_env_guards[0].code_list
+ if "L['pred']" in code
+ ]
+ self.assertEqual(guard_code_on_predicate, exp_guard_code[i])
+ outter_shape_env_guards = [
+ str(guard.expr) for guard in shape_env.guards
+ ]
+ self.assertEqual(outter_shape_env_guards, exp_shape_env_guards[i])
+
+ true_graph = """\
+class GraphModule(torch.nn.Module):
+ def forward(self, pred, x):
+ arg0, arg1: f32[s1, s2], = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
+ sin = arg1.sin(); arg1 = None
+ return pytree.tree_unflatten([sin], self._out_spec)
+"""
+ false_graph = """\
+class GraphModule(torch.nn.Module):
+ def forward(self, pred, x):
+ arg0, arg1: f32[s1, s2], = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
+ cos = arg1.cos(); arg1 = None
+ return pytree.tree_unflatten([cos], self._out_spec)
+"""
+ true_guard_code = ["cast_symbool_to_symint_guardless(L['pred']) == 1"]
+ false_guard_code = [
+ "Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
+ "-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])",
+ ]
+ test_symbool_guards(
+ f,
+ [3, 3, 4, 5],
+ [true_graph, true_graph, false_graph, false_graph],
+ [true_guard_code, true_guard_code, false_guard_code, false_guard_code],
+ # Outter shape env should have no guards in it because we never specialize on the outter symbool.
+ [[], [], [], []],
+ )
+
def test_invalid_input_global(self) -> None:
global bulbous_bouffant
bulbous_bouffant = torch.randn(3)
diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py
index 38513a6..9dcd89a 100644
--- a/test/dynamo/test_subclasses.py
+++ b/test/dynamo/test_subclasses.py
@@ -436,6 +436,96 @@
self.assertEqual(lower_bound_str, expected_lower_bound)
self.assertEqual(upper_bound_str, expected_upper_bound)
+ def test_recompile_with_symbool_inputs(self):
+ def f(pred: bool):
+ if pred:
+ return torch.ones([3, 4])
+ else:
+ return torch.ones([4, 3])
+
+ def test_recompilation(
+ f, x, sizes, exp_graphs, exp_frame_count, exp_shape_env_guards
+ ):
+ torch._dynamo.reset()
+ shape_env = ShapeEnv()
+ backend = torch._dynamo.testing.EagerAndRecordGraphs()
+ cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
+ f_cond = torch.compile(f, backend=cnt, fullgraph=True)
+ with torch._subclasses.fake_tensor.FakeTensorMode(
+ shape_env=shape_env
+ ) as fake_mode:
+ fake_inp = fake_mode.from_tensor(
+ x, dynamic_dims=[DimDynamic.DYNAMIC for i in range(x.dim())]
+ )
+ for i, size in enumerate(sizes):
+ pred = fake_inp.size(0) == size
+ f_cond(pred)
+ actual = normalize_gm(
+ backend.graphs[exp_frame_count[i] - 1].print_readable(
+ print_output=False
+ )
+ )
+ actual_guard_str = [str(guard.expr) for guard in shape_env.guards]
+ self.assertExpectedInline(actual, exp_graphs[i])
+ self.assertEqual(cnt.frame_count, exp_frame_count[i])
+ self.assertEqual(actual_guard_str, exp_shape_env_guards[i])
+
+ true_graph = """\
+class GraphModule(torch.nn.Module):
+ def forward(self):
+ ones = torch.ones([3, 4])
+ return (ones,)
+"""
+ false_graph = """\
+class GraphModule(torch.nn.Module):
+ def forward(self):
+ ones = torch.ones([4, 3])
+ return (ones,)
+"""
+ test_recompilation(
+ f,
+ torch.randn([3, 4]),
+ [3, 3, 4, 5],
+ exp_graphs=[true_graph, true_graph, false_graph, false_graph],
+ exp_frame_count=[1, 1, 2, 2],
+ exp_shape_env_guards=[
+ [],
+ # s0 is specialized and guarded in outter shape_env when dynamo checks the guards
+ ["Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)"],
+ [
+ "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
+ "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)",
+ ],
+ [
+ "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
+ "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)",
+ "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
+ ],
+ ],
+ )
+
+ test_recompilation(
+ f,
+ torch.randn([3, 4]),
+ [4, 5, 3, 3],
+ exp_graphs=[false_graph, false_graph, true_graph, true_graph],
+ exp_frame_count=[1, 1, 2, 2],
+ exp_shape_env_guards=[
+ [],
+ # s0 is specialized and guarded in outter shape_env when dynamo checks the guards
+ ["Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)"],
+ [
+ "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
+ "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
+ ],
+ [
+ "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)",
+ "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
+ "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)",
+ ],
+ ],
+ )
+
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py
index 963f8c2..92fa0b7 100644
--- a/test/functorch/test_control_flow.py
+++ b/test/functorch/test_control_flow.py
@@ -267,7 +267,6 @@
graph = make_fx(f, tracing_mode="symbolic")(x, torch.tensor(False), torch.tensor(False))
self.assertEqual(graph(x, torch.tensor(True), torch.tensor(True)), f(x, torch.tensor(True), torch.tensor(True)))
- @unittest.expectedFailure
def test_cond_functionalized(self):
def true_fn(x):
y = x.sin()
@@ -313,7 +312,6 @@
gm_functional = make_fx(torch.func.functionalize(gm_non_functional), tracing_mode="real")(inp)
self.assertEqual(gm_functional(torch.zeros(1, 2)), f(torch.zeros(1, 2)))
- @unittest.expectedFailure
def test_cond_functionalized_nested(self):
def true_true_fn(x):
y = x.cos()
@@ -1216,7 +1214,6 @@
self.assertEqual(res, main(p, pred, xs, y))
self.check_map_count(gm, 2)
- @unittest.expectedFailure
def test_cond_with_sym_pred(self):
def true_fn(x):
return x + x
@@ -1228,10 +1225,28 @@
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 2, 1))
+ # The symbols in make_fx's shape_env should not be speciliazed.
+ self.assertEqual(len(gm.shape_env.guards), 0)
+
+ exp_code = """\
+def forward(self, x_1):
+ sym_size = torch.ops.aten.sym_size(x_1, 0)
+ eq = sym_size == 4; sym_size = None
+ true_graph_0 = self.true_graph_0
+ false_graph_0 = self.false_graph_0
+ conditional = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]); \
+eq = true_graph_0 = false_graph_0 = x_1 = None
+ return conditional
+"""
+ self._expected_inline_normalized(gm.code, exp_code)
+
+
+ # We expect the traced graph module to work even if input size changes.
x = torch.ones(4, 3, 2)
self.assertEqual(gm(x), true_fn(x))
self.assertEqual(foo(x), true_fn(x))
+
def _check_closure_correctly_lifted(self, f, *, args, exp_res, exp_arg_num):
assert isinstance(args, (tuple, list))
self.assertEqual(f(*args), exp_res)
diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py
index e3cf0e4..7728b64 100644
--- a/torch/_dynamo/source.py
+++ b/torch/_dynamo/source.py
@@ -271,6 +271,21 @@
@dataclasses.dataclass(frozen=True)
+class ConvertIntSource(ChainedSource):
+ def __post_init__(self):
+ assert self.base is not None
+
+ def reconstruct(self, codegen):
+ return self.base.reconstruct(codegen)
+
+ def guard_source(self):
+ return self.base.guard_source()
+
+ def name(self):
+ return f"cast_symbool_to_symint_guardless({self.base.name()})"
+
+
+@dataclasses.dataclass(frozen=True)
class DefaultsSource(ChainedSource):
idx_key: Union[int, str]
is_kw: bool = False
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index 2ed000d..95dc05b 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -44,6 +44,7 @@
from ..source import (
AttrSource,
ConstantSource,
+ ConvertIntSource,
GetItemSource,
GlobalWeakRefSource,
is_constant_source,
@@ -713,6 +714,46 @@
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
)
+ elif isinstance(value, torch.SymBool):
+ # Note: the idea here is to re-use the infra we've built for SymInt by simulating the
+ # user provided SymBool with a SymInt in dynamo.
+
+ # Concretely,
+ # 1. We create a SymInt in dynamo's shape_env, whose source is constructed as ConvertIntSource(self.source).
+ # so that guards on the SymInts can be effectively applied on the original SymBool in user program.
+ # 2. We create a SymBool based on the SymInt in dynamo's ShapeEnv. Because the original user program
+ # depends on the value being a SymBool. This allows dynamo to interpret the user's program correctly.
+
+ value_hint = value.node.require_hint()
+ new_source = ConvertIntSource(self.source)
+
+ new_symint = self.tx.output.shape_env.create_unspecified_symint_and_symbol(
+ int(value_hint),
+ new_source,
+ dynamic_dim=DimDynamic.DYNAMIC,
+ )
+
+ sym_node_proxy = self.tx.output.root_tracer.create_graph_input(
+ re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
+ type(new_symint),
+ source=new_source,
+ )
+
+ sym_node_proxy.node.meta["grapharg"] = GraphArg(
+ new_source,
+ new_symint,
+ False,
+ None,
+ is_tensor=False,
+ example_strong_ref=new_symint,
+ )
+ self.tx.output.tracked_fakes.append(
+ TrackedFake(new_symint, new_source, None)
+ )
+ return SymNodeVariable(
+ sym_node_proxy,
+ new_symint == 1,
+ )
else:
result = UserDefinedObjectVariable(
value,
diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py
index 22af501..8b4dad7 100644
--- a/torch/_higher_order_ops/cond.py
+++ b/torch/_higher_order_ops/cond.py
@@ -124,7 +124,7 @@
return cond_op(pred, true_fn, false_fn, operands)
def _validate_input(pred, true_fn, false_fn, operands):
- if not isinstance(pred, (bool, torch.Tensor)):
+ if not isinstance(pred, (bool, torch.Tensor, torch.SymBool)):
raise RuntimeError(f"Expected pred to be bool or tensor, but got {pred}.")
if isinstance(pred, torch.Tensor) and pred.numel() != 1:
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index 985986d..3ffc52c 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -1286,6 +1286,10 @@
op = getattr(operator, method_attr)
return op
+def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt:
+ int_sym = sympy.Piecewise((1, symbool.node.expr), (0, True))
+ return symbool.node.shape_env.create_symintnode(int_sym, hint=int(symbool.node.require_hint()))
+
SYMPY_INTERP = {
'Eq': operator.eq,
'Ne': operator.ne,
@@ -1301,6 +1305,7 @@
'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense,
'floor': math.floor,
'ceiling': math.ceil,
+ 'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless,
}
always_float_magic_methods = {"truediv", "sym_float", "sym_sqrt", "pow"}
@@ -2723,7 +2728,10 @@
) -> "sympy.Expr":
# 'positive' is None for unspecified symbols, since we can't
# assume that it will be neither positive nor negative.
- return self.create_symbol(val, source, dynamic_dim, constraint_dim, positive=None)
+
+ # We don't want to specialize zero one val for unspecified symbol
+ # so that we can always get a new symbol despite val.
+ return self.create_symbol(val, source, dynamic_dim, constraint_dim, positive=None, do_not_specialize_zero_one=True)
@record_shapeenv_event()
def create_symbol(
@@ -2733,7 +2741,13 @@
dynamic_dim: DimDynamic = DimDynamic.DUCK,
constraint_dim: DimConstraint = None, # NB: includes None
positive: Optional[bool] = True,
+ do_not_specialize_zero_one: bool = False,
) -> "sympy.Expr":
+ if do_not_specialize_zero_one:
+ specialize_zero_one = False
+ else:
+ specialize_zero_one = self.specialize_zero_one
+
assert isinstance(source, Source), f"{type(source)} {source}"
assert not (positive and val < 0), f"positive set for negative value: {val}"
# It's always sound to allocate a symbol as DYNAMIC. If the user
@@ -2753,7 +2767,7 @@
else:
raise AssertionError(f"unhandled dynamic_dim {dynamic_dim}")
- if val in (0, 1) and self.specialize_zero_one:
+ if val in (0, 1) and specialize_zero_one:
r = self.val_to_var[val]
elif not duck or val not in self.val_to_var:
# If we're not duck shaping, we always create a new symbol
diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py
index 3eaf4e1..742737b 100644
--- a/torch/utils/_sympy/interp.py
+++ b/torch/utils/_sympy/interp.py
@@ -57,6 +57,8 @@
sympy.Min: "minimum",
sympy.Max: "maximum",
ModularIndexing: "modular_indexing",
+ sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair",
+ sympy.Piecewise: "piecewise",
}
return HANDLERS
diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py
index 87db5f7..f40db2c 100644
--- a/torch/utils/_sympy/value_ranges.py
+++ b/torch/utils/_sympy/value_ranges.py
@@ -6,7 +6,7 @@
import math
import logging
import torch
-from typing import Union, Dict, Optional
+from typing import Union, Dict, Optional, SupportsFloat
from torch._prims_common import dtype_to_type
from .interp import sympy_interp
@@ -190,7 +190,7 @@
# using nan makes subsequent computation throw, and for the purposes of optimization
# returning -math.inf - math.inf is equivalent to giving up
- if math.isnan(value):
+ if isinstance(value, SupportsFloat) and math.isnan(value):
return ValueRanges.unknown()
if is_python:
@@ -456,6 +456,30 @@
else:
return ValueRanges(sympy.Min(b.lower, c.lower), sympy.Max(b.upper, c.upper))
+ # expr_cond_pair is used to represent a single (expr, condition) pair in piecewise.
+ # We just return the value range of the expression and its corresponding condition as a tuple
+ # and defer the analysis to piecewise
+ @staticmethod
+ def expr_cond_pair(a, b):
+ assert b.is_bool, f"expect cond_expr's ValueRange to be a boolean range but got {b}"
+ return (a, b)
+
+ # piecewise function can be used to convert a SymBool to SymInt:
+ # int_expr = Piecewise((1, bool_expr), (0, True)), it evalutes to 1 when sym_bool is True and 0 otherwise.
+ #
+ # ranges is a sequence of (expr_range, condition_range) pairs. The range pair is constructed in expr_cond_pair.
+ # The ValueRange of Piecewise is just the union of all expr ranges whose condition expr can be True.
+ @staticmethod
+ def piecewise(*ranges):
+ init_range = None
+ for expr_range, cond_range in ranges:
+ if sympy.true in cond_range:
+ if init_range is None:
+ init_range = expr_range
+ else:
+ init_range = init_range | expr_range
+ return init_range
+
class ValueRangeAnalysis(SymPyValueRangeAnalysis):
def __init__(self):