| import logging |
| import operator |
| from dataclasses import dataclass |
| from enum import auto, Enum |
| from functools import partial |
| from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union |
| |
| import torch |
| import torch.distributed._spmd.experimental_ops |
| import torch.fx as fx |
| |
| from torch.distributed._spmd.comm_tensor import _get_tracer |
| from torch.distributed._spmd.graph_utils import OP |
| from torch.distributed._spmd.log_utils import get_logger |
| |
| from torch.distributed._tensor import DeviceMesh, DTensor |
| from torch.distributed._tensor.op_schema import OpSchema |
| from torch.distributed._tensor.placement_types import ( |
| _Partial, |
| DTensorSpec, |
| Placement, |
| Replicate, |
| Shard, |
| TensorMeta, |
| ) |
| from torch.distributed._tensor.redistribute import redistribute_local_tensor |
| from torch.fx.experimental.proxy_tensor import make_fx, proxy_slot |
| from torch.utils import _pytree as pytree |
| from torch.utils._pytree import tree_flatten, tree_map, tree_map_only, tree_unflatten |
| |
| |
| logger: Optional[logging.Logger] = None |
| |
| aten = torch.ops.aten |
| |
| |
| class TrainingPhase(Enum): |
| FORWARD = auto() |
| BACKWARD = auto() |
| |
| |
| @dataclass |
| class Schema: |
| mesh: DeviceMesh |
| placements: List[Placement] |
| |
| |
| @dataclass |
| class DSymInt: |
| """DSymInt represents a value retrieved by a SymInt op from a DTensor. |
| |
| DSymInt helps View and Factory ops to determine the placement and shape of the |
| output tensor, as those operators either do not have an input DTensor or |
| the input DTensor is insufficient to determine the output tensor's placement. |
| """ |
| |
| global_value: int # value that the SymInt evaluates to |
| local_value: int # vaue that this SymInt evaluates to on the local shard |
| mesh: DeviceMesh # device mesh of the DTensor where this SymInt is retrieved from |
| |
| def is_shard(self) -> bool: |
| return self.local_value != self.global_value |
| |
| @classmethod |
| def from_node(cls, node: fx.Node, dtensor: DTensor) -> "DSymInt": |
| dim: int = 0 |
| if node.target == aten.sym_size: |
| dim = cast(int, node.args[1]) |
| return cls( |
| global_value=dtensor.size(dim), |
| local_value=dtensor.to_local().size(dim), |
| mesh=dtensor.device_mesh, |
| ) |
| elif node.target == aten.sym_numel: |
| return cls( |
| global_value=dtensor.numel(), |
| local_value=dtensor.to_local().numel(), |
| mesh=dtensor.device_mesh, |
| ) |
| elif node.target == aten.sym_stride: |
| dim = cast(int, node.args[1]) |
| return cls( |
| global_value=dtensor.stride(dim), |
| local_value=dtensor.to_local().stride(dim), |
| mesh=dtensor.device_mesh, |
| ) |
| else: |
| raise NotImplementedError(f"DSymInt does not support {node.target}") |
| |
| |
| def _is_partial_dtensor(obj: Any) -> bool: |
| """Check if object is 1) DTensor and 2) with any placement of _Partial.""" |
| if not isinstance(obj, DTensor): |
| return False |
| |
| is_partial = False |
| for placement in obj.placements: |
| if isinstance(placement, _Partial): |
| is_partial = True |
| break |
| |
| return is_partial |
| |
| |
| def _dispatch_with_local_tensors( |
| op: torch._ops.OpOverload, |
| local_args: Tuple[Any, ...], |
| kwargs: Optional[Dict[str, Any]] = None, |
| specs: Optional[ |
| Dict[ |
| torch.Tensor, |
| Tuple[torch.Size, DeviceMesh, Sequence[Placement], Sequence[Placement]], |
| ] |
| ] = None, |
| ) -> Any: |
| if kwargs is None: |
| kwargs = {} |
| if specs is None: |
| specs = {} |
| |
| def redistribute(arg: Any) -> Any: |
| tensor_shape, mesh, current_placement, target_placement = specs[arg] |
| tensor_meta = TensorMeta( |
| tensor_shape, |
| stride=arg.stride(), |
| dtype=arg.dtype, |
| ) |
| current_spec = DTensorSpec( |
| mesh, tuple(current_placement), tensor_meta=tensor_meta |
| ) |
| target_spec = DTensorSpec( |
| mesh, tuple(target_placement), tensor_meta=tensor_meta |
| ) |
| |
| return ( |
| redistribute_local_tensor(arg, current_spec, target_spec) # type: ignore[index] |
| if isinstance(arg, torch.Tensor) and arg in specs # type: ignore[operator] |
| else arg |
| ) |
| |
| # TODO: this is broken because it won't redistributed potential tensors on the kwargs |
| return op(*tree_map(redistribute, local_args), **kwargs) |
| |
| |
| # Figure out how to specify a type spec for the return specs value |
| # without the entire structure. |
| # pyre-fixme |
| def _update_specs_for_redistribute(args, target_schema, redistribute): |
| # Code adapted from pack_args_kwargs_with_local_tensor |
| flatten_args, args_tree_spec = tree_flatten(args) |
| flatten_args_schema = pytree.tree_leaves(target_schema.args_schema) |
| |
| specs: Dict[ |
| torch.Tensor, |
| Tuple[ |
| torch.Size, |
| DeviceMesh, |
| Sequence[Placement], |
| Sequence[Placement], |
| ], |
| ] = {} |
| for i, arg in enumerate(flatten_args): |
| if isinstance(arg, DTensor): |
| if redistribute: |
| specs[arg._local_tensor] = ( |
| arg.size(), |
| flatten_args_schema[i].mesh, |
| arg.placements, |
| flatten_args_schema[i].placements, |
| ) |
| flatten_args_schema[i] = arg._local_tensor |
| |
| unflattened_args = tree_unflatten(flatten_args_schema, args_tree_spec) |
| return specs, unflattened_args |
| |
| |
| # When no tensor redistribution is required, we only need to update non-tensor args |
| # of the node according to op_schema and avoid building a GraphModule just for the |
| # node. |
| def _update_node_from_op_schema(node: torch.fx.Node, op_schema: OpSchema) -> None: |
| flat_args, args_tree_spec = tree_flatten(node.args) |
| flat_args_schema = pytree.tree_leaves(op_schema.args_schema) |
| |
| def is_sym_int_or_int(arg: Union[int, torch.fx.Node]) -> bool: |
| if isinstance(arg, torch.fx.Node): |
| return arg.target in [ |
| aten.sym_size, |
| aten.sym_numel, |
| aten.sym_stride, |
| ] |
| return isinstance(arg, int) |
| |
| assert len(flat_args) == len(flat_args_schema) |
| for i, (arg, arg_schema) in enumerate(zip(flat_args, flat_args_schema)): |
| if is_sym_int_or_int(arg) and isinstance(arg_schema, int): |
| flat_args[i] = arg_schema |
| |
| args = tree_unflatten(flat_args, args_tree_spec) |
| for idx, arg in enumerate(args): |
| node.update_arg(idx, arg) |
| return None |
| |
| |
| def _remap_arg(node_to_obj: Dict[fx.Node, Any], arg: Any) -> Any: |
| if isinstance(arg, torch.fx.Node): |
| obj = node_to_obj[arg] |
| if _get_tracer(): |
| # This is a shared arg, already has a tracer from previous |
| # tracing. Delete the tracer. |
| del cast(Dict[Any, Any], obj.__dict__)[proxy_slot] |
| return obj |
| else: |
| return arg |
| |
| |
| def unpack_sizes_and_dims( |
| sizes: List[Union[DSymInt, int]], mesh: DeviceMesh |
| ) -> Tuple[List[int], List[Placement]]: |
| local_sizes: List[int] = [ |
| s.local_value if isinstance(s, DSymInt) else s for s in sizes |
| ] |
| placements: List[Placement] = [ |
| Shard(i) |
| for i, a in enumerate(sizes) |
| if (isinstance(a, DSymInt) and a.is_shard()) |
| ] or [Replicate()] |
| |
| assert len(placements) == mesh.ndim, ( |
| f"The number of sharded dimensions ({len(placements)}) must " |
| f"match number of dimensions in device mesh ({mesh.ndim})." |
| ) |
| |
| return local_sizes, placements |
| |
| |
| def binop_sym_int_consumer_rule(node: fx.Node, args: Tuple[Any, ...]) -> DTensor: |
| assert len(args) == 2, f"Expect two args but got op {node.target} with args {args}" |
| assert isinstance( |
| args[0], DTensor |
| ), f"Expect 1st argument to be DTensor but got {args[0]}" |
| assert isinstance(args[1], list), f"Expect 2nd argument as list but got {args[1]}" |
| |
| # extract sharded dimensions in the size list, the output DTensor should |
| # follow these placements. |
| local_sizes, placements = unpack_sizes_and_dims(args[1], args[0].device_mesh) |
| |
| # set node args to real int sizes. |
| node.args = (node.args[0], local_sizes) |
| op = cast(torch._ops.OpOverload, node.target) |
| return DTensor.from_local( |
| local_tensor=op(args[0]._local_tensor, local_sizes), |
| device_mesh=args[0].device_mesh, |
| placements=placements, |
| run_check=False, |
| ) |
| |
| |
| def slice_backwad_sym_int_consumer_rule( |
| node: fx.Node, args: Tuple[Any, ...] |
| ) -> DTensor: |
| grad_output, input_sizes, dim, start, end, step = args |
| |
| local_sizes: List[int] = [ |
| s.local_value if isinstance(s, DSymInt) else s for s in input_sizes |
| ] |
| |
| input_tensor = torch.zeros( |
| local_sizes, device=grad_output.device, dtype=grad_output.dtype |
| ) |
| return DTensor.from_local( |
| local_tensor=torch.slice_scatter( |
| input_tensor, grad_output.to_local(), dim, start, end, step |
| ), |
| device_mesh=grad_output.device_mesh, |
| placements=grad_output.placements, |
| run_check=False, |
| ) |
| |
| |
| def factory_with_sizes_rule( |
| node: fx.Node, |
| args: Tuple[Any, ...], |
| kwargs: Dict[str, Any], |
| default_mesh: DeviceMesh, |
| ) -> DTensor: |
| flat_args = pytree.arg_tree_leaves(*args) |
| assert not any(isinstance(a, DTensor) for a in flat_args), ( |
| f"Not expect DTensor argument for factory op, but got {node.target} " |
| f"with arguments {args}." |
| ) |
| assert isinstance(args[0], list), f"Expect 2nd argument as list but got {args[1]}" |
| |
| local_sizes, placements = unpack_sizes_and_dims(args[0], default_mesh) |
| node.args = (local_sizes, *args[1:]) |
| op = cast(torch._ops.OpOverload, node.target) |
| return DTensor.from_local( |
| local_tensor=op(*node.args, **kwargs), |
| device_mesh=default_mesh, |
| placements=placements, |
| run_check=False, |
| ) |
| |
| |
| def factory_arange_rule( |
| node: fx.Node, |
| args: Tuple[Any, ...], |
| kwargs: Dict[str, Any], |
| default_mesh: DeviceMesh, |
| ) -> DTensor: |
| node.args = tree_map(lambda a: a.local_value if isinstance(a, DSymInt) else a, args) |
| op = cast(torch._ops.OpOverload, node.target) |
| return DTensor.from_local( |
| local_tensor=op(*node.args, **kwargs), |
| device_mesh=default_mesh, |
| placements=[Replicate()], |
| run_check=False, |
| ) |
| |
| |
| def default_factory_op_rule( |
| node: fx.Node, |
| args: Tuple[Any, ...], |
| kwargs: Dict[str, Any], |
| default_mesh: DeviceMesh, |
| ) -> DTensor: |
| node.args, node.kwargs = args, kwargs |
| op = cast(torch._ops.OpOverload, node.target) |
| return DTensor.from_local( |
| local_tensor=op(*node.args, **node.kwargs), |
| device_mesh=default_mesh, |
| placements=[Replicate()], |
| run_check=False, |
| ) |
| |
| |
| # Dispatch override for view and factory ops that consume SymInt arguments, |
| # where the output spec should follow dimension placement where the SymInt comes |
| # from. |
| VIEW_SYM_INT_CONSUMERS: Dict[torch._ops.OpOverload, Callable] = { |
| aten._unsafe_view.default: binop_sym_int_consumer_rule, |
| aten.expand.default: binop_sym_int_consumer_rule, |
| aten.slice_backward.default: slice_backwad_sym_int_consumer_rule, |
| aten.view.default: binop_sym_int_consumer_rule, |
| } |
| |
| FACTORY_SYM_INT_CONSUMERS: Dict[torch._ops.OpOverload, Callable] = { |
| aten.full.default: factory_with_sizes_rule, |
| aten.arange.default: factory_arange_rule, |
| aten.arange.start: factory_arange_rule, |
| } |
| |
| |
| # Dispatch override for factory ops, as DTensor cannot propogate sharding spec |
| # without DTensor inputs. |
| FACTORY_OPS: Dict[torch._ops.OpOverload, Callable] = { |
| aten.scalar_tensor.default: default_factory_op_rule, |
| aten.arange.start: default_factory_op_rule, |
| aten.zeros.default: default_factory_op_rule, |
| } |
| |
| |
| def _get_dtensor_dispatch_graph( |
| node: fx.Node, |
| node_to_obj: Dict[fx.Node, Any], |
| *, |
| force_make_fx: bool = False, |
| default_mesh: Optional[DeviceMesh] = None, |
| ) -> Optional[fx.GraphModule]: |
| with torch.no_grad(): |
| # Args should be a list of objects post remapping. |
| args = tree_map(partial(_remap_arg, node_to_obj), node.args) |
| kwargs = tree_map(partial(_remap_arg, node_to_obj), node.kwargs) |
| |
| op_overload = cast(torch._ops.OpOverload, node.target) |
| |
| if any( |
| a.is_shard() |
| for a in pytree.arg_tree_leaves(*args) |
| if isinstance(a, DSymInt) |
| ): |
| if op_overload in VIEW_SYM_INT_CONSUMERS: |
| assert len(kwargs) == 0, f"Expect empty kwargs, but got {kwargs}" |
| node_to_obj[node] = VIEW_SYM_INT_CONSUMERS[op_overload](node, args) |
| return None |
| elif op_overload in FACTORY_SYM_INT_CONSUMERS: |
| assert default_mesh is not None, "Requires default mesh for factory ops" |
| node_to_obj[node] = FACTORY_SYM_INT_CONSUMERS[op_overload]( |
| node, args, kwargs, default_mesh |
| ) |
| return None |
| else: |
| assert isinstance(logger, logging.Logger) |
| logger.warning( |
| "Assuming using local_value from SymInt for %s" |
| "is mathematically correct. Full args are %s.", |
| op_overload, |
| args, |
| ) |
| |
| if node.target == aten.view.default: |
| # HACK: this is a hack to get around with the fact that some |
| # view operations on a "global" tensor is invalid usage |
| # but somehow the view operation on the batch input might hit it |
| # so we convert the view op to reshape before calling DTensor |
| op_overload = aten.reshape.default |
| |
| # DSymInt args are not sharded on any dimension, local value and global |
| # value should be the same |
| args = tree_map(lambda a: a.local_value if isinstance(a, DSymInt) else a, args) |
| kwargs = tree_map( |
| lambda a: a.local_value if isinstance(a, DSymInt) else a, kwargs |
| ) |
| |
| if op_overload in FACTORY_OPS: |
| # Don't pass factory ops to DTensor dispatch, as DTensor cannot |
| # propagate sharding spec without DTensor inputs. |
| node_to_obj[node] = FACTORY_OPS[op_overload]( |
| node, args, kwargs, default_mesh |
| ) |
| return None |
| |
| dispatch = partial( |
| _dispatch_with_local_tensors, |
| op_overload, |
| kwargs=kwargs, |
| specs=args, |
| ) |
| |
| gm = make_fx(dispatch, _allow_non_fake_inputs=False)(args) |
| # FIXME(@wanchaol, @mrshenli): the above seems to accidentally captured |
| # DeviceMesh tensor ops when handling inplace operators? The ``_to_copy`` is |
| # not connected to graph output. So, using DCE to get rid of it, but this |
| # doesn't look correct. |
| # |
| # The following operators appear in the captured graph, where the dtype is |
| # torch.int64. |
| # |
| # get_attr _tensor_constant0 _tensor_constant0 () |
| # call_function transpose aten.transpose.int (_tensor_constant0, -1, 0) |
| # call_function view aten.view.default (transpose, [-1, 2]) |
| # call_function view_1 aten.view.default (view, [2]) |
| # call_function _to_copy aten._to_copy.default (view_1,) |
| gm.graph.eliminate_dead_code() |
| |
| return gm |
| |
| |
| def _build_dummy_add_graph( |
| dt: DTensor, node_to_obj: Dict[fx.Node, Any] |
| ) -> Tuple[fx.GraphModule, Any]: |
| """Create a graph for a dummy add function from a partial DTensor. |
| |
| This dummy add is used for triggering all_reduce on a Partial DTensor |
| during the DTensor expansion of the traced graph. |
| Also returns the actual DTensor after resharding. |
| """ |
| |
| def dummy_add(grad: torch.Tensor, zero: torch.Tensor) -> torch.Tensor: |
| return grad + zero |
| |
| grad: torch.Tensor = dt._local_tensor |
| zero: torch.Tensor = torch.zeros_like(dt._local_tensor) |
| |
| traced_add = make_fx(dummy_add)(grad, zero) |
| |
| placeholders = [n for n in traced_add.graph.nodes if n.op == OP.PLACEHOLDER] |
| call_functions = [n for n in traced_add.graph.nodes if n.op == OP.CALL_FUNCTION] |
| assert len(placeholders) == 2 |
| assert len(call_functions) == 1 |
| node_to_obj[placeholders[0]] = dt |
| node_to_obj[placeholders[1]] = DTensor.from_local( |
| zero, dt.device_mesh, [Replicate()], run_check=False |
| ) |
| |
| traced_dispatch = _get_dtensor_dispatch_graph( |
| call_functions[0], node_to_obj, force_make_fx=True |
| ) |
| assert traced_dispatch is not None |
| |
| # TODO(anj): This depends on the call function node -> actual DTensor output |
| # mapping that we want to avoid for SPMD expansion |
| return traced_dispatch, node_to_obj[call_functions[0]] |
| |
| |
| def _convert_output( |
| gm: fx.GraphModule, |
| node: fx.Node, |
| node_to_obj: Dict[fx.Node, Any], |
| ) -> fx.Node: |
| new_args = [] |
| has_partial = False |
| for argument in node.args[0]: # type: ignore[union-attr] |
| if not isinstance(argument, fx.Node): |
| new_args.append(argument) |
| continue |
| |
| obj = node_to_obj[argument] |
| |
| if not _is_partial_dtensor(obj): |
| new_args.append(argument) |
| continue |
| |
| has_partial = True |
| |
| # we know it's a dtensor from is partial DT check... |
| dt = cast(DTensor, obj) |
| |
| traced_dispatch, result_obj = _build_dummy_add_graph(dt, node_to_obj) |
| |
| wait = [ |
| n |
| for n in traced_dispatch.graph.nodes |
| if n.name == "wait_comm" or n.name == "wait_tensor" |
| ] |
| add = [n for n in traced_dispatch.graph.nodes if n.name == "add"] |
| assert len(wait) == 1 and len(add) == 1 |
| |
| # remove add node and replace it with wait node |
| add[0].replace_all_uses_with(wait[0]) |
| traced_dispatch.graph.eliminate_dead_code() |
| # also update the actual DTensor corresponding to the node |
| # TODO(anj): We require mapping of the final DTensor output to the wait |
| # comm node. |
| node_to_obj[wait[0]] = result_obj |
| |
| value_remap: Dict[fx.Node, fx.Node] = {} |
| for dtn in traced_dispatch.graph.nodes: |
| if dtn.op == OP.PLACEHOLDER: |
| # do nothing, ignore placeholders, as it has |
| # already been prepared in value_remap |
| value_remap[dtn] = argument |
| elif dtn.op == OP.OUTPUT: |
| assert ( |
| len(dtn.args) == 1 and len(dtn.args[0]) == 1 |
| ), f"Expecting single output, but got {dtn.args} {len(dtn.args)}" |
| new_args.append(value_remap[dtn.args[0][0]]) |
| # the concrete DTensor value of output was added when creating the |
| # inner graph (in _build_dummy_add_graph). Just add it to the final |
| # output node so that we can report the final output specs correctly. |
| # TODO(anj): We are depending on the concrete DTensor output of the dummy add. |
| node_to_obj[value_remap[dtn.args[0][0]]] = node_to_obj[dtn.args[0][0]] |
| |
| else: |
| if dtn.op == OP.GET_ATTR: |
| setattr( |
| gm, |
| dtn.target, |
| getattr(traced_dispatch, dtn.target), |
| ) |
| with gm.graph.inserting_before(node): |
| value_remap[dtn] = gm.graph.node_copy(dtn, lambda n: value_remap[n]) |
| if has_partial: |
| gm.graph.erase_node(node) |
| return gm.graph.output(new_args) |
| else: |
| return node |
| |
| |
| def _rebuild_graph( |
| gm: fx.GraphModule, |
| node_replacements: Dict[torch.fx.Node, torch.fx.GraphModule], |
| ) -> None: |
| # replace nodes in local traced graph with DTensor's dispatch graph |
| for node in gm.graph.nodes: |
| if node not in node_replacements: |
| continue |
| |
| traced_dispatch = node_replacements[node] |
| # Map DT's dispatch graph input placeholder nodes to the ones in |
| # local traced graph. It uses index-based accessing, which is |
| # brittle, just for testing purpose. |
| flatten_args = pytree.arg_tree_leaves(*node.args) |
| i, value_remap = 0, {} |
| for dtn in traced_dispatch.graph.nodes: |
| if dtn.op == OP.PLACEHOLDER: |
| value_remap[dtn] = flatten_args[i] |
| i += 1 |
| |
| # insert DT's dispatch graph to traced local graph. |
| with gm.graph.inserting_before(node): |
| for dtn in traced_dispatch.graph.nodes: |
| if dtn.op == OP.PLACEHOLDER: |
| # do nothing, ignore placeholders, as it has already |
| # been prepared in value_remap |
| pass |
| elif dtn.op == OP.OUTPUT: |
| assert ( |
| len(dtn.args) == 1 |
| ), f"Expecting single output, but got {dtn.args} {len(dtn.args[0])}" |
| outputs = dtn.args[0] |
| # we currently support two very specific types of output |
| # 1. single output |
| # 2. multiple outputs resulting from getitem of all elements of tuple |
| if len(outputs) == 1: |
| # for single output, we replace the node with the single node |
| output = outputs[0] |
| else: |
| # for multiple outputs, we check that these outputs correspond |
| # to all elements of a tuple. In that case, we replace |
| # uses of the output directly with the original tuple |
| source = None |
| for i, out in enumerate(outputs): |
| # we allow None outputs for certain items in the tuple |
| if out is None: |
| continue |
| assert out.op == "call_function" |
| assert out.target.__module__ == "_operator" |
| assert out.target.__name__ == "getitem" |
| assert source is None or source == out.args[0] |
| source = out.args[0] |
| assert out.args[1] == i |
| assert source is not None |
| output = source |
| |
| new_node = value_remap[output] |
| node.replace_all_uses_with(new_node) |
| else: |
| value_remap[dtn] = gm.graph.node_copy(dtn, lambda n: value_remap[n]) |
| if all( |
| isinstance(n.target, torch._ops.OpOverload) |
| and n.target._schema.name.startswith( |
| ("aten::_foreach", "aten::_fused_adam") |
| ) |
| for n in [dtn, node] |
| ): |
| # FIXME(@mrshenli): This is a temporary solution enable |
| # foreach ops. The problem is that foreach ops returns |
| # List[Tensor], but make_fx will flatten that before |
| # passing those tensors to output node, which will |
| # introduce additional getitem nodes. These redundant |
| # getitem nodes breaks graph correctness as we cannot do |
| # getitem(getitem(foreach_out, 0), 0). This temporary |
| # solution skips getitem nodes in DTensor expanded |
| # subgraphs. |
| node.replace_all_uses_with(value_remap[dtn]) |
| break |
| # explicitly erase node instead of relying on DCE, as DCE does not |
| # remove inplace copy_ correctly. |
| gm.graph.erase_node(node) |
| |
| gm.graph.eliminate_dead_code() |
| gm.recompile() |
| |
| |
| def _get_last_consumer_to_nodes( |
| graph: fx.Graph, |
| ) -> Dict[fx.Node, List[fx.Node]]: |
| # Run through reverse nodes and record the first instance of a use |
| # of a given node. This represents the *last* use of the node in the |
| # execution order of the program, which we will use to free unused |
| # values |
| node_to_last_consumer: Dict[fx.Node, fx.Node] = {} |
| last_consumer_to_nodes: Dict[fx.Node, List[fx.Node]] = {} |
| |
| def _register_final_consumer(arg_node: fx.Node, consumer: fx.Node) -> None: |
| if arg_node not in node_to_last_consumer: |
| node_to_last_consumer[arg_node] = consumer |
| last_consumer_to_nodes.setdefault(consumer, []).append(arg_node) |
| |
| for node in reversed(graph.nodes): |
| fx.node.map_arg( |
| node.args, lambda arg_node: _register_final_consumer(arg_node, node) |
| ) |
| fx.node.map_arg( |
| node.kwargs, |
| lambda kwarg_node: _register_final_consumer(kwarg_node, node), |
| ) |
| |
| return last_consumer_to_nodes |
| |
| |
| def _convert_to_distributed( |
| gm: fx.GraphModule, |
| inps: List[torch.Tensor], |
| schemas: List[Schema], |
| default_mesh: Optional[DeviceMesh] = None, |
| _allow_partial: bool = False, |
| ) -> Tuple[fx.GraphModule, Dict[str, Schema]]: |
| """Transform a graph module to a distributed graph module. |
| |
| Returns: |
| - transformed graph module |
| - map from output name to DTensorSpec |
| |
| """ |
| global logger |
| logger = get_logger("spmd_exp") |
| operators = {getattr(operator, name) for name in operator.__all__} |
| node_to_obj: Dict[fx.Node, Any] = {} |
| # map local op node in traced_f to its corresponding subgraph of |
| # DTensor ops. |
| node_replacements: Dict[torch.fx.Node, torch.fx.GraphModule] = {} |
| |
| last_consumer_to_nodes = _get_last_consumer_to_nodes(gm.graph) |
| |
| output_schemas: Dict[str, Schema] = {} |
| for i, node in enumerate(gm.graph.nodes): |
| assert logger is not None |
| logger.info("node%s: op=%s target=%s", i, node.op, node.target) |
| if node.op == OP.PLACEHOLDER: |
| assert i < len( |
| inps |
| ), f"got more placeholder nodes ({i + 1}) than inputs ({len(inps)})" |
| |
| # our example inputs are local shards. Create DTensors from them. |
| node_to_obj[node] = DTensor.from_local( |
| inps[i].clone(), # use clone to avoid modifications from inplace ops |
| schemas[i].mesh, |
| schemas[i].placements, |
| # prevent running this collective in backwards pass |
| run_check=False, |
| ) |
| elif isinstance(node.target, torch._ops.OpOverloadPacket): |
| dtensor = cast(DTensor, node_to_obj[node.args[0]]) |
| node_to_obj[node] = DSymInt.from_node(node, dtensor) |
| elif isinstance(node.target, torch._ops.OpOverload): |
| replacement = _get_dtensor_dispatch_graph( |
| node, node_to_obj, default_mesh=default_mesh |
| ) |
| if replacement is not None: |
| node_replacements[node] = replacement |
| elif node.op == OP.OUTPUT: |
| if not _allow_partial: |
| # Returns an expanded dummy add node that ensures |
| # that the partial output tensor has been converted |
| # to a replicated tensor. |
| node = _convert_output(gm, node, node_to_obj) |
| |
| # Save output sharding for the inputs to backward pass. |
| # TODO(anj): Pipe the output schema for the BW pass |
| # instead of requiring the full output DTensor to be |
| # materialized. |
| for inp_arg in node.args[0]: |
| if isinstance(inp_arg, fx.Node): |
| obj = node_to_obj[inp_arg] |
| if isinstance(obj, DTensor): |
| output_schemas[inp_arg.name] = Schema( |
| obj.device_mesh, obj.placements # type: ignore[arg-type] |
| ) |
| elif node.op == OP.CALL_FUNCTION: |
| args = tree_map(partial(_remap_arg, node_to_obj), node.args) |
| kwargs = tree_map(partial(_remap_arg, node_to_obj), node.kwargs) |
| |
| dsymints = list( |
| filter(lambda a: isinstance(a, DSymInt), args + tuple(kwargs.values())) |
| ) |
| |
| if node.target in operators and len(dsymints) > 0: |
| assert all( |
| dsymints[0].mesh == d.mesh for d in dsymints |
| ), "all DSymInts must have the same mesh. " |
| |
| local_args = tree_map_only(DSymInt, lambda a: a.local_value, args) |
| local_kwargs = tree_map_only(DSymInt, lambda a: a.local_value, kwargs) |
| |
| global_args = tree_map_only(DSymInt, lambda a: a.global_value, args) |
| global_kwargs = tree_map_only(DSymInt, lambda a: a.global_value, kwargs) |
| |
| node.args = local_args |
| node.kwargs = local_kwargs |
| |
| node_to_obj[node] = DSymInt( |
| local_value=node.target(*local_args, **local_kwargs), |
| global_value=node.target(*global_args, **global_kwargs), |
| mesh=dsymints[0].mesh, |
| ) |
| else: |
| assert len(dsymints) == 0, ( |
| "SPMD expansion does not support SymInt in non-operator " |
| f"nodes, got {node.target}." |
| ) |
| node_to_obj[node] = node.target(*args, **kwargs) |
| else: |
| raise ValueError(f"Unrecognized node.op type {node.op}") |
| |
| if node in last_consumer_to_nodes: |
| # Save memory by deleting objs that wont be used anymore. |
| for arg_node in last_consumer_to_nodes[node]: |
| del node_to_obj[arg_node] |
| |
| _rebuild_graph(gm, node_replacements) |
| |
| return gm, output_schemas |