| # Copyright 2025 Arm Limited and/or its affiliates. |
| # Copyright 2025-2026 NXP |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| import json |
| import subprocess |
| import time |
| from typing import Any, Callable, Type |
| |
| from executorch.exir import EdgeProgramManager, ExecutorchProgramManager # type: ignore |
| from executorch.exir.program._program import ( # type: ignore |
| _update_exported_program_graph_module, |
| ) |
| |
| from torch._export.verifier import Verifier |
| from torch.export.exported_program import ExportedProgram # type: ignore |
| from torch.fx import GraphModule, Node # type: ignore |
| |
| try: |
| from model_explorer import config, consts, visualize_from_config # type: ignore |
| from model_explorer.config import ModelExplorerConfig # type: ignore |
| from model_explorer.pytorch_exported_program_adater_impl import ( # type: ignore |
| PytorchExportedProgramAdapterImpl, |
| ) |
| except ImportError: |
| print( |
| "Error: 'model_explorer' is not installed. Install using devtools/install_requirements.sh" |
| ) |
| raise |
| |
| |
| class SingletonModelExplorerServer: |
| """Singleton context manager for starting a model-explorer server. |
| If multiple ModelExplorerServer contexts are nested, a single |
| server is still used. |
| """ |
| |
| server: None | subprocess.Popen = None |
| num_open: int = 0 |
| wait_after_start = 3.0 |
| |
| def __init__(self, open_in_browser: bool = True, port: int | None = None): |
| if SingletonModelExplorerServer.server is None: |
| command = ["model-explorer"] |
| if not open_in_browser: |
| command.append("--no_open_in_browser") |
| if port is not None: |
| command.append("--port") |
| command.append(str(port)) |
| SingletonModelExplorerServer.server = subprocess.Popen(command) |
| |
| def __enter__(self): |
| SingletonModelExplorerServer.num_open = ( |
| SingletonModelExplorerServer.num_open + 1 |
| ) |
| time.sleep(SingletonModelExplorerServer.wait_after_start) |
| return self |
| |
| def __exit__(self, type, value, traceback): |
| SingletonModelExplorerServer.num_open = ( |
| SingletonModelExplorerServer.num_open - 1 |
| ) |
| if SingletonModelExplorerServer.num_open == 0: |
| if SingletonModelExplorerServer.server is not None: |
| SingletonModelExplorerServer.server.kill() |
| try: |
| SingletonModelExplorerServer.server.wait( |
| SingletonModelExplorerServer.wait_after_start |
| ) |
| except subprocess.TimeoutExpired: |
| SingletonModelExplorerServer.server.terminate() |
| SingletonModelExplorerServer.server = None |
| |
| |
| class ModelExplorerServer: |
| """Context manager for starting a model-explorer server.""" |
| |
| wait_after_start = 2.0 |
| |
| def __init__(self, open_in_browser: bool = True, port: int | None = None): |
| command = ["model-explorer"] |
| if not open_in_browser: |
| command.append("--no_open_in_browser") |
| if port is not None: |
| command.append("--port") |
| command.append(str(port)) |
| self.server = subprocess.Popen(command) |
| |
| def __enter__(self): |
| time.sleep(self.wait_after_start) |
| |
| def __exit__(self, type, value, traceback): |
| self.server.kill() |
| try: |
| self.server.wait(self.wait_after_start) |
| except subprocess.TimeoutExpired: |
| self.server.terminate() |
| |
| |
| def _get_exported_program( |
| visualizable: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager, |
| ) -> ExportedProgram: |
| if isinstance(visualizable, ExportedProgram): |
| return visualizable |
| if isinstance(visualizable, (EdgeProgramManager, ExecutorchProgramManager)): |
| return visualizable.exported_program() |
| raise RuntimeError(f"Cannot get ExportedProgram from {visualizable}") |
| |
| |
| def visualize( |
| visualizable: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager, |
| reuse_server: bool = True, |
| no_open_in_browser: bool = False, |
| **kwargs, |
| ): |
| """Wraps the visualize_from_config call from model_explorer. |
| For convenience, figures out how to find the exported_program |
| from EdgeProgramManager and ExecutorchProgramManager for you. |
| |
| See https://github.com/google-ai-edge/model-explorer/wiki/4.-API-Guide#visualize-pytorch-models |
| for full documentation. |
| """ |
| cur_config = config() |
| settings = consts.DEFAULT_SETTINGS |
| cur_config.add_model_from_pytorch( |
| "Executorch", |
| exported_program=_get_exported_program(visualizable), |
| settings=settings, |
| ) |
| if reuse_server: |
| cur_config.set_reuse_server() |
| visualize_model_explorer( |
| config=kwargs.pop("config", cur_config), |
| no_open_in_browser=no_open_in_browser, |
| **kwargs, |
| ) |
| |
| |
| def visualize_model_explorer( |
| **kwargs, |
| ): |
| """Wraps the visualize_from_config call from model_explorer.""" |
| visualize_from_config( |
| **kwargs, |
| ) |
| |
| |
| def _save_model_as_json(cur_config: ModelExplorerConfig, file_name: str): |
| """Save the graphs stored in the `cur_config` in JSON format, which can be loaded by the Model Explorer GUI. |
| |
| :param cur_config: ModelExplorerConfig containing the graph for visualization. |
| :param file_name: Name of the JSON file for storage. |
| """ |
| # Extract the graphs from the config file. |
| graphs_list = json.loads(cur_config.get_transferrable_data()["graphs_list"]) |
| graphs = json.loads(graphs_list[0])["graphs"] |
| |
| # The returned dictionary is missing the `collectionLabel` entry. Add it manually. |
| for graph in graphs: |
| graph["collectionLabel"] = "Executorch" |
| |
| # Create the JSON according to the structure required by the Model Explorer GUI. |
| json_data = { |
| "label": "Executorch", |
| "graphs": graphs, |
| "graphsWithLevel": [ |
| {"graph": graph, "level": level} for level, graph in enumerate(graphs) |
| ], |
| } |
| |
| # Store the JSON. |
| with open(file_name, "w") as f: |
| json.dump(json_data, f) |
| |
| |
| def visualize_with_clusters( |
| exported_program: ExportedProgram, |
| json_file_name: str | None = None, |
| no_open_in_browser: bool = False, |
| reuse_server: bool = False, |
| get_node_partition_name: Callable[[Node], str | None] = lambda node: node.meta.get( |
| "delegation_tag", None |
| ), |
| get_node_qdq_cluster_name: Callable[ |
| [Node], str | None |
| ] = lambda node: node.meta.get("cluster", None), |
| **kwargs, |
| ): |
| """Visualize exported programs using the Model Explorer. The QDQ clusters and individual partitions are highlighted. |
| |
| To install the Model Explorer, run `devtools/install_requirements.sh`. |
| To display a stored json file, first launch the Model Explorer server by running `model-explorer`, and then |
| use the GUI to open the json. |
| |
| NOTE: FireFox seems to have issues rendering the graphs. Other browsers seem to work well. |
| |
| :param exported_program: Program to visualize. |
| :param json_file_name: If not None, a JSON of the visualization will be stored in the provided file. The JSON can |
| then be opened in the Model Explorer GUI later. |
| If None, a Model Explorer instance will be launched with the model visualization. |
| :param no_open_in_browser: If `True`, a browser instance with the model explorer will NOT be launched, and only the |
| URI to the model explorer server with the visualization will be printed. |
| :param reuse_server: If True, an existing instance of the Model Explorer server will be used (if one exists). |
| Otherwise, a new instance at a separate port will start. |
| :param get_node_partition_name: Function which takes a `Node` and returns a string with the name of the partition |
| the `Node` belongs to, or `None` if it has no partition. |
| :param get_node_qdq_cluster_name: Function which takes a `Node` and returns a string with the name of the QDQ |
| cluster the `Node` belongs to, or `None` if it has no cluster. |
| :param kwargs: Additional kwargs for the `visualize_from_config()` function. |
| """ |
| |
| cur_config = config() |
| |
| # Create a Model Explorer graph from the `exported_program`. |
| adapter = PytorchExportedProgramAdapterImpl( |
| exported_program, consts.DEFAULT_SETTINGS |
| ) |
| graphs = adapter.convert() |
| |
| nodes = list(exported_program.graph.nodes) |
| explorer_nodes = graphs["graphs"][0].nodes |
| |
| # Highlight QDQ clusters and individual partitions. |
| known_partition_names = [] |
| for explorer_node, node in zip(explorer_nodes, nodes, strict=True): |
| # Generate the `namespace` for the node, which will determine node grouping in the visualizer. |
| # The character "/" is used as a divider when a node has multiple namespaces. |
| namespace = "" |
| |
| if (partition_name := get_node_partition_name(node)) is not None: |
| # If the nodes are tagged by the partitioner, highlight the tagged groups. |
| |
| # Create a custom naming for the partitions ("partition <i>" where `i` = 0, 1, 2, ...). |
| if partition_name not in known_partition_names: |
| known_partition_names.append(partition_name) |
| partition_id = known_partition_names.index(partition_name) |
| |
| safe_partition_name = partition_name.replace( |
| "/", ":" |
| ) # Avoid using unwanted "/". |
| namespace += f"partition {partition_id} ({safe_partition_name})" |
| |
| if (cluster_name := get_node_qdq_cluster_name(node)) is not None: |
| # Highlight the QDQ cluster. |
| |
| # Add a separator, in case the namespace already contains the `partition`. |
| if len(namespace) != 0: |
| namespace += "/" |
| |
| # Create a custom naming for the clusters ("cluster (<old_cluster_name>)"). |
| safe_cluster_name = cluster_name.replace( |
| "/", ":" |
| ) # Avoid using unwanted "/". |
| namespace += f"cluster ({safe_cluster_name})" |
| |
| explorer_node.namespace = namespace |
| |
| # Store the modified graph in the config. |
| graphs_index = len(cur_config.graphs_list) |
| cur_config.graphs_list.append(graphs) |
| name = "Executorch" |
| model_source: config.ModelSource = {"url": f"graphs://{name}/{graphs_index}"} |
| cur_config.model_sources.append(model_source) |
| |
| if json_file_name is not None: |
| # Just save the visualization. |
| _save_model_as_json(cur_config, json_file_name) |
| |
| else: |
| # Start the ModelExplorer server, and visualize the graph in the browser. |
| if reuse_server: |
| cur_config.set_reuse_server() |
| visualize_from_config( |
| cur_config, |
| **kwargs, |
| no_open_in_browser=no_open_in_browser, |
| ) |
| |
| |
| def visualize_graph( |
| graph_module: GraphModule, |
| exported_program: ExportedProgram | EdgeProgramManager | ExecutorchProgramManager, |
| reuse_server: bool = True, |
| no_open_in_browser: bool = False, |
| **kwargs, |
| ): |
| """Overrides the graph_module of the supplied exported_program with 'graph_module' before visualizing. |
| Also disables validating operators to allow visualizing graphs containing custom ops. |
| |
| A typical example is after running passes, which returns a graph_module rather than an ExportedProgram. |
| """ |
| |
| class _any_op(Verifier): |
| dialect = "ANY_OP" |
| |
| def allowed_op_types(self) -> tuple[Type[Any], ...]: |
| return (Callable,) # type: ignore |
| |
| exported_program = _get_exported_program(exported_program) |
| exported_program = _update_exported_program_graph_module( |
| exported_program, graph_module, override_verifiers=[_any_op] |
| ) |
| visualize(exported_program, reuse_server, no_open_in_browser, **kwargs) |