| import dataclasses |
| import math |
| import operator |
| from typing import Any, Dict, Iterable, List, Optional, Tuple, Type |
| |
| import torch |
| from torch._subclasses.fake_tensor import FakeTensor |
| |
| from torch.export import ExportedProgram |
| from torch.utils._pytree import ( |
| _register_pytree_node, |
| Context, |
| FlattenFunc, |
| FromDumpableContextFn, |
| KeyPath, |
| keystr, |
| MappingKey, |
| SequenceKey, |
| ToDumpableContextFn, |
| UnflattenFunc, |
| ) |
| |
| |
| def _check_input_constraints_for_graph( |
| input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints |
| ): |
| def get_keystr(key_path: KeyPath) -> str: |
| """For a given index into the flat_args, return a human readable string |
| describing how to access it, e.g. "*args["foo"][0].bar" |
| """ |
| # Prefix the keypath with "*args" or "**kwargs" to make it clearer where |
| # the arguments come from. Ultimately we ought to serialize the |
| # original arg names for the best error message here. |
| args_kwargs_key_path = key_path[0] |
| assert isinstance(args_kwargs_key_path, SequenceKey) |
| if args_kwargs_key_path.idx == 0: |
| return f"*args{keystr(key_path[1:])}" |
| else: |
| kwarg_key = key_path[1] |
| assert isinstance(kwarg_key, MappingKey) |
| name = str(kwarg_key)[1:-1] # get rid of the enclosed [] |
| return f"{name}{keystr(key_path[2:])}" |
| |
| import sympy |
| |
| from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( |
| _convert_range_to_int, |
| ) |
| from torch.utils._sympy.solve import try_solve |
| |
| if len(flat_args_with_path) != len(input_placeholders): |
| raise RuntimeError( |
| "Unexpected number of inputs " |
| f"(expected {len(input_placeholders)}, got {len(flat_args_with_path)})" |
| ) |
| # NOTE: export already guarantees that the same symbol is used in metadata |
| # for all InputDims related by equality constraints, so we can just unify |
| # symbols with given input dimension values to check equality constraints. |
| unification_map: "Dict[sympy.Symbol, Any]" = {} |
| for (key_path, arg), node in zip(flat_args_with_path, input_placeholders): |
| node_val = node.meta.get("val") |
| if isinstance(node_val, FakeTensor): |
| if not isinstance(arg, torch.Tensor): |
| raise RuntimeError( |
| f"Expected input at {get_keystr(key_path)} to be a tensor, but got {type(arg)}", |
| ) |
| |
| if len(node_val.shape) != len(arg.shape): |
| raise RuntimeError( |
| f"Unexpected number of dimensions in input at {get_keystr(key_path)}.shape " |
| f"(expected {node_val.shape}, got {arg.shape})" |
| ) |
| |
| for j, (arg_dim, node_dim) in enumerate(zip(arg.shape, node_val.shape)): |
| # TODO(avik): Assert the following property in the IR verifier: |
| # node_dim is either an int or a SymInt containing an int or a unary sympy.Expr |
| if ( |
| isinstance(node_dim, torch.SymInt) |
| and len(node_dim.node.expr.free_symbols) == 1 |
| ): |
| symbol = next(iter(node_dim.node.expr.free_symbols)) |
| if symbol in unification_map: |
| existing_dim = node_dim.node.expr.subs(unification_map) |
| if arg_dim != existing_dim: |
| raise RuntimeError( |
| f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to " |
| f"{existing_dim}, but got {arg_dim}", |
| ) |
| else: |
| if isinstance(arg_dim, torch.SymInt) and isinstance( |
| node_dim.node.expr, sympy.Symbol |
| ): |
| # this can happen when, say, arg is a fake tensor |
| unification_map[symbol] = arg_dim |
| else: |
| solution = try_solve( |
| sympy.Eq(node_dim.node.expr, arg_dim), symbol |
| ) |
| if solution is None: |
| raise RuntimeError( # noqa: TRY200 |
| f"Expected input {node.name}.shape[{j}] = {arg_dim} to be " |
| f"of the form {node_dim.node.expr}, where {symbol} is an integer" |
| ) |
| else: |
| unification_map[symbol] = int(solution[1]) |
| |
| if node_dim.node.expr in range_constraints: |
| min_val, max_val = _convert_range_to_int( |
| range_constraints[node_dim.node.expr] |
| ) |
| # NOTE: we allow dimensions to be 0/1 at runtime |
| if min_val > 2: |
| if arg_dim < min_val: |
| raise RuntimeError( |
| f"Expected input at {get_keystr(key_path)}.shape[{j}] to be >= " |
| f"{min_val}, but got {arg_dim}", |
| ) |
| if max_val < math.inf: |
| if arg_dim > max_val: |
| raise RuntimeError( |
| f"Expected input at {get_keystr(key_path)}.shape[{j}] to be <= " |
| f"{max_val}, but got {arg_dim}", |
| ) |
| else: |
| if arg_dim != node_dim: |
| raise RuntimeError( |
| f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to " |
| f"{node_dim}, but got {arg_dim}", |
| ) |
| elif isinstance(node_val, (int, float, str)): |
| if type(arg) != type(node_val) or arg != node_val: |
| raise RuntimeError( |
| f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}", |
| ) |
| |
| |
| def register_dataclass_as_pytree_node( |
| cls: Type[Any], |
| flatten_fn: Optional[FlattenFunc] = None, |
| unflatten_fn: Optional[UnflattenFunc] = None, |
| *, |
| serialized_type_name: Optional[str] = None, |
| to_dumpable_context: Optional[ToDumpableContextFn] = None, |
| from_dumpable_context: Optional[FromDumpableContextFn] = None, |
| return_none_fields: bool = False, |
| ) -> None: |
| assert dataclasses.is_dataclass( |
| cls |
| ), f"Only dataclasses can be registered with this function: {cls}" |
| |
| def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]: |
| flattened = [] |
| flat_names = [] |
| none_names = [] |
| for f in dataclasses.fields(obj): |
| name, val = f.name, getattr(obj, f.name) |
| if val is not None or return_none_fields: |
| flattened.append(val) |
| flat_names.append(name) |
| else: |
| none_names.append(name) |
| return flattened, [flat_names, none_names] |
| |
| def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any: |
| flat_names, none_names = context |
| return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names)) |
| |
| flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn |
| unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn |
| |
| if (to_dumpable_context is None) ^ (from_dumpable_context is None): |
| raise ValueError( |
| f"Both to_dumpable_context and from_dumpable_context for {cls} must " |
| "be None or registered." |
| ) |
| |
| _register_pytree_node( |
| cls, |
| flatten_fn, |
| unflatten_fn, |
| serialized_type_name=serialized_type_name, |
| to_dumpable_context=to_dumpable_context, |
| from_dumpable_context=from_dumpable_context, |
| ) |
| |
| |
| def is_param(program: ExportedProgram, node: torch.fx.Node) -> bool: |
| """ |
| Checks if the given node is a parameter within the exported program |
| """ |
| |
| return node.name in program.graph_signature.inputs_to_parameters |
| |
| |
| def get_param( |
| program: ExportedProgram, |
| node: torch.fx.Node, |
| ) -> Optional[torch.nn.Parameter]: |
| """ |
| Returns the parameter associated with the given node in the exported program. |
| Returns None if the node is not a parameter within the exported program |
| """ |
| |
| if is_param(program, node): |
| parameter_name = program.graph_signature.inputs_to_parameters[node.name] |
| return program.state_dict[parameter_name] |
| |
| return None |
| |
| |
| def is_buffer(program: ExportedProgram, node: torch.fx.Node) -> bool: |
| """ |
| Checks if the given node is a buffer within the exported program |
| """ |
| |
| return node.name in program.graph_signature.inputs_to_buffers |
| |
| |
| def get_buffer( |
| program: ExportedProgram, |
| node: torch.fx.Node, |
| ) -> Optional[torch.Tensor]: |
| """ |
| Returns the buffer associated with the given node in the exported program. |
| Returns None if the node is not a buffer within the exported program |
| """ |
| |
| if is_buffer(program, node): |
| buffer_name = program.graph_signature.inputs_to_buffers[node.name] |
| if buffer_name in program.graph_signature.non_persistent_buffers: |
| return program.constants[buffer_name] |
| else: |
| return program.state_dict[buffer_name] |
| |
| return None |
| |
| |
| def is_lifted_tensor_constant( |
| program: ExportedProgram, |
| node: torch.fx.Node, |
| ) -> bool: |
| """ |
| Checks if the given node is a lifted tensor constant within the exported program |
| """ |
| |
| return node.name in program.graph_signature.inputs_to_lifted_tensor_constants |
| |
| |
| def get_lifted_tensor_constant( |
| program: ExportedProgram, |
| node: torch.fx.Node, |
| ) -> Optional[torch.Tensor]: |
| """ |
| Returns the lifted tensor constant associated with the given node in the exported program. |
| Returns None if the node is not a lifted tensor constant within the exported program |
| """ |
| |
| if is_lifted_tensor_constant(program, node): |
| lifted_tensor_name = program.graph_signature.inputs_to_lifted_tensor_constants[ |
| node.name |
| ] |
| return program.constants[lifted_tensor_name] |
| |
| return None |
| |
| |
| def sequential_split(gm: torch.fx.GraphModule, node_call_back) -> torch.fx.GraphModule: |
| """ |
| Splits the graph module into multiple submodules based on the node_call_back. |
| The node_call_back should return True if the node is a delimiter. Delimiter will be |
| the first node in the next submodule. |
| """ |
| from torch.fx.passes.split_module import split_module |
| |
| split_map = {} |
| split_id = 0 |
| for node in gm.graph.nodes: |
| if node_call_back(node): |
| split_id += 1 |
| split_map[node] = split_id |
| |
| new_gm = split_module( |
| gm, |
| gm, |
| lambda node: split_map[node], |
| keep_original_order=True, |
| keep_original_node_name=True, |
| ) |
| # Keep the codegen from original graph module to preserve e.g. pytree info. |
| new_gm.graph._codegen = gm.graph._codegen |
| new_gm.recompile() |
| return new_gm |
| |
| |
| def nodes_filter(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]: |
| """Returns the nodes that match the node_call_back as a list.""" |
| return [node for node in nodes if node_call_back(node)] |
| |
| |
| def nodes_first( |
| nodes: List[torch.fx.Node], node_call_back=None |
| ) -> Optional[torch.fx.Node]: |
| """ |
| Returns the first node that matches the node_call_back. If no node matches, returns None. |
| When node_call_back is None, returns the first node in the node list. |
| """ |
| ret = nodes_filter(nodes, node_call_back if node_call_back else lambda node: True) |
| if len(ret) > 0: |
| return ret[0] |
| return None |
| |
| |
| def nodes_count(nodes: List[torch.fx.Node], node_call_back) -> int: |
| """Returns the number of nodes that match the node_call_back.""" |
| return len(nodes_filter(nodes, node_call_back)) |
| |
| |
| def nodes_map(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]: |
| """ |
| Sequentially visit the nodes list and invoke node_call_back on each element. |
| Returns the nodes list after the node_call_back is invoked on each element. |
| """ |
| for node in nodes: |
| node_call_back(node) |
| return nodes |
| |
| |
| def node_replace_( |
| old_node: torch.fx.Node, new_node: torch.fx.Node, delete_old: bool = False |
| ) -> None: |
| """ |
| Replace all uses of old_node with new_node. |
| """ |
| old_node.replace_all_uses_with(new_node) |
| if delete_old: |
| old_node.users.clear() |
| old_node.graph.erase_node(old_node) |
| |
| |
| def node_inline_(call_mod_node: torch.fx.Node) -> None: |
| """ |
| Inline the submodule of the given node into the parent module. |
| Note: we only support the case where submodule takes tensors inputs. |
| """ |
| assert call_mod_node.op == "call_module" |
| gm = call_mod_node.graph.owning_module |
| |
| assert isinstance(call_mod_node.target, str) |
| sub_gm = getattr(gm, call_mod_node.target) |
| |
| phs = (node for node in sub_gm.graph.nodes if node.op == "placeholder") |
| body = ( |
| node for node in sub_gm.graph.nodes if node.op not in ("placeholder", "output") |
| ) |
| output = [node for node in sub_gm.graph.nodes if node.op == "output"] |
| |
| for ph, arg in zip(phs, call_mod_node.args): |
| assert isinstance(arg, torch.fx.Node) |
| node_replace_(ph, arg, delete_old=True) |
| |
| with gm.graph.inserting_before(call_mod_node): |
| for node in body: |
| new_node = gm.graph.node_copy(node) |
| node_replace_(node, new_node, delete_old=True) |
| |
| if len(output) > 0: |
| assert len(output) == 1 and len(output[0].args) == 1 |
| new_output = output[0].args[0] |
| |
| if isinstance(new_output, torch.fx.Node): |
| node_replace_(call_mod_node, new_output, delete_old=True) |
| elif isinstance(new_output, (list, tuple)): |
| # Inline the get_item calls for the output node. |
| get_item_users = nodes_filter( |
| list(call_mod_node.users.keys()), |
| lambda node: node.op == "call_function" |
| and node.target == operator.getitem, |
| ) |
| # get_item_node.args[1] is the idx referring to new_output[idx] |
| nodes_map( |
| get_item_users, |
| lambda get_item_node: node_replace_( |
| get_item_node, |
| new_output[get_item_node.args[1]], |
| delete_old=True, |
| ), |
| ) |
| call_mod_node.graph.erase_node(call_mod_node) |
| else: |
| raise NotImplementedError( |
| f"Unsupported output type {type(new_output)}. Expect it to be a Node or a list/tuple of Nodes." |
| ) |
| else: |
| call_mod_node.graph.erase_node(call_mod_node) |
| |
| gm.delete_all_unused_submodules() |
| gm.recompile() |
| return gm |