blob: a496609ec0a3c9d19d8897cda67680c6b9f92539 [file] [log] [blame]
from typing import Dict, List, NamedTuple, Any
import torch
from torch.fx.passes.shape_prop import ShapeProp
from torch.fx.experimental.param_fetch import lift_lowering_attrs_to_nodes
from torch.fx.node import _get_qualified_name
from torch.fx.graph_module import GraphModule
from torch.fx.graph import Graph
from torch.fx.node import Node, Target, map_arg
def replace_target_nodes_with(
fx_module: GraphModule,
old_op: str,
old_target: Target,
new_op: str,
new_target: Target,
):
"""Modifies all nodes in fx_module.graph.nodes which match the specified op code and target,
and updates them to match the new op code and target"""
new_graph = Graph()
val_map: Dict[Node, Node] = {}
for node in fx_module.graph.nodes:
if node.op == old_op and node.target == old_target:
args = map_arg(node.args, lambda n: val_map[n])
kwargs = map_arg(node.kwargs, lambda n: val_map[n])
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
val_map[node] = new_graph.create_node(
new_op, new_target, args, kwargs, node.name
)
else:
val_map[node] = new_graph.node_copy(node, lambda n: val_map[n])
fx_module.graph = new_graph
class size_bytes(NamedTuple):
output_size: int
total_size: int
def get_size_of_all_nodes(fx_module: GraphModule, args: List[torch.Tensor]) -> None:
"""Given a fx graph module, update each node with its total size (weights + bias + output)
and its output_size(output). For a non-module node, the total size is the output size.
return total size"""
# Mark shape and dtype for each node (node.shape and node.dtype)
ShapeProp(fx_module).propagate(*args)
# Calculate the total size of the whole fx graph
total_size_of_graph = 0.0
for node in fx_module.graph.nodes:
if node.op == "output":
break
node.size_bytes = get_size_of_node(fx_module, node)
return
def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
"""Given a node with node.dtype and node.shape, return its total size and its output size.
total_size = weights + bias + output_size
"""
# Total num of elements
total_num_of_elems = 0
# For a module, conside all parameters
if node.op == "call_module":
submodule_dict = dict(fx_module.named_modules())
submodule = submodule_dict[node.target]
parameters = submodule.named_parameters()
# Parameters are named tuples
for name, p in parameters:
total_num_of_elems += p.numel()
# Don't forget the output size
# node.shape is the shape of this node's output
shape = getattr(node, "shape", None)
if shape:
output_elem = shape.numel()
else:
raise RuntimeError("Node has no shape attr")
total_num_of_elems += output_elem
size_per_elem_bytes = 0
dtype = getattr(node, "dtype", None)
if dtype:
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
else:
raise RuntimeError("Node has no dtype attr")
total_size = size_per_elem_bytes * total_num_of_elems
output_size = size_per_elem_bytes * output_elem
return size_bytes(output_size, total_size)
def serialize_shape(shape: torch.Size) -> str:
return str(list(shape))
def serialize_tensor_quantization(tensor: torch.Tensor) -> Dict[str, Any]:
scheme: Dict[str, Any] = {}
if tensor.is_quantized:
scheme["q_scheme"] = str(tensor.qscheme())
if tensor.qscheme() in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
scheme["q_scale"] = tensor.q_scale()
scheme["q_zero_pont"] = tensor.q_zero_point()
if tensor.qscheme() in {
torch.per_channel_affine,
torch.per_channel_affine_float_qparams,
torch.per_channel_symmetric,
}:
scheme["q_per_channel_scales"] = tensor.q_per_channel_scales().tolist()
scheme[
"q_per_channel_zero_points"
] = tensor.q_per_channel_zero_points().tolist()
scheme["q_per_channel_axis"] = tensor.q_per_channel_axis()
return scheme
def serialize_weight(tensor: torch.Tensor) -> Dict:
weight: Dict[str, Any] = {}
weight["dtype"] = str(tensor.dtype)
weight["is_quantized"] = tensor.is_quantized
if tensor.is_quantized:
weight["quantized_type"] = serialize_tensor_quantization(tensor)
weight["shape"] = serialize_shape(tensor.shape)
return weight
def serialize_leaf_module(
node: Node, weights_metadata: Dict, weights: Dict, name_prefix: str
) -> Dict:
parameters: Dict[str, Any] = {}
for p_name, p_value in node.attrs_for_lowering.items(): # type: ignore
if isinstance(p_value, torch.Tensor):
weights_metadata[f"{name_prefix}.{p_name}"] = serialize_weight(p_value)
weights[f"{name_prefix}.{p_name}"] = p_value
else:
parameters[p_name] = str(p_value)
return parameters
def serialize_module(fx_module: GraphModule, weights: Dict, name_prefix="") -> Dict:
"""Recursively Serializes a graph module (fx_module) to a dictionary which is later exported to JSON.
It also adds all weights the provided weights dictionary by qualified_name.
Dictionary Schema:
MODULE
{
modules: {module_name: MODULE],
nodes: [NODE],
weights {qualified_name: WEIGHT},
}
NODE
{
shape: [],
dtype: dtype,
target: target,
op_code: op_code,
name: name,
args: [],
kwargs: {}
}
WEIGHT
{
dtype: dtype,
is_quantized: bool,
shape: [],
quantization_info: QUANTIZATION
}
QUANTIZATION
{
qscheme: qscheme,
q_scale: float,
q_zero_point: float,
q_per_channel_scales, [],
q_per_channel_zero_points: [],
q_per_channel_axis, int
}
"""
serialized_dict: Dict[str, Any] = {}
serialized_dict["modules"] = {}
serialized_dict["weights"] = {}
serialized_dict["nodes"] = []
parameters = fx_module.named_parameters()
prefix = f"{name_prefix}." if name_prefix else ""
submodules = dict(fx_module.named_modules())
for name, p in parameters:
if isinstance(p, torch.Tensor):
weight = serialize_weight(p)
serialized_dict["weights"][prefix + name] = weight
weights[prefix + name] = p
lift_lowering_attrs_to_nodes(fx_module)
for node in fx_module.graph.nodes:
node_rep: Dict[str, Any] = {}
# Get shape/type info, currently not needed for call_module.
if node.op != "call_module" or not isinstance(
submodules[node.target], GraphModule
):
shape = getattr(node, "shape", None)
if shape:
node_rep["shape"] = serialize_shape(shape)
else:
raise RuntimeError(
"Node has no shape attr, this is likely because shape propagation has not been run on this Graph."
)
dtype = getattr(node, "dtype", None)
if dtype:
node_rep["dtype"] = str(dtype)
else:
raise RuntimeError(
"Node has no dtype attr, this is likely because shape propagation has not been run on this Graph."
)
# Recurse down into any submodules we are calling.
if node.op == "call_module":
if isinstance(submodules[node.target], GraphModule):
serialized_module = serialize_module(
getattr(fx_module, node.target), weights, node.target
)
serialized_dict["modules"][node.target] = serialized_module
else:
node_rep["parameters"] = serialize_leaf_module(
node,
serialized_dict["weights"],
weights,
prefix + node.target,
)
if node.op == "call_function":
node_rep["target"] = _get_qualified_name(node.target)
else:
node_rep["target"] = str(node.target)
# Make sure we capture all constants.
if node.op == "get_attr":
target = getattr(fx_module, node.target)
qualname = prefix + node.target
if isinstance(target, torch.Tensor) and qualname not in weights:
weight = serialize_weight(target)
serialized_dict["weights"][prefix + node.target] = weight
weights[prefix + node.target] = target
node_rep["op_code"] = node.op
node_rep["name"] = node.name
node_rep["args"] = map_arg(
node.args, lambda arg: {"is_node": True, "name": str(arg)}
)
node_rep["kwargs"] = map_arg(
node.kwargs, lambda arg: {"is_node": True, "name": str(arg)}
)
serialized_dict["nodes"] += [node_rep]
return serialized_dict