| # mypy: allow-untyped-defs |
| # Copyright (c) Meta Platforms, Inc. and affiliates |
| from typing import Dict |
| |
| import torch |
| from torch.export.unflatten import _ModuleFrame |
| |
| |
| def _outline_submodules(orig_graph: torch.fx.Graph): |
| # Create an empty GraphModule to hold the outlined modules |
| new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) |
| seen_nodes: Dict[str, torch.fx.Node] = {} |
| seen_modules: Dict[int, torch.nn.Module] = {} |
| _ModuleFrame( |
| orig_graph, |
| tuple(orig_graph.nodes), |
| seen_nodes, |
| seen_modules, |
| None, |
| [""], |
| "", |
| {}, |
| module=new_module, |
| ).run_outer() |
| new_module.graph.lint() |
| new_module.recompile() |
| return new_module |