[dynamo] Support control flow map() operator. (#91939)
Fixes #ISSUE_NUMBER
We want to add support for control flow map() at dynamo level to unblock some internal model which will have to use map() operator in captured graph. Basically I replicate the pattern for implementing cond() op from https://github.com/pytorch/pytorch/pull/90286
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91939
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py
index 3652606..5168e4e 100644
--- a/test/dynamo/test_export.py
+++ b/test/dynamo/test_export.py
@@ -1456,6 +1456,62 @@
dynamo_result_2 = out_graph(pred, x)
self.assertTrue(torch._dynamo.utils.same(real_result_2, dynamo_result_2))
+ @patch.object(torch._dynamo.config, "dynamic_shapes", True)
+ def test_export_with_map_cond(self):
+ from functorch.experimental.control_flow import cond, map
+
+ class Module(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def inner(self, x, pred):
+ def true_fn(x):
+ return x + x
+
+ def false_fn(x):
+ return x * x
+
+ return cond(pred, true_fn, false_fn, [x])
+
+ def forward(self, pred, xs):
+ def body(x, pred):
+ return self.inner(x, pred)
+
+ return map(body, xs, pred)
+
+ mod = Module()
+ x = torch.randn(3, 2, 1)
+ pred_x = torch.tensor(True)
+
+ y = torch.randn(4, 3, 2)
+ pred_y = torch.tensor(False)
+ real_result = mod(pred_y, y)
+
+ out_graph, _ = torch._dynamo.export(mod, pred_x, x)
+ self.assertEqual(real_result, out_graph(pred_y, y))
+
+ @patch.object(torch._dynamo.config, "dynamic_shapes", True)
+ def test_export_with_map_zero_sized_tensor(self):
+ from functorch.experimental.control_flow import map
+
+ class Module(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, xs):
+ def body(x):
+ return x + 1
+
+ return map(body, xs)
+
+ mod = Module()
+ xs = torch.randn(0, 2)
+ with self.assertRaisesRegex(
+ torch._dynamo.exc.Unsupported,
+ "zero-sized tensor",
+ ):
+ out_graph, _ = torch._dynamo.export(mod, xs)
+
def test_export_meta_val(self):
def f(x, y, z):
return x * y + z
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 95d9e59..330ca66 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -2376,6 +2376,29 @@
a = opt_fn(torch.tensor(False), torch.tensor([0.25, 0.25]))
self.assertTrue(same(torch.tensor([1.25, 1.25]), a))
+ def test_map_side_effects(self):
+ from functorch.experimental.control_flow import map
+
+ class Module(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.w = torch.tensor(1)
+
+ def forward(self, xs):
+ def body(x):
+ self.w += 1
+ return x
+
+ return map(body, xs)
+
+ mod = Module()
+ with self.assertRaisesRegex(
+ torch._dynamo.exc.Unsupported,
+ "Graph state change detected",
+ ):
+ opt_fn = torch._dynamo.optimize("eager", nopython=True)(mod)
+ opt_fn(torch.randn(3, 2))
+
def test_cond_nested(self):
from functorch.experimental.control_flow import cond
diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py
index ccb719a..75eaec9 100644
--- a/torch/_dynamo/variables/torch.py
+++ b/torch/_dynamo/variables/torch.py
@@ -10,6 +10,7 @@
import torch.fx
import torch.nn
import torch.onnx.operators
+from torch._dynamo.utils import get_fake_value
from torch._guards import GuardsCheckpointState
from .. import config, variables
@@ -699,6 +700,69 @@
tx.output.register_attr_or_module(gm, next_name, source=src)
return next_name
+ def get_comparable_state(state):
+ # Nub out bits of state that we don't require to be
+ # equal
+ return state._replace(
+ output=state.output._replace(
+ guard_state=GuardsCheckpointState(set()),
+ nn_modules=None,
+ # Timestamp is monotonically increasing so we don't
+ # care about divergence
+ timestamp=0,
+ # Meh (problem is the nodes don't compare equal;
+ # maybe nub out outputs only)
+ name_to_input=OrderedDict(),
+ # Unused in branches
+ graphargs=[],
+ )
+ )
+
+ def speculate_subgraph(f, sub_args, graph_checkpoint, checkpoint):
+ # Setup the subgraph we're going to capture into
+ tx.output.graph = torch.fx.Graph()
+ tx.output.graphargs = []
+ tx.output.name_to_input.clear()
+
+ args = []
+ # One argument to graph per sub_args
+ for a in sub_args:
+ if isinstance(a, TensorVariable):
+ tx.output.create_graph_input(a.as_proxy().node.name)
+ args.append(a)
+ else:
+ # call_function() needs a TensorVariable, therefore we construct
+ # one with inner graph proxy.
+ assert isinstance(a, torch.Tensor)
+ proxy = tx.output.create_graph_input("arg")
+ args.append(wrap_fx_proxy(tx=tx, proxy=proxy, example_value=a))
+ # NB: we don't bother populating graphargs, as
+ # they won't actually get used by anything
+
+ output = f.call_function(tx, args, {})
+
+ # Register output to graph
+ # Modeled off of compile_and_call_fx_graph
+ # TODO: support non single Tensor output
+ assert isinstance(output, TensorVariable)
+ tx.output.guards.update(output.guards)
+ tx.output.create_node(
+ "output", "output", (tx.output.create_arg((output.as_proxy(),))), {}
+ )
+
+ tx.output.side_effects.prune_dead_object_new(tx)
+ state = tx.copy_graphstate()
+
+ guards = state.output.guards
+ nn_modules = state.output.nn_modules
+
+ comparable_state = get_comparable_state(state)
+ graph = tx.output.graph
+ tx.output.graph = graph_checkpoint
+ tx.restore_graphstate(checkpoint)
+
+ return output, graph, guards, nn_modules, comparable_state
+
if self.value.__name__ == "cond":
# TODO(voz): Support fake tensor dispatch for recursive
# ops - see torch/dispatch/_dispatcher.py
@@ -735,59 +799,12 @@
sub_args = args[3].unpack_var_sequence(tx)
def speculate_branch(branch):
- # Setup the subgraph we're going to capture into
- tx.output.graph = torch.fx.Graph()
- tx.output.graphargs = []
- tx.output.name_to_input.clear()
-
- # One argument to graph per sub_args
- for a in sub_args:
- assert isinstance(a, TensorVariable)
- tx.output.create_graph_input(a.as_proxy().node.name)
- # NB: we don't bother populating graphargs, as
- # they won't actually get used by anything
-
# NB: 0 is predicate
ix = 1 if branch else 2
-
- output = args[ix].call_function(tx, sub_args, {})
-
- # Register output to graph
- # Modeled off of compile_and_call_fx_graph
- # TODO: support non single Tensor output
- assert isinstance(output, TensorVariable)
- tx.output.guards.update(output.guards)
- tx.output.create_node(
- "output", "output", (tx.output.create_arg((output.as_proxy(),))), {}
+ return speculate_subgraph(
+ args[ix], sub_args, graph_checkpoint, checkpoint
)
- tx.output.side_effects.prune_dead_object_new(tx)
- state = tx.copy_graphstate()
-
- guards = state.output.guards
- nn_modules = state.output.nn_modules
-
- # Nub out bits of state that we don't require to be
- # equal
- comparable_state = state._replace(
- output=state.output._replace(
- guard_state=GuardsCheckpointState(set()),
- nn_modules=None,
- # Timestamp is monotonically increasing so we don't
- # care about divergence
- timestamp=0,
- # Meh (problem is the nodes don't compare equal;
- # maybe nub out outputs only)
- name_to_input=OrderedDict(),
- )
- )
-
- graph = tx.output.graph
- tx.output.graph = graph_checkpoint
- tx.restore_graphstate(checkpoint)
-
- return output, graph, guards, nn_modules, comparable_state
-
(
true_r,
true_graph,
@@ -827,11 +844,62 @@
args[0].as_proxy(),
true_node,
false_node,
- tuple(a.as_proxy() for a in sub_args),
+ list(a.as_proxy() for a in sub_args),
)
# TODO: assert that the true/false return values are
# consistent
example_value = true_r.as_proxy().node.meta["example_value"]
+ elif self.value.__name__ == "map":
+ assert type(args[0]) in (UserFunctionVariable, NestedUserFunctionVariable)
+ assert type(args[1]) is TensorVariable
+
+ sample_shape = args[1].get_real_value().size()
+ if len(sample_shape) < 1 or sample_shape[0] == 0:
+ unimplemented(
+ "map() operator doesn't support scalar or zero-sized tensors during tracing."
+ )
+
+ checkpoint = tx.copy_graphstate()
+ # To get the example output from map() we will need to prodive at least one sample to
+ # the loop body. In our case we will always use xs[0], and our map() won't support zero
+ # sized tensor during tracing.
+ (
+ body_r,
+ body_graph,
+ body_guards,
+ body_nn_modules,
+ body_cmp,
+ ) = speculate_subgraph(
+ args[0],
+ [
+ get_fake_value(args[1].as_proxy().node, tx)[0],
+ *args[2:],
+ ],
+ tx.output.graph,
+ checkpoint,
+ )
+
+ # We don't support side effects inside a map loop body for simplicity.
+ parent_cmp = get_comparable_state(checkpoint)
+ if parent_cmp != body_cmp:
+ diff = parent_cmp.diff(body_cmp)
+ raise unimplemented(
+ f"Graph state change detected in map() loop body. Diagnostics: {diff}"
+ )
+
+ # Add guards
+ tx.output.tracing_context.guards_context.dynamo_guards |= body_guards
+
+ body_name = add_subgraph(
+ "body", torch.fx.GraphModule(body_nn_modules, body_graph)
+ )
+
+ body_node = make_attr(body_name)
+ p_args = (body_node, *(arg.as_proxy() for arg in args[1:]))
+ r = body_r.as_proxy().node.meta["example_value"]
+ example_value = r.new_empty(
+ [get_fake_value(args[1].as_proxy().node, tx).shape[0], *r.shape]
+ )
else:
unimplemented(f"PyOperator {self.value.__name__}")