blob: 4f35bb4888a28c401f1e320283829d4f85289442 [file] [log] [blame]
import itertools
import operator
from collections.abc import Iterable
from typing import Set
import torch
from functorch.experimental import control_flow
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx import GraphModule
from torch.fx._compatibility import compatibility
PRESERVED_META_KEYS: Set[str] = {
"val",
"stack_trace",
}
@compatibility(is_backward_compatible=False)
class SpecViolationError(Exception):
pass
@compatibility(is_backward_compatible=False)
def is_functional(op: OpOverload) -> bool:
return not op._schema.is_mutable
@compatibility(is_backward_compatible=False)
def _check_has_fake_tensor(node: torch.fx.Node) -> None:
def _check_is_fake_tensor(val):
if isinstance(val, FakeTensor):
return True
if isinstance(val, Iterable):
return all(_check_is_fake_tensor(x) for x in val)
return False
val = node.meta.get("val", None)
if val is None or not _check_is_fake_tensor(val):
raise SpecViolationError(f"Node.meta {node.name} is missing val field.")
@compatibility(is_backward_compatible=False)
def _check_tensors_are_contiguous(gm: GraphModule) -> None:
# Tensors be of contiguous format
for name, param in itertools.chain(gm.named_parameters(), gm.named_buffers()):
if isinstance(param, torch.Tensor):
if not param.is_contiguous():
raise SpecViolationError(
f"Tensors in Aten dialect must be contiguous, {name} is not contiguous"
)
@compatibility(is_backward_compatible=False)
class Verifier:
def __call__(self, gm: GraphModule) -> None:
self.check_valid(gm)
@compatibility(is_backward_compatible=False)
def valid_builtin_funcs(self):
return [
operator.getitem,
control_flow.cond,
torch.ops.map_impl,
]
@compatibility(is_backward_compatible=False)
def check_valid_op(self, op):
op_name = op.name if hasattr(op, "name") else op.__name__
if not isinstance(op, OpOverload):
raise SpecViolationError(
f"Operator '{op_name}' is not a registered Op",
)
# All ops functional
if not is_functional(op):
raise SpecViolationError(
f"operator '{op_name}' is not functional"
)
@compatibility(is_backward_compatible=False)
def check_valid(self, gm: GraphModule) -> None: # noqa: C901
for node in gm.graph.nodes:
# TODO(T140410192): should have fake tensor for all dialects
if node.op in {"call_module", "call_method"}:
raise SpecViolationError(
f"call_module is not valid: got a class '{node.target}' ",
)
if node.op == "call_function":
_check_has_fake_tensor(node)
if node.target not in self.valid_builtin_funcs():
self.check_valid_op(node.target)
if isinstance(node.target, OpOverload):
# Check preserved metadata
for meta in PRESERVED_META_KEYS:
if node.meta.get(meta, None) is None:
raise SpecViolationError(
f"node {node} is missing metadata {meta}"
)
@compatibility(is_backward_compatible=False)
def is_valid(self, gm: GraphModule) -> bool:
try:
self.check_valid(gm)
return True
except SpecViolationError:
return False
class ATenDialectVerifier(Verifier):
@compatibility(is_backward_compatible=False)
def check_valid_op(self, op) -> None:
super().check_valid_op(op)
op_name = op.name if hasattr(op, "name") else op.__name__
if not isinstance(op, OpOverload):
raise SpecViolationError(
f"Operator '{op_name}' is not a registered Op",
)
if (
torch.Tag.core not in op.tags # type: ignore[attr-defined]
and torch.Tag.view_copy not in op.tags # type: ignore[attr-defined]
):
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
# discussion.
raise SpecViolationError(
f"Operator {op.__module__}.{op.__name__} is not Aten Canonical."
)
@compatibility(is_backward_compatible=False)
def check_valid(self, gm: GraphModule) -> None:
super().check_valid(gm)
_check_tensors_are_contiguous(gm)