| from typing import Any, Dict, List, NamedTuple, Optional, Tuple |
| |
| import torch |
| from torch.fx._compatibility import compatibility |
| from torch.fx.graph import Graph |
| from torch.fx.graph_module import GraphModule |
| from torch.fx.node import ( |
| _get_qualified_name, |
| Argument, |
| map_aggregate, |
| map_arg, |
| Node, |
| Target, |
| ) |
| from torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes |
| from torch.fx.passes.shape_prop import ShapeProp |
| |
| |
| @compatibility(is_backward_compatible=False) |
| def replace_target_nodes_with( |
| fx_module: GraphModule, |
| old_op: str, |
| old_target: Target, |
| new_op: str, |
| new_target: Target, |
| ): |
| """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target, |
| and updates them to match the new op code and target""" |
| new_graph = Graph() |
| val_map: Dict[Node, Node] = {} |
| for node in fx_module.graph.nodes: |
| if node.op == old_op and node.target == old_target: |
| args = map_arg(node.args, lambda n: val_map[n]) |
| kwargs = map_arg(node.kwargs, lambda n: val_map[n]) |
| assert isinstance(args, tuple) |
| assert isinstance(kwargs, dict) |
| val_map[node] = new_graph.create_node( |
| new_op, new_target, args, kwargs, node.name |
| ) |
| else: |
| val_map[node] = new_graph.node_copy(node, lambda n: val_map[n]) |
| fx_module.graph = new_graph |
| |
| |
| @compatibility(is_backward_compatible=False) |
| class size_bytes(NamedTuple): |
| output_size: int |
| total_size: int |
| |
| |
| @compatibility(is_backward_compatible=False) |
| def get_size_of_all_nodes( |
| fx_module: GraphModule, args: Optional[List[torch.Tensor]] = None |
| ) -> None: |
| """Given a fx graph module, update each node with its total size (weights + bias + output) |
| and its output_size(output). For a non-module node, the total size is the output size. |
| return total size""" |
| if args is not None: |
| # Mark shape and dtype for each node (node.shape and node.dtype) |
| ShapeProp(fx_module).propagate(*args) |
| # Calculate the total size of the whole fx graph |
| total_size_of_graph = 0.0 |
| for node in fx_module.graph.nodes: |
| if node.op == "output": |
| break |
| node.size_bytes = get_size_of_node(fx_module, node) |
| return |
| |
| |
| @compatibility(is_backward_compatible=False) |
| def get_tensor_meta(node: Node) -> Any: |
| tensor_meta = node.meta.get("tensor_meta") |
| |
| if not tensor_meta: |
| raise RuntimeError( |
| f"Node {node} has no tensor metadata associated with it! " |
| f"Check that shape propagation has run." |
| ) |
| |
| return tensor_meta |
| |
| |
| @compatibility(is_backward_compatible=False) |
| def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes: |
| """Given a node with node.dtype and node.shape, return its total size and its output size. |
| total_size = weights + bias + output_size |
| """ |
| # Total num of elements |
| total_num_of_elems = 0 |
| # For a module, conside all parameters |
| if node.op == "call_module": |
| submodule_dict = dict(fx_module.named_modules()) |
| submodule = submodule_dict[node.target] |
| parameters = submodule.named_parameters() |
| # Parameters are named tuples |
| for name, p in parameters: |
| total_num_of_elems += p.numel() |
| # Don't forget the output size |
| # node.shape is the shape of this node's output |
| tensor_meta = get_tensor_meta(node) |
| output_elem = tensor_meta.shape.numel() |
| total_num_of_elems += output_elem |
| # Assume for now if it's quantized then it's qint8 or quint8 |
| if tensor_meta.is_quantized: |
| size_per_elem_bytes = torch._empty_affine_quantized( |
| [], dtype=tensor_meta.dtype |
| ).element_size() |
| else: |
| size_per_elem_bytes = torch.tensor([], dtype=tensor_meta.dtype).element_size() |
| total_size = size_per_elem_bytes * total_num_of_elems |
| output_size = size_per_elem_bytes * output_elem |
| return size_bytes(output_size, total_size) |
| |
| |
| @compatibility(is_backward_compatible=False) |
| def serialize_shape(shape: torch.Size) -> str: |
| return str(list(shape)) |
| |
| |
| @compatibility(is_backward_compatible=False) |
| def serialize_stride(stride: Tuple[int]) -> str: |
| return str(list(stride)) |
| |
| |
| @compatibility(is_backward_compatible=False) |
| def serialize_tensor_quantization( |
| tensor: torch.Tensor, weights: Dict, pcq_prefix: str |
| ) -> Tuple[Dict, Dict]: |
| """ |
| Args: |
| tensor: The tensor from which we try to extract quantization information. |
| weights: A dict that contains mapping from name to a tensor value. |
| pcq_prefix: A string that we would use later on as prefix for per channel quantization information. This |
| usually would be the key that we use to store info of `tensor`. |
| |
| Returns: |
| scheme: Dict that stores the quantization information of `tensor`. |
| per_channel_dict: Dict that stores the information of per_channel_scales and |
| per_channel_zero_points of `tensor`. This Will be empty if `tensor` is not |
| per channel quantized. |
| |
| `tensor` is per tensor quantized: |
| scheme: { |
| "qscheme": str(tensor.qscheme()), |
| "q_scale": tensor.q_scale(), |
| "q_zero_point": tensor.q_zero_point(), |
| } |
| |
| `tensor` is per channel quantized: |
| scheme: { |
| "qscheme": str(tensor.qscheme()), |
| "q_per_channel_scales": {pcq_prefix}_per_channel_scales, |
| "q_per_channel_zero_points": {pcq_prefix}_per_channel_zero_points, |
| "q_per_channel_axis": tensor.q_per_channel_axis() |
| } |
| per_channel_dict: { |
| {pcq_prefix}_per_channel_scales: { |
| "dtype": dtype, |
| "shape": shape, |
| "is_quantized": is_quantized, |
| "stride": stride, |
| } |
| {pcq_prefix}_per_channel_zero_points: { |
| "dtype": dtype, |
| "shape": shape, |
| "is_quantized": is_quantized, |
| "stride": stride, |
| } |
| } |
| weights would be updated with { |
| {pcq_prefix}_per_channel_scales: tensor.q_per_channel_scales().float() |
| {pcq_prefix}_per_channel_zero_points: tensor.q_per_channel_zero_points().int() |
| } |
| """ |
| scheme: Dict[str, Any] = {} |
| per_channel_dict: Dict[str, Dict] = {} |
| |
| if not tensor.is_quantized: |
| return scheme, per_channel_dict |
| |
| scheme["qscheme"] = str(tensor.qscheme()) |
| |
| # For per tensor scheme, we stores scale and zero_point. |
| if tensor.qscheme() in {torch.per_tensor_affine, torch.per_tensor_symmetric}: |
| scheme["q_scale"] = tensor.q_scale() |
| scheme["q_zero_point"] = tensor.q_zero_point() |
| |
| # For per channel scheme, per_channel_scales and per_channel_zero_points are tensors. |
| # We store their tensor value into `weights` and store the name into `scheme`. |
| if tensor.qscheme() in { |
| torch.per_channel_affine, |
| torch.per_channel_affine_float_qparams, |
| torch.per_channel_symmetric, |
| }: |
| # per_channel_scales is float64. Here we save it as float32. |
| weights[ |
| f"{pcq_prefix}_per_channel_scales" |
| ] = tensor.q_per_channel_scales().float() |
| scheme["q_per_channel_scales"] = f"{pcq_prefix}_per_channel_scales" |
| per_channel_dict.update( |
| serialize_weight( |
| weights[f"{pcq_prefix}_per_channel_scales"], |
| weights, |
| f"{pcq_prefix}_per_channel_scales", |
| ) |
| ) |
| |
| # per_channel_zero_point is int64. Here we save it as int32. |
| weights[ |
| f"{pcq_prefix}_per_channel_zero_points" |
| ] = tensor.q_per_channel_zero_points().int() |
| scheme["q_per_channel_zero_points"] = f"{pcq_prefix}_per_channel_zero_points" |
| per_channel_dict.update( |
| serialize_weight( |
| weights[f"{pcq_prefix}_per_channel_zero_points"], |
| weights, |
| f"{pcq_prefix}_per_channel_zero_points", |
| ) |
| ) |
| |
| scheme["q_per_channel_axis"] = tensor.q_per_channel_axis() |
| return scheme, per_channel_dict |
| |
| |
| @compatibility(is_backward_compatible=False) |
| def serialize_weight(tensor: torch.Tensor, weights: Dict, name: str) -> Dict: |
| weight_dict: Dict[str, Dict] = {name: {}} |
| weight_dict[name]["dtype"] = str(tensor.dtype) |
| weight_dict[name]["shape"] = serialize_shape(tensor.shape) |
| weight_dict[name]["requires_grad"] = str(tensor.requires_grad) |
| weight_dict[name]["is_quantized"] = tensor.is_quantized |
| weight_dict[name]["stride"] = serialize_stride(tensor.stride()) |
| |
| if tensor.is_quantized: |
| quantization_info, per_channel_dict = serialize_tensor_quantization( |
| tensor, weights, name |
| ) |
| weight_dict[name].update(quantization_info) |
| weight_dict.update(per_channel_dict) |
| |
| return weight_dict |
| |
| |
| @compatibility(is_backward_compatible=False) |
| def serialize_leaf_module( |
| node: Node, weights_metadata: Dict, weights: Dict, name_prefix: str |
| ) -> Dict: |
| parameters: Dict[str, Any] = {} |
| |
| for p_name, p_value in node.attrs_for_lowering.items(): # type: ignore[attr-defined] |
| if isinstance(p_value, torch.Tensor): |
| weights_metadata.update( |
| serialize_weight(p_value, weights, f"{name_prefix}.{p_name}") |
| ) |
| weights[f"{name_prefix}.{p_name}"] = p_value |
| else: |
| parameters[p_name] = str(p_value) |
| |
| return parameters |
| |
| |
| def _update_weight_fused_dtypes(weight, name, node): |
| """ |
| For quantized embedding tables we need to update the shape/type, so we check if the |
| users of this get_attr node is a quantized EB and this is the weight for the EB, and |
| update the dtype accordingly. |
| """ |
| if len(node.users) == 0: |
| return |
| user = list(node.users)[0] |
| if user.op != "call_function": |
| return |
| user_target = _get_qualified_name(user.target) |
| if ( |
| user_target.endswith("acc_ops.embedding_bag_byte_rowwise_offsets") |
| and node == user.kwargs["weight"] |
| ): |
| weight[name]["dtype"] = "acc.uint8fused" |
| elif ( |
| user_target.endswith("acc_ops.embedding_bag_4bit_rowwise_offsets") |
| and node == user.kwargs["weight"] |
| ): |
| weight[name]["dtype"] = "acc.uint4fused" |
| |
| |
| @compatibility(is_backward_compatible=False) |
| def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict: |
| """Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON. |
| It also adds all weights the provided weights dictionary by qualified_name. |
| Dictionary Schema: |
| MODULE |
| { |
| modules: {module_name: MODULE], |
| nodes: [NODE], |
| weights {qualified_name: WEIGHT}, |
| } |
| NODE |
| { |
| shape: [], |
| stride: [], |
| dtype: dtype, |
| is_quantized: bool, |
| target: target, |
| op_code: op_code, |
| name: name, |
| args: [], |
| kwargs: {} |
| } |
| WEIGHT |
| { |
| dtype: dtype, |
| is_quantized: bool, |
| shape: [], |
| QUANTIZATION, |
| } |
| QUANTIZATION |
| { |
| qscheme: qscheme, |
| q_scale: float, |
| q_zero_point: float, |
| q_per_channel_scales, [], |
| q_per_channel_zero_points: [], |
| q_per_channel_axis, int |
| } |
| """ |
| serialized_dict: Dict[str, Any] = {} |
| serialized_dict["modules"] = {} |
| serialized_dict["weights"] = {} |
| serialized_dict["nodes"] = [] |
| submodules = dict(fx_module.named_modules()) |
| prefix = f"{name_prefix}." if name_prefix else "" |
| |
| def get_node_info(node): |
| tensor_meta = get_tensor_meta(node) |
| node_rep = { |
| "shape": serialize_shape(tensor_meta.shape), |
| "dtype": str(tensor_meta.dtype), |
| "requires_grad": str(tensor_meta.requires_grad), |
| "stride": serialize_stride(tensor_meta.stride), |
| "is_quantized": tensor_meta.is_quantized, |
| } |
| |
| if tensor_meta.is_quantized: |
| node_rep["qscheme"] = str(tensor_meta.qparams["qscheme"]) |
| |
| if tensor_meta.qparams["qscheme"] in { |
| torch.per_tensor_affine, |
| torch.per_tensor_symmetric, |
| }: |
| node_rep["q_scale"] = tensor_meta.qparams["scale"] |
| node_rep["q_zero_point"] = tensor_meta.qparams["zero_point"] |
| |
| # Add all extra lowering_info that was provided in node.meta. |
| lowering_info = node.meta.get("lowering_info") |
| if lowering_info is not None: |
| overlapping_keys = node_rep.keys() & lowering_info.keys() |
| assert ( |
| len(overlapping_keys) == 0 |
| ), f"Overlap found between lowering_info and node_rep: {overlapping_keys}" |
| node_rep.update(lowering_info) |
| |
| return node_rep |
| |
| # Note: lift_lowering_attrs_to_nodes is only used to support leaf modules |
| # that cannot currently be symbolically traced into, e.g. batch norm. |
| lift_lowering_attrs_to_nodes(fx_module) |
| for node in fx_module.graph.nodes: |
| node_rep: Dict[str, Any] = {} |
| # Get shape/type info, currently not needed for call_module node |
| # whose target is a GraphModule and output node. |
| if ( |
| not ( |
| node.op == "call_module" |
| and isinstance(submodules[node.target], GraphModule) |
| ) |
| and node.op != "output" |
| ): |
| node_rep.update(get_node_info(node)) |
| |
| # Recurse down into any submodules we are calling. |
| if node.op == "call_module": |
| if isinstance(submodules[node.target], GraphModule): |
| serialized_module = serialize_module( |
| getattr(fx_module, node.target), weights, node.target |
| ) |
| serialized_dict["modules"][node.target] = serialized_module |
| else: |
| node_rep["parameters"] = serialize_leaf_module( |
| node, |
| serialized_dict["weights"], |
| weights, |
| prefix + node.target, |
| ) |
| |
| if node.op == "call_function": |
| node_rep["target"] = _get_qualified_name(node.target) |
| else: |
| node_rep["target"] = str(node.target) |
| |
| # Make sure we capture all constants. |
| if node.op == "get_attr": |
| # If we are targeting a parent constant we update the target. |
| if node.target.startswith("parent."): |
| qualname = node.target[len("parent.") :] |
| node.name = qualname |
| node_rep["target"] = qualname |
| else: |
| qualname = prefix + node.target |
| # Find the actual target parameter/buffer from the fx_module. |
| submod_path, _, target_name = node.target.rpartition(".") |
| submod: Optional[torch.nn.Module] = ( |
| fx_module.get_submodule(submod_path) if submod_path else fx_module |
| ) |
| assert submod is not None, f"submod {submod_path} not found" |
| target = getattr(submod, target_name, None) |
| assert target is not None, f"{target_name} not an attr of {submod_path}" |
| # Check that the target is a tensor, and that we haven't added it already from a leaf module. |
| if isinstance(target, torch.Tensor) and qualname not in weights: |
| weight = serialize_weight(target, weights, qualname) |
| _update_weight_fused_dtypes(weight, qualname, node) |
| serialized_dict["weights"].update(weight) |
| weights[qualname] = target |
| elif node.op == "placeholder": |
| ph_type = node.meta.get("ph_type", "") |
| assert ( |
| ph_type == "" or ph_type == "input_ph" or ph_type == "output_ph" |
| ), "When present, placeholder type must be 'input_ph' or 'ouput_ph'" |
| if ph_type == "input_ph": |
| node_rep["ph_type"] = "input_ph" |
| elif ph_type == "output_ph": |
| node_rep["ph_type"] = "output_ph" |
| |
| node_rep["op_code"] = node.op |
| node_rep["name"] = node.name |
| |
| def get_user_info(user_node: Argument) -> Any: |
| return {"is_node": True, "name": str(user_node)} |
| |
| def get_arg_info(arg: Argument) -> Any: |
| if isinstance(arg, torch.fx.Node): |
| return {"is_node": True, "name": str(arg)} |
| elif isinstance(arg, (torch.dtype, torch.memory_format, torch.qscheme)): |
| return str(arg) |
| else: |
| return arg |
| |
| def get_output_arg_info(arg: Node) -> Dict[str, Any]: |
| node_rep: Dict[str, Any] = get_arg_info(arg) |
| node_rep.update(get_node_info(arg)) |
| return node_rep |
| |
| if node.op == "output": |
| node_rep["args"] = map_arg( |
| node.args, |
| get_output_arg_info, |
| ) |
| |
| # If there're multiple outputs then node_rep["args"][0] will be a tuple or |
| # list. In this case we want to unpack the tuple or list. |
| if isinstance(node_rep["args"][0], (tuple, list)): |
| node_rep["args"] = node_rep["args"][0] |
| else: |
| node_rep["args"] = map_aggregate(node.args, get_arg_info) |
| |
| node_rep["kwargs"] = map_aggregate(node.kwargs, get_arg_info) |
| node_rep["users"] = map_aggregate(list(node.users.keys()), get_user_info) |
| serialized_dict["nodes"] += [node_rep] |
| |
| return serialized_dict |