| # Nodes represent a definition of a value in our graph of operators. |
| from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict |
| import torch |
| |
| if TYPE_CHECKING: |
| from .graph import Graph |
| |
| |
| BaseArgumentTypes = Union[str, int, float, bool, torch.dtype, torch.Tensor] |
| base_types = BaseArgumentTypes.__args__ # type: ignore |
| |
| Target = Union[Callable[..., Any], str] |
| |
| Argument = Optional[Union[ |
| Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types |
| List[Any], # actually Argument |
| Dict[str, Any], # actually Argument |
| slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing |
| 'Node', |
| BaseArgumentTypes |
| ]] |
| |
| class Node: |
| def __init__(self, graph: 'Graph', name: str, op: str, target: Target, |
| args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> None: |
| self.graph = graph |
| self.name = name # unique name of value being created |
| self.op = op # the kind of operation = placeholder|call_method|call_module|call_function|getattr |
| self.target = target # for method/module/function, the name of the method/module/function/attr |
| # being invoked, e.g add, layer1, or torch.add |
| self.args = args |
| self.kwargs = kwargs |
| self.uses = 0 |
| |
| def __repr__(self) -> str: |
| return self.name |