blob: efc268e7ec3fb7ccfd3f54a4e08b5d704179be40 [file] [log] [blame]
from typing import List, Tuple, Union
import torch
import torch.fx
Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]]
TensorOrTensors = Union[torch.Tensor, Tensors]
Nodes = List[torch.fx.Node]
CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"}