[dynamo] Support control flow map() operator. (#91939)

Fixes #ISSUE_NUMBER

We want to add support for control flow map() at dynamo level to unblock some internal model which will have to use map() operator in captured graph. Basically I replicate the pattern for implementing cond() op from https://github.com/pytorch/pytorch/pull/90286

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91939
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py
index 3652606..5168e4e 100644
--- a/test/dynamo/test_export.py
+++ b/test/dynamo/test_export.py
@@ -1456,6 +1456,62 @@
         dynamo_result_2 = out_graph(pred, x)
         self.assertTrue(torch._dynamo.utils.same(real_result_2, dynamo_result_2))
 
+    @patch.object(torch._dynamo.config, "dynamic_shapes", True)
+    def test_export_with_map_cond(self):
+        from functorch.experimental.control_flow import cond, map
+
+        class Module(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def inner(self, x, pred):
+                def true_fn(x):
+                    return x + x
+
+                def false_fn(x):
+                    return x * x
+
+                return cond(pred, true_fn, false_fn, [x])
+
+            def forward(self, pred, xs):
+                def body(x, pred):
+                    return self.inner(x, pred)
+
+                return map(body, xs, pred)
+
+        mod = Module()
+        x = torch.randn(3, 2, 1)
+        pred_x = torch.tensor(True)
+
+        y = torch.randn(4, 3, 2)
+        pred_y = torch.tensor(False)
+        real_result = mod(pred_y, y)
+
+        out_graph, _ = torch._dynamo.export(mod, pred_x, x)
+        self.assertEqual(real_result, out_graph(pred_y, y))
+
+    @patch.object(torch._dynamo.config, "dynamic_shapes", True)
+    def test_export_with_map_zero_sized_tensor(self):
+        from functorch.experimental.control_flow import map
+
+        class Module(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, xs):
+                def body(x):
+                    return x + 1
+
+                return map(body, xs)
+
+        mod = Module()
+        xs = torch.randn(0, 2)
+        with self.assertRaisesRegex(
+            torch._dynamo.exc.Unsupported,
+            "zero-sized tensor",
+        ):
+            out_graph, _ = torch._dynamo.export(mod, xs)
+
     def test_export_meta_val(self):
         def f(x, y, z):
             return x * y + z
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 95d9e59..330ca66 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -2376,6 +2376,29 @@
         a = opt_fn(torch.tensor(False), torch.tensor([0.25, 0.25]))
         self.assertTrue(same(torch.tensor([1.25, 1.25]), a))
 
+    def test_map_side_effects(self):
+        from functorch.experimental.control_flow import map
+
+        class Module(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.w = torch.tensor(1)
+
+            def forward(self, xs):
+                def body(x):
+                    self.w += 1
+                    return x
+
+                return map(body, xs)
+
+        mod = Module()
+        with self.assertRaisesRegex(
+            torch._dynamo.exc.Unsupported,
+            "Graph state change detected",
+        ):
+            opt_fn = torch._dynamo.optimize("eager", nopython=True)(mod)
+            opt_fn(torch.randn(3, 2))
+
     def test_cond_nested(self):
         from functorch.experimental.control_flow import cond
 
diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py
index ccb719a..75eaec9 100644
--- a/torch/_dynamo/variables/torch.py
+++ b/torch/_dynamo/variables/torch.py
@@ -10,6 +10,7 @@
 import torch.fx
 import torch.nn
 import torch.onnx.operators
+from torch._dynamo.utils import get_fake_value
 from torch._guards import GuardsCheckpointState
 
 from .. import config, variables
@@ -699,6 +700,69 @@
             tx.output.register_attr_or_module(gm, next_name, source=src)
             return next_name
 
+        def get_comparable_state(state):
+            # Nub out bits of state that we don't require to be
+            # equal
+            return state._replace(
+                output=state.output._replace(
+                    guard_state=GuardsCheckpointState(set()),
+                    nn_modules=None,
+                    # Timestamp is monotonically increasing so we don't
+                    # care about divergence
+                    timestamp=0,
+                    # Meh (problem is the nodes don't compare equal;
+                    # maybe nub out outputs only)
+                    name_to_input=OrderedDict(),
+                    # Unused in branches
+                    graphargs=[],
+                )
+            )
+
+        def speculate_subgraph(f, sub_args, graph_checkpoint, checkpoint):
+            # Setup the subgraph we're going to capture into
+            tx.output.graph = torch.fx.Graph()
+            tx.output.graphargs = []
+            tx.output.name_to_input.clear()
+
+            args = []
+            # One argument to graph per sub_args
+            for a in sub_args:
+                if isinstance(a, TensorVariable):
+                    tx.output.create_graph_input(a.as_proxy().node.name)
+                    args.append(a)
+                else:
+                    # call_function() needs a TensorVariable, therefore we construct
+                    # one with inner graph proxy.
+                    assert isinstance(a, torch.Tensor)
+                    proxy = tx.output.create_graph_input("arg")
+                    args.append(wrap_fx_proxy(tx=tx, proxy=proxy, example_value=a))
+                # NB: we don't bother populating graphargs, as
+                # they won't actually get used by anything
+
+            output = f.call_function(tx, args, {})
+
+            # Register output to graph
+            # Modeled off of compile_and_call_fx_graph
+            # TODO: support non single Tensor output
+            assert isinstance(output, TensorVariable)
+            tx.output.guards.update(output.guards)
+            tx.output.create_node(
+                "output", "output", (tx.output.create_arg((output.as_proxy(),))), {}
+            )
+
+            tx.output.side_effects.prune_dead_object_new(tx)
+            state = tx.copy_graphstate()
+
+            guards = state.output.guards
+            nn_modules = state.output.nn_modules
+
+            comparable_state = get_comparable_state(state)
+            graph = tx.output.graph
+            tx.output.graph = graph_checkpoint
+            tx.restore_graphstate(checkpoint)
+
+            return output, graph, guards, nn_modules, comparable_state
+
         if self.value.__name__ == "cond":
             # TODO(voz): Support fake tensor dispatch for recursive
             # ops - see torch/dispatch/_dispatcher.py
@@ -735,59 +799,12 @@
             sub_args = args[3].unpack_var_sequence(tx)
 
             def speculate_branch(branch):
-                # Setup the subgraph we're going to capture into
-                tx.output.graph = torch.fx.Graph()
-                tx.output.graphargs = []
-                tx.output.name_to_input.clear()
-
-                # One argument to graph per sub_args
-                for a in sub_args:
-                    assert isinstance(a, TensorVariable)
-                    tx.output.create_graph_input(a.as_proxy().node.name)
-                    # NB: we don't bother populating graphargs, as
-                    # they won't actually get used by anything
-
                 # NB: 0 is predicate
                 ix = 1 if branch else 2
-
-                output = args[ix].call_function(tx, sub_args, {})
-
-                # Register output to graph
-                # Modeled off of compile_and_call_fx_graph
-                # TODO: support non single Tensor output
-                assert isinstance(output, TensorVariable)
-                tx.output.guards.update(output.guards)
-                tx.output.create_node(
-                    "output", "output", (tx.output.create_arg((output.as_proxy(),))), {}
+                return speculate_subgraph(
+                    args[ix], sub_args, graph_checkpoint, checkpoint
                 )
 
-                tx.output.side_effects.prune_dead_object_new(tx)
-                state = tx.copy_graphstate()
-
-                guards = state.output.guards
-                nn_modules = state.output.nn_modules
-
-                # Nub out bits of state that we don't require to be
-                # equal
-                comparable_state = state._replace(
-                    output=state.output._replace(
-                        guard_state=GuardsCheckpointState(set()),
-                        nn_modules=None,
-                        # Timestamp is monotonically increasing so we don't
-                        # care about divergence
-                        timestamp=0,
-                        # Meh (problem is the nodes don't compare equal;
-                        # maybe nub out outputs only)
-                        name_to_input=OrderedDict(),
-                    )
-                )
-
-                graph = tx.output.graph
-                tx.output.graph = graph_checkpoint
-                tx.restore_graphstate(checkpoint)
-
-                return output, graph, guards, nn_modules, comparable_state
-
             (
                 true_r,
                 true_graph,
@@ -827,11 +844,62 @@
                 args[0].as_proxy(),
                 true_node,
                 false_node,
-                tuple(a.as_proxy() for a in sub_args),
+                list(a.as_proxy() for a in sub_args),
             )
             # TODO: assert that the true/false return values are
             # consistent
             example_value = true_r.as_proxy().node.meta["example_value"]
+        elif self.value.__name__ == "map":
+            assert type(args[0]) in (UserFunctionVariable, NestedUserFunctionVariable)
+            assert type(args[1]) is TensorVariable
+
+            sample_shape = args[1].get_real_value().size()
+            if len(sample_shape) < 1 or sample_shape[0] == 0:
+                unimplemented(
+                    "map() operator doesn't support scalar or zero-sized tensors during tracing."
+                )
+
+            checkpoint = tx.copy_graphstate()
+            # To get the example output from map() we will need to prodive at least one sample to
+            # the loop body. In our case we will always use xs[0], and our map() won't support zero
+            # sized tensor during tracing.
+            (
+                body_r,
+                body_graph,
+                body_guards,
+                body_nn_modules,
+                body_cmp,
+            ) = speculate_subgraph(
+                args[0],
+                [
+                    get_fake_value(args[1].as_proxy().node, tx)[0],
+                    *args[2:],
+                ],
+                tx.output.graph,
+                checkpoint,
+            )
+
+            # We don't support side effects inside a map loop body for simplicity.
+            parent_cmp = get_comparable_state(checkpoint)
+            if parent_cmp != body_cmp:
+                diff = parent_cmp.diff(body_cmp)
+                raise unimplemented(
+                    f"Graph state change detected in map() loop body. Diagnostics: {diff}"
+                )
+
+            # Add guards
+            tx.output.tracing_context.guards_context.dynamo_guards |= body_guards
+
+            body_name = add_subgraph(
+                "body", torch.fx.GraphModule(body_nn_modules, body_graph)
+            )
+
+            body_node = make_attr(body_name)
+            p_args = (body_node, *(arg.as_proxy() for arg in args[1:]))
+            r = body_r.as_proxy().node.meta["example_value"]
+            example_value = r.new_empty(
+                [get_fake_value(args[1].as_proxy().node, tx).shape[0], *r.shape]
+            )
         else:
             unimplemented(f"PyOperator {self.value.__name__}")