| # mypy: allow-untyped-defs |
| # pyre-strict |
| from __future__ import annotations |
| |
| import heapq |
| import operator |
| import sys |
| from collections import defaultdict |
| from typing import Dict, List, Set, TYPE_CHECKING |
| |
| import torch |
| |
| from . import config, ir |
| from .dependencies import WeakDep |
| from .utils import ( |
| contains_collective, |
| contains_wait, |
| find_recursive_deps_of_node, |
| find_recursive_users_of_node, |
| is_collective, |
| is_fallback_op, |
| is_wait, |
| ) |
| |
| |
| overlap_log = torch._logging.getArtifactLogger(__name__, "overlap") |
| |
| if TYPE_CHECKING: |
| from .scheduler import BaseSchedulerNode |
| |
| |
| def sink_waits(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: |
| """ |
| Greedily schedules waits as late as possible. |
| """ |
| return _schedule_for_comm( |
| snodes, raise_comms=False, sink_waits=True, reorder_for_overlap=False |
| ) |
| |
| |
| def raise_comms(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: |
| """ |
| Greedily schedules comms as early as possible. |
| """ |
| return _schedule_for_comm( |
| snodes, raise_comms=True, sink_waits=False, reorder_for_overlap=False |
| ) |
| |
| |
| def reorder_compute_for_overlap( |
| snodes: List[BaseSchedulerNode], |
| ) -> List[BaseSchedulerNode]: |
| """ |
| This achieves the following overall scheduling procedure: |
| Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes |
| that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N. |
| Step 2: If all those compute nodes are sufficient to overlap comm N, we're done. |
| Otherwise, we now need to look elsewhere to find compute that overlaps with comm N. |
| We prioritize compute nodes that are needed sooner. |
| Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1. |
| Step 4: We schedule comm N + 1. |
| Repeat this for subsequent comm nodes. |
| """ |
| return _schedule_for_comm( |
| snodes, raise_comms=True, sink_waits=True, reorder_for_overlap=True |
| ) |
| |
| |
| def _schedule_for_comm( |
| snodes: List[BaseSchedulerNode], |
| raise_comms: bool, |
| sink_waits: bool, |
| reorder_for_overlap: bool, |
| ) -> List[BaseSchedulerNode]: |
| """ |
| Schedule `snodes` for various comm optimization objectives. |
| |
| Args: |
| snodes: the nodes to be scheduled. |
| raise_comms: whether to greedily schedule collectives as early as possible |
| sink_wait: whether to greedily schedule waits as late as possible |
| reorder_compute_for_overlap: whether to reorder compute nodes to |
| optimize for compute/communication overlapping. |
| |
| Returns: |
| The new schedule order. |
| |
| Some notes on the synergy between different options: |
| - `raise_comms` provides more overlapping oppurtunies for `reorder_compute_for_overlap`. |
| - When both `raise_comms` and `sink_waits` is `True`, `raise_comms` is prioritized. |
| """ |
| # We assign each node a tuple of scores (score_0, score_1, score_2), |
| # decreasing in importance, with a lower value indicating a higher ranking: |
| # |
| # - score_0: the lowest comm_idx among the comm nodes that the node blocks. |
| # If a node doesn't block any comm nodes, its score_0 is set to |
| # sys.maxsize. This score ensures that comm nodes get scheduled as early as |
| # possible. |
| # - score_1: 1 if the node is a wait node, 0 otherwise. This score ensures |
| # that wait nodes are deferred as late as possible. |
| # - score_2: the index of the node in the original topological order. This |
| # score provides stability in case of ties. |
| # |
| # When only raise_comms is True, only score_0 and score_2 are considered. |
| # When only sink_waits is True, only score_1 and score_2 are considered. |
| # When neither is True, the original order is yielded. |
| buf_name_to_snode = {} |
| name_to_fused_node = {} |
| scores_0, scores_1, scores_2 = {}, {}, {} |
| for idx, snode in enumerate(snodes): |
| for buf_name in snode.get_buffer_names(): |
| buf_name_to_snode[buf_name] = snode |
| |
| for op_name in snode.get_operation_names(): |
| name_to_fused_node[op_name] = snode |
| name_to_fused_node[snode.get_name()] = snode |
| |
| node_name = snode.get_name() |
| scores_0[node_name] = sys.maxsize |
| scores_1[node_name] = 0 |
| scores_2[node_name] = idx |
| |
| comm_idx = 0 |
| for snode in snodes: |
| if raise_comms and contains_collective(snode): |
| scores_0[snode.get_name()] = comm_idx |
| for anc in snode.ancestors: |
| anc_fused_name = name_to_fused_node[anc].get_name() |
| scores_0[anc_fused_name] = min(scores_0[anc_fused_name], comm_idx) |
| comm_idx += 1 |
| elif sink_waits and contains_wait(snode): |
| scores_1[snode.get_name()] = 1 |
| |
| class Runnable: |
| def __init__(self, snode) -> None: |
| self.snode = snode |
| name = next(iter(snode.get_operation_names())) |
| fused_name = name_to_fused_node[name].get_name() |
| self.score = ( |
| scores_0[fused_name], |
| scores_1[fused_name], |
| scores_2[fused_name], |
| ) |
| |
| def __lt__(self, other): |
| return self.score < other.score |
| |
| unmet_deps: Dict[BaseSchedulerNode, Set[str]] = { |
| snode: {dep.name for dep in snode.unmet_dependencies} for snode in snodes |
| } |
| |
| ready: List[Runnable] = [] |
| buffer_users: Dict[str, Set[BaseSchedulerNode]] = defaultdict(set) |
| snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes} |
| |
| for snode, deps in unmet_deps.items(): |
| if len(deps) == 0: |
| heapq.heappush(ready, Runnable(snode)) |
| for dep in deps: |
| buffer_users[dep].add(snode) |
| |
| scheduled = [] |
| |
| def schedule(snode): |
| """ |
| Schedules `snode` and put all unblocked nodes onto the ready queue. |
| """ |
| scheduled.append(snode) |
| for buf_name in snode.get_buffer_names(): |
| for snode in buffer_users[buf_name]: |
| unmet_deps[snode].remove(buf_name) |
| if len(unmet_deps[snode]) == 0: |
| heapq.heappush(ready, Runnable(snode)) |
| |
| def get_overlapping_candidate(): |
| """ |
| Return the next node in the ready queue that's neither a collective or |
| a wait. |
| """ |
| candidates = [ |
| x |
| for x in ready |
| if not contains_collective(x.snode) and not contains_wait(x.snode) |
| ] |
| if len(candidates) == 0: |
| return None |
| return min(candidates, key=lambda x: x.score) |
| |
| def schedule_collective_for_overlap(snode): |
| """ |
| Schedules collective node `snode`, along with one or more compute nodes |
| to overlap with it. The strategy is described in the comment of |
| `reorder_compute_for_overlap`. |
| """ |
| assert contains_collective(snode) |
| schedule(snode) |
| |
| collective_cost = snode_to_cost[snode] |
| while ( |
| collective_cost > 0 |
| and (candidate := get_overlapping_candidate()) is not None |
| ): |
| ready.remove(candidate) |
| schedule(candidate.snode) |
| collective_cost -= snode_to_cost[candidate.snode] |
| heapq.heapify(ready) |
| |
| while len(ready): |
| snode = heapq.heappop(ready).snode |
| if reorder_for_overlap and contains_collective(snode): |
| schedule_collective_for_overlap(snode) |
| else: |
| schedule(snode) |
| |
| for snode, deps in unmet_deps.items(): |
| assert len(deps) == 0, ( |
| "Detected unscheduled nodes. " |
| f"Nodes with unmet dependencies: {unmet_deps}" |
| ) |
| return scheduled |
| |
| |
| def decide_global_ordering_of_comms( |
| nodes: List[BaseSchedulerNode], name_to_buf, name_to_fused_node |
| ) -> List[BaseSchedulerNode]: |
| """ |
| Decide global ordering of comms, by just enforcing the ordering that's in the input graph |
| (might not be the same ordering as the eager mode program). |
| TODO: Come up with a better approach |
| """ |
| # If FSDP2 is used, we apply FSDP-specific passes. |
| if any( |
| is_fallback_op( |
| x.node, |
| { |
| torch.ops.fsdp.all_gather_copy_in.default, |
| torch.ops.fsdp.chunk_cat.default, |
| }, |
| ) |
| for x in nodes |
| ): |
| nodes = enforce_comm_ordering_for_fsdp(nodes, name_to_buf, name_to_fused_node) |
| |
| comm_nodes = [n for n in nodes if contains_collective(n)] |
| |
| for i in range(1, len(comm_nodes)): |
| # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm |
| mutating_buf = next(iter(comm_nodes[i].get_buffer_names())) |
| for buf in comm_nodes[i - 1].get_buffer_names(): |
| comm_nodes[i].add_fake_dep(WeakDep(buf, mutating_buf=mutating_buf)) |
| |
| return nodes |
| |
| |
| def estimate_op_runtime(snode: BaseSchedulerNode) -> float: |
| """ |
| Returns estimated op runtime in nanoseconds (ns) |
| """ |
| if config.estimate_op_runtime == "default": |
| runtime = snode.get_estimated_runtime() |
| else: |
| assert callable(config.estimate_op_runtime) |
| runtime = config.estimate_op_runtime(snode) |
| return runtime |
| |
| |
| def node_summary(snode): |
| detail = "" |
| if isinstance(snode.node, ir.ExternKernelOut): |
| detail = f" ({snode.node.python_kernel_name})" |
| out_tensor_info = "" |
| if ( |
| hasattr(snode.node, "layout") |
| and hasattr(snode.node.layout, "size") |
| and hasattr(snode.node.layout, "stride") |
| ): |
| out_tensor_info = ( |
| f" (size={snode.node.layout.size}, stride={snode.node.layout.stride})" |
| ) |
| node_name = "" |
| if hasattr(snode.node, "name"): |
| node_name = snode.node.name |
| return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name})" |
| |
| |
| def visualize_overlap(order): |
| total_est_runtime: float = 0.0 |
| cur_comm_node = None |
| for snode in order: |
| if cur_comm_node is None: |
| if contains_collective(snode): |
| total_est_runtime += estimate_op_runtime(snode) |
| cur_comm_node = snode.node |
| elif is_wait(snode.node): |
| raise AssertionError( |
| "Wait is not expected when there is no collective running" |
| ) |
| else: # exposed compute op |
| total_est_runtime += estimate_op_runtime(snode) |
| overlap_log.debug(f"{node_summary(snode)}") # noqa: G004 |
| else: # cur_comm_node is not None |
| if contains_collective(snode): |
| raise AssertionError( |
| "Found two collectives running at the same time. " |
| "`visualize_overlap` needs to be updated to handle this case" |
| ) |
| elif is_wait(snode.node): # end of this comm op |
| overlap_log.debug(f"{node_summary(snode)}") # noqa: G004 |
| cur_comm_node = None |
| else: # overlapped compute op |
| overlap_log.debug(f"| {node_summary(snode)}") # noqa: G004 |
| overlap_log.debug( |
| f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004 |
| ) |
| |
| |
| def reorder_compute_and_comm_for_overlap( |
| snodes: List[BaseSchedulerNode], |
| ) -> List[BaseSchedulerNode]: |
| order = snodes |
| |
| for p in config.reorder_for_compute_comm_overlap_passes: |
| if isinstance(p, str) and p in globals(): |
| p = globals()[p] # it is a builtin pass |
| if torch.distributed.get_rank() == 0: |
| overlap_log.debug( |
| f"==== Visualize overlap before reordering pass {p} ====" # noqa: G004 |
| ) |
| try: |
| visualize_overlap(order) |
| except Exception as e: |
| overlap_log.debug(str(e)) |
| order = p(order) # type: ignore[operator] |
| if torch.distributed.get_rank() == 0: |
| overlap_log.debug( |
| f"==== Visualize overlap after reordering pass {p} ====" # noqa: G004 |
| ) |
| try: |
| visualize_overlap(order) |
| except Exception as e: |
| overlap_log.debug(str(e)) |
| return order |
| |
| |
| def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None: |
| try: |
| import torch.distributed._composable.fsdp._fsdp_collectives |
| |
| assert torch.distributed.is_available() |
| # Assert existence of these ops |
| assert ( |
| torch.ops._c10d_functional.all_gather_into_tensor |
| and torch.ops._c10d_functional.all_gather_into_tensor_out |
| ) |
| except (ImportError, AttributeError, AssertionError): |
| return |
| |
| from .pattern_matcher import ( |
| CallFunction, |
| KeywordArg, |
| Match, |
| PatternMatcherPass, |
| register_graph_pattern, |
| ) |
| |
| """ |
| all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...); |
| getitem = all_gather_copy_in[0]; |
| (getitem_1 = all_gather_copy_in[1];) # optional |
| |
| all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem, ...); |
| |
| -> |
| |
| all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...); |
| getitem = all_gather_copy_in[0]; |
| getitem_1 = all_gather_copy_in[1]; |
| |
| all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor_out.default(getitem, ..., out=getitem_1); |
| """ |
| |
| def remove_unused_getitem(g): |
| # Remove `getitem_X = all_gather_copy_in[1]` which is never used. |
| node_list = list(g.nodes) |
| for n in node_list: |
| if ( |
| n.target == operator.getitem |
| and n.args[0].target is torch.ops.fsdp.all_gather_copy_in.default |
| and n.args[1] == 1 |
| ): |
| g.erase_node(n) |
| |
| graph_pass = PatternMatcherPass() |
| |
| @register_graph_pattern( |
| CallFunction( |
| torch.ops._c10d_functional.all_gather_into_tensor.default, |
| CallFunction( |
| operator.getitem, |
| CallFunction( |
| torch.ops.fsdp.all_gather_copy_in.default, |
| KeywordArg("all_gather_inputs"), |
| KeywordArg("inp_split_sizes"), |
| KeywordArg("all_gather_input_numel"), |
| KeywordArg("world_size"), |
| KeywordArg("rank"), |
| KeywordArg("dtype"), |
| KeywordArg("device"), |
| ), |
| KeywordArg("item_idx"), |
| ), |
| KeywordArg("group_size"), |
| KeywordArg("group_name"), |
| ), |
| pass_dict=graph_pass, |
| extra_check=lambda match: match.kwargs["item_idx"] == 0, |
| ) |
| def reinplace_all_gather(match: Match, *args, **kwargs): |
| def repl( |
| *args, |
| ): |
| copy_in_args = args[:-2] |
| group_size = args[-2] |
| group_name = args[-1] |
| all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default( |
| *copy_in_args |
| ) |
| getitem = all_gather_copy_in[0] |
| getitem_1 = all_gather_copy_in[1] |
| all_gather_into_tensor = ( |
| torch.ops._c10d_functional.all_gather_into_tensor_out.default( |
| getitem, group_size, group_name, out=getitem_1 |
| ) |
| ) |
| return all_gather_into_tensor |
| |
| match.replace_by_example( |
| repl, |
| [ |
| kwargs["all_gather_inputs"], |
| kwargs["inp_split_sizes"], |
| kwargs["all_gather_input_numel"], |
| kwargs["world_size"], |
| kwargs["rank"], |
| kwargs["dtype"], |
| kwargs["device"], |
| kwargs["group_size"], |
| kwargs["group_name"], |
| ], |
| ) |
| |
| remove_unused_getitem(graph) |
| graph_pass.apply(graph) # type: ignore[arg-type] |
| |
| |
| def get_op_idx(snode): |
| assert not isinstance( |
| snode, |
| ( |
| torch._inductor.scheduler.FusedSchedulerNode, |
| torch._inductor.scheduler.GroupedSchedulerNode, |
| ), |
| ) |
| return int(snode.get_name()[2:]) |
| |
| |
| def enforce_comm_ordering_for_fsdp( |
| snodes: List[torch._inductor.scheduler.BaseSchedulerNode], |
| name_to_buf: Dict[str, torch._inductor.scheduler.SchedulerBuffer], |
| name_to_fused_node: Dict[str, BaseSchedulerNode], |
| ) -> List[torch._inductor.scheduler.BaseSchedulerNode]: |
| from . import scheduler |
| |
| new_order: list[BaseSchedulerNode] = [] |
| scheduled = set() |
| ag_exists = False |
| rs_exists = False |
| ag_grouped_node_to_wait_grouped_node = {} |
| rs_grouped_node_to_wait_grouped_node = {} |
| snode_name_to_final_snode = {} |
| |
| def _create_group_node(snodes_to_group): |
| group_node = scheduler.GroupedSchedulerNode.create(snodes_to_group) |
| for snode in snodes_to_group: |
| snode_name_to_final_snode[snode.get_name()] = group_node |
| snode_name_to_final_snode[group_node.get_name()] = group_node |
| return group_node |
| |
| # Create grouped nodes for specific sets of ops |
| for snode in snodes: |
| # Case 1: Handle AllGather |
| if is_collective( |
| snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor_out.default |
| ) and any( |
| is_fallback_op( |
| name_to_fused_node[x].node, torch.ops.fsdp.all_gather_copy_in.default |
| ) |
| for x in snode.ancestors |
| ): |
| ag_exists = True |
| ag_snode = snode |
| ag_related_snode_set: set[scheduler.BaseSchedulerNode] = set() |
| |
| # Find the "cast + copy_in + getitem + all_gather" code block |
| find_recursive_deps_of_node( |
| ag_snode, |
| ag_related_snode_set, |
| name_to_buf, |
| name_to_fused_node, |
| ) |
| |
| # Find the "all_gather + all_gather_wait_tensor + copy_out + set_" code block |
| allowed_ops = { |
| torch.ops._c10d_functional.all_gather_into_tensor_out.default, |
| torch.ops._c10d_functional.wait_tensor.default, |
| torch.ops.fsdp.split_with_sizes_copy.default, |
| torch.ops.aten.set_.source_Tensor, |
| } |
| find_recursive_users_of_node( |
| ag_snode, |
| ag_related_snode_set, |
| name_to_buf, |
| name_to_fused_node, |
| criteria_cb=lambda x: not ( |
| isinstance(x, scheduler.NopKernelSchedulerNode) |
| or ( |
| isinstance(x, scheduler.ExternKernelSchedulerNode) |
| and x.node.op_overload in allowed_ops # type: ignore[union-attr] |
| ) |
| ), |
| ) |
| |
| # sort nodes by original operation order |
| ag_related_snodes = sorted( |
| ag_related_snode_set, key=lambda x: get_op_idx(x) |
| ) |
| |
| # In the "reuse layer" case, some ops in the 2nd all-gather code block could also |
| # depend on ops in the 1st all-gather code block, and we don't want to group them together. |
| end_idx_of_current_ag_block = len(ag_related_snodes) |
| copy_out_count = 0 |
| for i in range(len(ag_related_snodes)): |
| cur_snode = ag_related_snodes[i] |
| if is_fallback_op( |
| cur_snode.node, torch.ops.fsdp.split_with_sizes_copy.default |
| ): |
| copy_out_count += 1 |
| if copy_out_count > 1: |
| end_idx_of_current_ag_block = i |
| break |
| |
| ag_related_snodes = ag_related_snodes[:end_idx_of_current_ag_block] |
| |
| # Group "cast + copy_in + getitem + all_gather" into one GroupedSchedulerNode |
| wait_node_idx = None |
| for i in range(len(ag_related_snodes) - 1): |
| if isinstance(ag_related_snodes[i + 1].node, ir._WaitKernel): |
| wait_node_idx = i + 1 |
| break |
| assert wait_node_idx is not None |
| ag_group_node = _create_group_node(ag_related_snodes[:wait_node_idx]) |
| |
| # Group "all_gather_wait_tensor + copy_out + set_" into one GroupedSchedulerNode |
| ag_wait_group_node = _create_group_node(ag_related_snodes[wait_node_idx:]) |
| |
| ag_grouped_node_to_wait_grouped_node[ag_group_node] = ag_wait_group_node |
| |
| # Case 2: Handle ReduceScatter |
| elif is_fallback_op(snode.node, torch.ops.fsdp.chunk_cat.default): |
| rs_exists = True |
| rs_snode = snode |
| |
| # Find the "reduce_scatter copy-in + reduce_scatter comm + reduce_scatter wait" code block |
| rs_related_snode_set: set[scheduler.BaseSchedulerNode] = set() |
| find_recursive_users_of_node( |
| rs_snode, |
| rs_related_snode_set, |
| name_to_buf, |
| name_to_fused_node, |
| ) |
| |
| # sort nodes by original operation order |
| rs_related_snodes = sorted( |
| rs_related_snode_set, key=lambda x: get_op_idx(x) |
| ) |
| |
| # Group "reduce_scatter copy-in + reduce_scatter comm" into one GroupedSchedulerNode |
| wait_node_idx = None |
| for i in range(len(rs_related_snodes) - 1): |
| if isinstance(rs_related_snodes[i + 1].node, ir._WaitKernel): |
| wait_node_idx = i + 1 |
| break |
| assert wait_node_idx is not None |
| rs_group_node = _create_group_node(rs_related_snodes[:wait_node_idx]) |
| |
| # Group "reduce_scatter wait + related output nodes" into one GroupedSchedulerNode |
| rs_wait_group_node = _create_group_node(rs_related_snodes[wait_node_idx:]) |
| |
| rs_grouped_node_to_wait_grouped_node[rs_group_node] = rs_wait_group_node |
| |
| assert len(snode_name_to_final_snode) > 0 |
| if ag_exists: |
| assert len(ag_grouped_node_to_wait_grouped_node) > 0 |
| if rs_exists: |
| assert len(rs_grouped_node_to_wait_grouped_node) > 0 |
| |
| # Build the new node schedule, taking GroupedSchedulerNode into account |
| for snode in snodes: |
| if snode.get_name() in snode_name_to_final_snode: |
| snode = snode_name_to_final_snode[snode.get_name()] |
| if snode in scheduled: |
| continue |
| new_order.append(snode) |
| scheduled.add(snode) |
| |
| # Enforce AllGather ordering: previous AllGather's "wait then copy_out" group node must run |
| # before next AllGather's "copy_in then AG" group node |
| prev_ag_wait = None |
| for ag_group_node, wait_group_node in ag_grouped_node_to_wait_grouped_node.items(): |
| if prev_ag_wait is not None: |
| mutating_buf = next(iter(ag_group_node.get_buffer_names())) |
| for o in prev_ag_wait.get_outputs(): |
| ag_group_node.add_fake_dep( |
| WeakDep(o.get_name(), mutating_buf=mutating_buf) |
| ) |
| prev_ag_wait = wait_group_node |
| |
| # Enforce ReduceScatter ordering: previous ReduceScatter's "wait" group node must run |
| # before next ReduceScatter's "copy_in then RS" group node |
| prev_rs_wait = None |
| for rs_group_node, wait_group_node in rs_grouped_node_to_wait_grouped_node.items(): |
| if prev_rs_wait is not None: |
| mutating_buf = next(iter(rs_group_node.get_buffer_names())) |
| for o in prev_rs_wait.get_outputs(): |
| rs_group_node.add_fake_dep( |
| WeakDep(o.get_name(), mutating_buf=mutating_buf) |
| ) |
| prev_rs_wait = wait_group_node |
| |
| return new_order # type: ignore[return-value] |