| import torch |
| import torch.fx |
| import traceback |
| |
| from torch.fx.node import Node, map_aggregate |
| from typing import Any, Tuple, NamedTuple, Optional, Dict |
| from torch.fx._compatibility import compatibility |
| |
| |
| @compatibility(is_backward_compatible=True) |
| class TensorMetadata(NamedTuple): |
| # TensorMetadata is a structure containing pertinent information |
| # about a tensor within a PyTorch program. |
| |
| # General Tensor metadata |
| shape : torch.Size |
| dtype : torch.dtype |
| requires_grad : bool |
| stride : Tuple[int] |
| memory_format : Optional[torch.memory_format] |
| |
| # Quantization metadata |
| is_quantized : bool |
| qparams: Dict[str, Any] |
| |
| def _extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata: |
| """ |
| Extract a TensorMetadata NamedTuple describing `result`. |
| """ |
| shape = result.shape |
| dtype = result.dtype |
| requires_grad = result.requires_grad |
| stride = result.stride() |
| |
| memory_formats = { |
| torch.contiguous_format, |
| torch.channels_last, |
| torch.channels_last_3d, |
| } |
| |
| memory_format = None |
| |
| for query_format in memory_formats: |
| if result.is_contiguous(memory_format=query_format): |
| memory_format = query_format |
| break |
| |
| is_quantized = result.is_quantized |
| qparams: Dict[str, Any] = {} |
| if is_quantized: |
| qscheme = result.qscheme() |
| qparams["qscheme"] = qscheme |
| if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: |
| qparams["scale"] = result.q_scale() # type: ignore[assignment] |
| qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment] |
| elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}: |
| # In this branch, scale and zero_point are expected to be tensors, |
| # we store the values as immutable_list in TensorMetadata for |
| # easier serialization downstream |
| qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment] |
| qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment] |
| qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] |
| |
| return TensorMetadata( |
| shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) |
| |
| @compatibility(is_backward_compatible=True) |
| class ShapeProp(torch.fx.Interpreter): |
| """ |
| Execute an FX graph Node-by-Node and |
| record the shape and type of the result |
| into the corresponding node. |
| |
| Example: |
| In this example, we record the shape |
| and data type of a module given |
| an example input ``torch.randn(50, D_in)``. |
| We print the name, shape and dtype of each node. |
| |
| class TwoLayerNet(torch.nn.Module): |
| def __init__(self, D_in, H, D_out): |
| super(TwoLayerNet, self).__init__() |
| self.linear1 = torch.nn.Linear(D_in, H) |
| self.linear2 = torch.nn.Linear(H, D_out) |
| def forward(self, x): |
| h_relu = self.linear1(x).clamp(min=0) |
| y_pred = self.linear2(h_relu) |
| return y_pred |
| N, D_in, H, D_out = 64, 1000, 100, 10 |
| x = torch.randn(N, D_in) |
| y = torch.randn(N, D_out) |
| model = TwoLayerNet(D_in, H, D_out) |
| gm = torch.fx.symbolic_trace(model) |
| sample_input = torch.randn(50, D_in) |
| ShapeProp(gm).propagate(sample_input) |
| |
| for node in gm.graph.nodes: |
| print(node.name, node.meta['tensor_meta'].dtype, |
| node.meta['tensor_meta'].shape) |
| |
| The output of this code is: |
| |
| x torch.float32 torch.Size([50, 1000]) |
| linear1 torch.float32 torch.Size([50, 100]) |
| clamp_1 torch.float32 torch.Size([50, 100]) |
| linear2 torch.float32 torch.Size([50, 10]) |
| output torch.float32 torch.Size([50, 10]) |
| |
| Args: |
| module (GraphModule): The module to be executed |
| |
| """ |
| def run_node(self, n : Node) -> Any: |
| try: |
| result = super().run_node(n) |
| except Exception: |
| traceback.print_exc() |
| raise RuntimeError( |
| f"ShapeProp error for: node={n.format_node()} with " |
| f"meta={n.meta}" |
| ) |
| |
| found_tensor = False |
| |
| def extract_tensor_meta(obj): |
| if isinstance(obj, torch.Tensor): |
| nonlocal found_tensor |
| found_tensor = True |
| return _extract_tensor_metadata(obj) |
| else: |
| return obj |
| |
| meta = map_aggregate(result, extract_tensor_meta) |
| if found_tensor: |
| n.meta['tensor_meta'] = meta |
| |
| n.meta['type'] = type(result) |
| return result |
| |
| def propagate(self, *args): |
| """ |
| Run `module` via interpretation and return the result and |
| record the shape and type of each node. |
| |
| Args: |
| *args (Tensor): the sample input. |
| |
| Returns: |
| Any: The value returned from executing the Module |
| """ |
| return super().run(*args) |