| import itertools |
| import operator |
| from collections.abc import Iterable |
| from typing import Set |
| |
| import torch |
| from functorch.experimental import control_flow |
| from torch._ops import OpOverload |
| from torch._subclasses.fake_tensor import FakeTensor |
| from torch.fx import GraphModule |
| from torch.fx._compatibility import compatibility |
| |
| |
| PRESERVED_META_KEYS: Set[str] = { |
| "val", |
| "stack_trace", |
| } |
| |
| |
| @compatibility(is_backward_compatible=False) |
| class SpecViolationError(Exception): |
| pass |
| |
| |
| @compatibility(is_backward_compatible=False) |
| def is_functional(op: OpOverload) -> bool: |
| return not op._schema.is_mutable |
| |
| |
| @compatibility(is_backward_compatible=False) |
| def _check_has_fake_tensor(node: torch.fx.Node) -> None: |
| def _check_is_fake_tensor(val): |
| if isinstance(val, FakeTensor): |
| return True |
| if isinstance(val, Iterable): |
| return all(_check_is_fake_tensor(x) for x in val) |
| return False |
| |
| val = node.meta.get("val", None) |
| if val is None or not _check_is_fake_tensor(val): |
| raise SpecViolationError(f"Node.meta {node.name} is missing val field.") |
| |
| |
| @compatibility(is_backward_compatible=False) |
| def _check_tensors_are_contiguous(gm: GraphModule) -> None: |
| # Tensors be of contiguous format |
| for name, param in itertools.chain(gm.named_parameters(), gm.named_buffers()): |
| if isinstance(param, torch.Tensor): |
| if not param.is_contiguous(): |
| raise SpecViolationError( |
| f"Tensors in Aten dialect must be contiguous, {name} is not contiguous" |
| ) |
| |
| |
| @compatibility(is_backward_compatible=False) |
| class Verifier: |
| def __call__(self, gm: GraphModule) -> None: |
| self.check_valid(gm) |
| |
| @compatibility(is_backward_compatible=False) |
| def valid_builtin_funcs(self): |
| return [ |
| operator.getitem, |
| control_flow.cond, |
| torch.ops.map_impl, |
| ] |
| |
| @compatibility(is_backward_compatible=False) |
| def check_valid_op(self, op): |
| op_name = op.name if hasattr(op, "name") else op.__name__ |
| |
| if not isinstance(op, OpOverload): |
| raise SpecViolationError( |
| f"Operator '{op_name}' is not a registered Op", |
| ) |
| |
| # All ops functional |
| if not is_functional(op): |
| raise SpecViolationError( |
| f"operator '{op_name}' is not functional" |
| ) |
| |
| @compatibility(is_backward_compatible=False) |
| def check_valid(self, gm: GraphModule) -> None: # noqa: C901 |
| |
| for node in gm.graph.nodes: |
| # TODO(T140410192): should have fake tensor for all dialects |
| if node.op in {"call_module", "call_method"}: |
| raise SpecViolationError( |
| f"call_module is not valid: got a class '{node.target}' ", |
| ) |
| |
| if node.op == "call_function": |
| _check_has_fake_tensor(node) |
| |
| if node.target not in self.valid_builtin_funcs(): |
| self.check_valid_op(node.target) |
| |
| if isinstance(node.target, OpOverload): |
| # Check preserved metadata |
| for meta in PRESERVED_META_KEYS: |
| if node.meta.get(meta, None) is None: |
| raise SpecViolationError( |
| f"node {node} is missing metadata {meta}" |
| ) |
| |
| @compatibility(is_backward_compatible=False) |
| def is_valid(self, gm: GraphModule) -> bool: |
| try: |
| self.check_valid(gm) |
| return True |
| except SpecViolationError: |
| return False |
| |
| |
| class ATenDialectVerifier(Verifier): |
| @compatibility(is_backward_compatible=False) |
| def check_valid_op(self, op) -> None: |
| super().check_valid_op(op) |
| |
| op_name = op.name if hasattr(op, "name") else op.__name__ |
| |
| if not isinstance(op, OpOverload): |
| raise SpecViolationError( |
| f"Operator '{op_name}' is not a registered Op", |
| ) |
| |
| if ( |
| torch.Tag.core not in op.tags # type: ignore[attr-defined] |
| and torch.Tag.view_copy not in op.tags # type: ignore[attr-defined] |
| ): |
| # NOTE(qihan): whether view_copy operators are marked as canonical is still under |
| # discussion. |
| raise SpecViolationError( |
| f"Operator {op.__module__}.{op.__name__} is not Aten Canonical." |
| ) |
| |
| @compatibility(is_backward_compatible=False) |
| def check_valid(self, gm: GraphModule) -> None: |
| super().check_valid(gm) |
| _check_tensors_are_contiguous(gm) |