| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| """ |
| Base Intermediate Representation for Productivity SDK consumers |
| (e.g. TensorBoard, Terminal Debugger) |
| """ |
| |
| from __future__ import annotations |
| |
| from dataclasses import dataclass |
| from typing import Any, Dict, List, Optional |
| |
| |
| # Base Representation of a generic node within a ModelGraph |
| @dataclass |
| class Node: |
| name: str |
| # Nodes that this Node consumes/in-edges |
| inputs: Optional[List[Node]] = None |
| # List of output shapes |
| output_shapes: Optional[List[List[int]]] = None |
| # Generic Node level metadata |
| metadata: Optional[Dict[str, Any]] = None |
| # Names of the arguments derived from the op schema: |
| named_args: Optional[List[str]] = None |
| |
| |
| # Base Representation of an operator subgraph with metadata |
| @dataclass |
| class OperatorGraph: |
| # Identifier used for grouping nodes (e.g. expand/minimize Module) |
| graph_name: str |
| # Nodes and Sub-Graphs |
| elements: List[Node | OperatorGraph] |
| # Graph Level Metadata |
| metadata: Optional[Dict[str, Any]] = None |
| |
| |
| """ |
| Node SubClasses Types |
| """ |
| |
| |
| # Representation of a "Value" node within a ModelGraph |
| # i.e. Non-Operator Nodes |
| @dataclass |
| class ValueNode(Node): |
| dtype: str = "" |
| val: Optional[Any] = None |
| |
| |
| # Representation of an "OP" node within a ModelGraph |
| @dataclass |
| class OperatorNode(Node): |
| op: Optional[str] = None |