blob: 20f1393f7fa8b5b1475bbc0af1102550ae5733b6 [file] [log] [blame]
import itertools
import operator
from collections.abc import Iterable
import torch
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx import GraphModule
ALLOWED_META_KEYS = {"spec", "stack_trace"}
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
class SpecViolationError(Exception):
pass
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
def is_functional(op: OpOverload) -> bool:
return not op._schema.is_mutable
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
def _check_has_fake_tensor(node: torch.fx.Node) -> None:
def _check_is_fake_tensor(val):
if isinstance(val, FakeTensor):
return True
if isinstance(val, Iterable):
return all(_check_is_fake_tensor(x) for x in val)
return False
val = node.meta.get("val")
if not _check_is_fake_tensor(val):
raise SpecViolationError("Node.meta {} is missing val field.".format(node.name))
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
def check_valid(gm: GraphModule) -> None: # noqa: C901
for node in gm.graph.nodes:
# TODO(T140410192): should have fake tensor for all dialects
if node.op == "call_method":
# what is delegates in ATen dialect?
raise SpecViolationError(
"call_module can only be used for delegates, got a object of class '{}.{}' instead".format(
type(node.args[0]).__module__, type(node.args[0]).__name__
),
)
if node.op == "call_module":
raise SpecViolationError(
"call_module is not valid: got a class '{}' ".format(node.target),
)
if node.op == "call_function":
_check_has_fake_tensor(node)
op_name = (
node.target.name
if hasattr(node.target, "name")
else node.target.__name__
)
is_builtin_func = (node.target == operator.getitem or node.target.__name__ in [
'while_loop',
'cond',
])
if not isinstance(node.target, OpOverload) and not is_builtin_func:
raise SpecViolationError(
"Operator '{}' is not a registered Op".format(op_name),
)
# All ops functional
# TODO(qihan): use node.target.is_functional: when PR/83134 lands
if not is_builtin_func and not is_functional(node.target):
raise SpecViolationError(
"operator '{}' is not functional".format(op_name),
)
if isinstance(node.target, OpOverload):
stacktrace = node.meta.get("stack_trace")
if stacktrace is None:
raise SpecViolationError(
"node of name '{}' for operator '{}' is missing stackstrace".format(
node.name, op_name
),
)
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
def is_valid(gm: GraphModule) -> bool:
try:
check_valid(gm)
return True
except SpecViolationError:
return False
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
def check_valid_aten_dialect(gm: GraphModule) -> None:
"""Raises exception if gm is not in aten dialect.
Args:
gm: GraphModule
"""
# need to be first valid
check_valid(gm)
# Operators be aten cannonical
for n in gm.graph.nodes:
if n.op == "call_function" and isinstance(n.target, OpOverload):
if (
torch.Tag.core not in n.target.tags # type: ignore[attr-defined]
and torch.Tag.view_copy not in n.target.tags # type: ignore[attr-defined]
):
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
# discussion.
raise SpecViolationError(
"Operator {}.{} is not Aten Canonical.".format(
n.target.__module__, n.target.__name__
)
)
# Tensors be of contiguous format
for name, param in itertools.chain(gm.named_parameters(), gm.named_buffers()):
if isinstance(param, torch.Tensor):
if not param.is_contiguous():
raise SpecViolationError(
f"Tensors in Aten dialect must be contiguous, {name} is not contiguous"
)
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
def is_valid_aten_dialect(gm: GraphModule) -> bool:
try:
check_valid_aten_dialect(gm)
return True
except SpecViolationError:
return False
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
def check_valid_edge_dialect(gm: GraphModule) -> None:
check_valid_aten_dialect(gm)
# Additionally, edge dialect's operator must have same input dtype
for n in gm.graph.nodes:
if n.op == "call_function" and isinstance(n.target, OpOverload):
_check_has_fake_tensor(n)
dtypes = set()
for arg in n.args:
if isinstance(arg, torch.Tensor):
dtypes.add(arg.dtype)
if isinstance(arg, torch.fx.Node):
dtypes.add(arg.meta["val"].dtype)
if len(dtypes) > 1:
raise SpecViolationError(
"Operators of Edge dialect in should work on tensors of same dtype"
)
@torch.fx._compatibility.compatibility(is_backward_compatible=False)
def is_valid_edge_dialect(gm: GraphModule) -> bool:
try:
check_valid_edge_dialect(gm)
return True
except SpecViolationError:
return False