| import inspect |
| from typing import Any, Callable, Dict, List, Optional, Set |
| from collections import OrderedDict |
| import logging |
| |
| import torch |
| from torch.fx._compatibility import compatibility |
| from torch.fx.graph_module import GraphModule |
| from torch.fx.node import Node |
| |
| |
| __all__ = ["Partition", "split_module"] |
| _LOGGER = logging.getLogger(__name__) |
| |
| @compatibility(is_backward_compatible=True) |
| class Partition: |
| def __init__(self, name: str): |
| self.name: str = name |
| self.submod_name = f"submod_{name}" |
| self.node_names: List[str] = [] |
| self.inputs: Dict[str, None] = {} |
| self.outputs: Dict[str, None] = {} |
| self.dependencies: Dict[str, None] = {} |
| self.dependents: Dict[str, None] = {} |
| self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph() |
| self.environment: Dict[Node, Node] = {} |
| self.targets: Dict[str, Any] = {} |
| |
| def __repr__(self) -> str: |
| return ( |
| f"name: {self.name},\n" |
| f" nodes: {self.node_names},\n" |
| f" inputs: {self.inputs},\n" |
| f" outputs: {self.outputs},\n" |
| f" partitions depended on: {self.dependencies},\n" |
| f" partition dependents: {self.dependents}" |
| ) |
| |
| |
| # Creates subgraphs out of main graph |
| @compatibility(is_backward_compatible=True) |
| def split_module( |
| m: GraphModule, |
| root_m: torch.nn.Module, |
| split_callback: Callable[[Node], int], |
| qualname_map: Optional[Dict[str, str]] = None, |
| keep_original_order: Optional[bool] = False, |
| keep_original_node_name: Optional[bool] = False, |
| ): |
| """ |
| Creates subgraphs out of main graph |
| |
| Args: |
| m (GraphModule): Graph module to split |
| root_m (torch.nn.Module): root nn module. Not currently used. Included |
| because the root nn module is usually transformed via |
| torch.fx._symbolic_trace.symbolic_trace (see example below) |
| split_callback (Callable[[Node], int]): Callable function |
| that maps a given Node instance to a numeric partition identifier. |
| split_module will use this function as the policy for which operations |
| appear in which partitions in the output Module. |
| qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a |
| mapping from new target names in the module after split to old target |
| names in the original module. |
| keep_original_order: Optional[bool]: keep the original order of the GraphModule |
| or use the Topological order of the new constructed GraphModule |
| |
| |
| Returns: |
| GraphModule: the module after split. |
| |
| Example: |
| |
| This is a sample setup: |
| |
| import torch |
| from torch.fx.symbolic_trace import symbolic_trace |
| from torch.fx.graph_module import GraphModule |
| from torch.fx.node import Node |
| from torch.fx.passes.split_module import split_module |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.param = torch.nn.Parameter(torch.rand(3, 4)) |
| self.linear = torch.nn.Linear(4, 5) |
| |
| def forward(self, x, y): |
| z = self.linear(x + self.param).clamp(min=0.0, max=1.0) |
| w = self.linear(y).clamp(min=0.0, max=1.0) |
| return z + w |
| |
| # symbolically trace model |
| my_module = MyModule() |
| my_module_traced = symbolic_trace(my_module) |
| |
| # random mod partitioning |
| partition_counter = 0 |
| NPARTITIONS = 3 |
| |
| def mod_partition(node: Node): |
| global partition_counter |
| partition = partition_counter % NPARTITIONS |
| partition_counter = (partition_counter + 1) % NPARTITIONS |
| return partition |
| |
| # split module in module with submodules |
| module_with_submodules = split_module( |
| my_module_traced, my_module, mod_partition |
| ) |
| |
| Output looks like this. Original graph is broken into partitions |
| |
| > print(module_with_submodules) |
| GraphModule( |
| (submod_0): GraphModule( |
| (linear): Linear(in_features=4, out_features=5, bias=True) |
| ) |
| (submod_1): GraphModule( |
| (linear): Linear(in_features=4, out_features=5, bias=True) |
| ) |
| (submod_2): GraphModule() |
| ) |
| |
| def forward(self, x, y): |
| param = self.param |
| submod_0 = self.submod_0(x, param, y); x = param = y = None |
| getitem = submod_0[0] |
| getitem_1 = submod_0[1]; submod_0 = None |
| submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None |
| getitem_2 = submod_1[0] |
| getitem_3 = submod_1[1]; submod_1 = None |
| submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None |
| return submod_2 |
| |
| Output of split module is the same as output of input traced module. |
| This is an example within a test setting: |
| |
| > orig_out = my_module_traced(x, y) |
| > submodules_out = module_with_submodules(x, y) |
| > self.assertEqual(orig_out, submodules_out) |
| True |
| """ |
| |
| def construct_graph( |
| node: Node, |
| base_mod_env: Dict[str, Node], |
| base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule], |
| ): |
| if node.op == "placeholder": |
| default_value = ( |
| node.args[0] if len(node.args) > 0 else inspect.Signature.empty |
| ) |
| if keep_original_node_name: |
| args = () if default_value is inspect.Signature.empty else (default_value,) |
| base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type) |
| else: |
| base_mod_env[node.name] = base_mod_graph.placeholder( |
| node.target, type_expr=node.type, default_value=default_value |
| ) |
| base_mod_env[node.name].meta = node.meta.copy() |
| elif node.op == "get_attr": |
| base_mod_env[node.name] = base_mod_graph.get_attr(node.target) |
| base_mod_env[node.name].meta = node.meta.copy() |
| attr_val = m |
| for atom in node.target.split("."): # type: ignore[union-attr] |
| if not hasattr(attr_val, atom): |
| raise AttributeError(f"Node target {node.target} not found!") |
| attr_val = getattr(attr_val, atom) |
| base_mod_attrs[node.target] = attr_val # type: ignore[index] |
| return base_mod_env, base_mod_attrs |
| |
| import sympy |
| |
| partitions: Dict[str, Partition] = {} |
| orig_nodes: Dict[str, Node] = {} |
| symbol_to_node: Dict[sympy.Symbol, Node] = {} |
| |
| def record_cross_partition_use( |
| def_node: Node, use_node: Optional[Node] |
| ): # noqa: B950 |
| from torch.fx.experimental.symbolic_shapes import free_symbols |
| |
| defined = getattr(def_node, "_fx_partition", None) |
| used = getattr(use_node, "_fx_partition", None) |
| if defined != used: |
| if defined is not None: |
| def_partition = partitions[defined] |
| def_partition.outputs.setdefault(def_node.name) |
| if used is not None: |
| def_partition.dependents.setdefault(used) |
| |
| if used is not None: |
| use_partition = partitions[used] |
| use_partition.inputs.setdefault(def_node.name) |
| if (def_val := def_node.meta.get("example_value")) is not None: |
| for s in sorted(free_symbols(def_val), key=str): |
| use_partition.inputs.setdefault(symbol_to_node[s].name) |
| if defined is not None: |
| use_partition.dependencies.setdefault(defined) |
| |
| def instantiate_node_partition_mapping(node): |
| partition_name = str(split_callback(node)) |
| |
| # add node to partitions |
| partition = partitions.get(partition_name) |
| if partition is None: |
| partitions[partition_name] = partition = Partition(partition_name) |
| |
| partition.node_names.append(node.name) |
| node._fx_partition = partition_name |
| |
| # Global State Nodes are nodes which by their global state effects, |
| # "taint" all downstream nodes while they are active. |
| GLOBAL_STATE_NODES = [ |
| torch.amp._enter_autocast, |
| torch.amp._exit_autocast, |
| torch._C._set_grad_enabled |
| ] |
| |
| # For grad regions: |
| # ------------------------ |
| # 1. first region: we do nothing |
| # 2. subsequent regions: we insert the set_grad at the beginning |
| grad_regions: OrderedDict[Node, Set[int]] = OrderedDict() |
| |
| # For autocast regions: |
| # ------------------------ |
| # 1. first region: we will only insert the _exit at the end |
| # 2. intermediate regions: we will insert both the |
| # _enter at the beginning and _exit at the end |
| # 3. last region: we will only insert _enter at the beginning |
| # We will do so in the order in which the autocasts were instantiated. |
| autocast_regions: OrderedDict[Node, Set[int]] = OrderedDict() |
| autocast_exits: Dict[Node, Optional[Node]] = {} |
| |
| active_grad = None |
| active_autocasts = set() |
| |
| for node in m.graph.nodes: |
| if node.op in ["placeholder", "get_attr", "output"]: |
| if ( |
| node.op == "placeholder" and |
| (val := node.meta.get("example_value")) is not None and |
| isinstance(val, torch.SymInt) and |
| isinstance(val.node.expr, sympy.Symbol) |
| ): |
| symbol_to_node[val.node.expr] = node |
| continue |
| |
| instantiate_node_partition_mapping(node) |
| |
| if node.op == "call_function" and node.target in GLOBAL_STATE_NODES: |
| if node.target == torch._C._set_grad_enabled: |
| assert len(node.args) == 1 |
| assert isinstance(node.args[0], bool) |
| active_grad = node |
| grad_regions[active_grad] = set({split_callback(node)}) |
| elif node.target == torch.amp._enter_autocast: |
| # Should all be python constants |
| assert all(not isinstance(arg, Node) for arg in node.args) |
| active_autocasts.add(node) |
| autocast_regions[node] = set({split_callback(node)}) |
| autocast_exits[node] = None |
| elif node.target == torch.amp._exit_autocast: |
| assert len(node.args) == 1 |
| autocast_regions[node.args[0]].add(split_callback(node)) |
| active_autocasts.remove(node.args[0]) |
| autocast_exits[node.args[0]] = node |
| |
| if active_grad is not None: |
| grad_regions[active_grad].add(split_callback(node)) |
| |
| for a in active_autocasts: |
| autocast_regions[a].add(split_callback(node)) |
| |
| assert all(v is not None for v in autocast_exits.values()), "autocast must exit" |
| |
| autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()} |
| grad_regions = {k: sorted(v) for k, v in grad_regions.items()} |
| |
| if _LOGGER.isEnabledFor(logging.DEBUG): |
| _LOGGER.debug("autocast_regions: %s", autocast_regions) |
| _LOGGER.debug("grad_regions: %s", grad_regions) |
| |
| assert_monotonically_increasing = bool(autocast_regions) or bool(grad_regions) |
| |
| # split nodes into partitions |
| highest_partition = -1 |
| for node in m.graph.nodes: |
| orig_nodes[node.name] = node |
| |
| # TODO currently placeholders/parameters aren't put into random partitions, |
| # rather they're added to the graphs where they are used down below |
| if node.op in ["placeholder", "get_attr"]: |
| continue |
| if node.op == "output": |
| torch.fx.graph.map_arg( |
| node.args[0], lambda n: record_cross_partition_use(n, None) |
| ) |
| continue |
| |
| if assert_monotonically_increasing: |
| pid = split_callback(node) |
| assert highest_partition <= pid, \ |
| ("autocast or set_grad_enabled require monotonically increasing partitions:" |
| f"highest: {highest_partition}, this node's: {pid}") |
| highest_partition = pid |
| |
| # do not capture cross-partition dependencies for global state nodes as they will be |
| # self-contained - their setup and unwind will be isolated to each partition submodule. |
| if node.target not in GLOBAL_STATE_NODES: |
| torch.fx.graph.map_arg( |
| node.args, lambda def_node: record_cross_partition_use(def_node, node) |
| ) |
| torch.fx.graph.map_arg( |
| node.kwargs, lambda def_node: record_cross_partition_use(def_node, node) |
| ) # noqa: B950 |
| |
| original_partition_order = list(partitions.keys()) |
| # find partitions with no dependencies |
| root_partitions: List[str] = [] |
| for partition_name, partition in partitions.items(): |
| if not len(partition.dependencies): |
| root_partitions.append(partition_name) |
| |
| # check partitions for circular dependencies and create topological partition ordering |
| sorted_partitions: List[str] = [] |
| while root_partitions: |
| root_partition = root_partitions.pop() |
| sorted_partitions.append(root_partition) |
| for dependent in partitions[root_partition].dependents: |
| partitions[dependent].dependencies.pop(root_partition) |
| if not partitions[dependent].dependencies: |
| root_partitions.append(dependent) |
| if len(sorted_partitions) != len(partitions): |
| raise RuntimeError("cycle exists between partitions!") |
| |
| # Enter prelude |
| for regions_mapping in [autocast_regions, grad_regions]: |
| for node, regions in regions_mapping.items(): |
| assert len(regions) > 0 |
| partitions[str(regions[0])].environment[node] = node |
| for r in regions[1:]: |
| partition = partitions[str(r)] |
| new_node = partition.graph.create_node( |
| op=node.op, |
| target=node.target, |
| args=tuple(arg for arg in node.args), |
| kwargs={}, |
| type_expr=node.type, |
| ) |
| new_node.meta = node.meta.copy() # is it really a good idea to copy this? |
| partition.environment[node] = new_node |
| |
| # add placeholders to partition inputs |
| for partition_name in sorted_partitions: |
| partition = partitions[partition_name] |
| for inp in partition.inputs: |
| placeholder = partition.graph.placeholder( |
| inp, |
| type_expr=orig_nodes[inp].type, |
| ) |
| placeholder.meta = orig_nodes[inp].meta.copy() |
| partition.environment[orig_nodes[inp]] = placeholder |
| |
| # Transform nodes and collect targets for partition's submodule |
| for node in m.graph.nodes: |
| if hasattr(node, "_fx_partition"): |
| partition = partitions[node._fx_partition] |
| |
| # swap out old graph nodes in kw/args with references to new nodes in this submodule |
| environment = partition.environment |
| gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) |
| gathered_kwargs = torch.fx.graph.map_arg( |
| node.kwargs, lambda n: environment[n] |
| ) |
| |
| if node.op not in ["call_module", "get_attr"]: |
| target = node.target |
| else: |
| target_atoms = node.target.split(".") |
| target_attr = m |
| for atom in target_atoms: |
| if not hasattr(target_attr, atom): |
| raise AttributeError(f"Operator target {node.target} not found!") |
| target_attr = getattr(target_attr, atom) |
| # target = target_atoms[-1] |
| target = "_".join(target_atoms) |
| partition.targets[target] = target_attr |
| # Fill in the passed-in mapping from new qualname to old qualname |
| if qualname_map is not None: |
| # When creating the split module later, the submodules will have |
| # path prefix matching the corresponding partition's submod_name |
| qualname = f"{partition.submod_name}.{target}" |
| qualname_map[qualname] = node.target |
| |
| assert isinstance(gathered_args, tuple) |
| assert isinstance(gathered_kwargs, dict) |
| name = node.name if keep_original_node_name else None |
| new_node = partition.graph.create_node( |
| op=node.op, |
| target=target, |
| args=gathered_args, |
| kwargs=gathered_kwargs, |
| type_expr=node.type, |
| name=name, |
| ) |
| new_node.meta = node.meta.copy() |
| partition.environment[node] = new_node |
| |
| # Exit epilogue |
| for regions_mapping in [autocast_regions]: |
| for node in reversed(regions_mapping): |
| regions = regions_mapping[node] |
| assert len(regions) > 0 |
| for r in regions[:-1]: |
| partition = partitions[str(r)] |
| exit_node = autocast_exits[node] |
| assert exit_node is not None, "Missing exit node" |
| new_node = partition.graph.create_node( |
| op=exit_node.op, |
| target=exit_node.target, |
| args=(partition.environment[node],), |
| kwargs={}, |
| type_expr=exit_node.type, |
| ) |
| new_node.meta = exit_node.meta.copy() # is it really a good idea to copy this? |
| |
| # original module environment dict mapping node names to nodes |
| orig_mod_env: Dict[str, Node] = {} |
| # Set up values to construct base module |
| base_mod_env: Dict[str, Node] = {} |
| base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() |
| base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} |
| if not keep_original_order: |
| for node in m.graph.nodes: |
| base_mod_env, base_mod_attrs = construct_graph( |
| node, base_mod_env, base_mod_attrs |
| ) |
| |
| else: |
| # Go through the graph to construct the mapping dict |
| for node in m.graph.nodes: |
| orig_mod_env[node.name] = node |
| |
| # Do some things iterating over the partitions in topological order again: |
| # 1) Finish off submodule Graphs by setting corresponding outputs |
| # 2) Construct GraphModules for each submodule |
| # 3) Construct the base graph by emitting calls to those submodules in |
| # topological order or original order specified by keep_original_order |
| |
| construct_order_partitions = ( |
| sorted_partitions if not keep_original_order else original_partition_order |
| ) |
| |
| already_constructed_attr_nodes = set() |
| |
| # We actually need to insert the placeholder nodes in the original order |
| # otherwise graph signature will be wrong. |
| original_order = [node for node in m.graph.nodes if node.op == "placeholder"] |
| |
| for partition_name in construct_order_partitions: |
| partition = partitions[partition_name] |
| |
| # Set correct output values |
| output_vals = tuple( |
| partition.environment[orig_nodes[name]] for name in partition.outputs |
| ) |
| |
| # skip output node generation if there are no output values |
| num_output_vals = len(output_vals) |
| if num_output_vals == 1: |
| partition.graph.output(output_vals[0]) |
| elif num_output_vals > 1: |
| partition.graph.output(output_vals) |
| |
| if keep_original_order: |
| # first get the attr nodes required by this partition |
| orig_mod_attr_nodes: List[Node] = [ |
| orig_mod_env[key] for key in partition.inputs if key not in original_order |
| ] |
| |
| for node in original_order: |
| if node in already_constructed_attr_nodes: |
| continue # already added this attr to the base graph |
| base_mod_env, based_mod_attrs = construct_graph( |
| node, base_mod_env, base_mod_attrs |
| ) |
| already_constructed_attr_nodes.add(node) |
| |
| # Construct GraphModule for this partition |
| for node in orig_mod_attr_nodes: # type: ignore[attr-defined] |
| if node in already_constructed_attr_nodes: |
| continue |
| base_mod_env, base_mod_attrs = construct_graph( |
| node, base_mod_env, base_mod_attrs |
| ) |
| already_constructed_attr_nodes.add(node) |
| |
| base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule( |
| partition.targets, partition.graph |
| ) # noqa: B950 |
| |
| # Emit call in base graph to this submodule |
| output_val = base_mod_graph.call_module( |
| partition.submod_name, |
| tuple(base_mod_env[name] for name in partition.inputs), |
| ) |
| |
| num_outputs = len(partition.outputs) |
| if num_outputs > 1: |
| # Unpack multiple return values from submodule |
| output_val_proxy = torch.fx.proxy.Proxy(output_val) |
| for i, output_name in enumerate(partition.outputs): |
| base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] |
| elif num_outputs == 1: |
| base_mod_env[next(iter(partition.outputs))] = output_val |
| |
| for node in m.graph.nodes: |
| if node.op == "output": |
| base_mod_graph.output( |
| torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name]) |
| ) # noqa: B950 |
| |
| return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) |