| import itertools | 
 | import operator | 
 | from collections.abc import Iterable | 
 |  | 
 | import torch | 
 | from torch._ops import OpOverload | 
 | from torch._subclasses.fake_tensor import FakeTensor | 
 | from torch.fx import GraphModule | 
 |  | 
 |  | 
 | ALLOWED_META_KEYS = {"spec", "stack_trace"} | 
 |  | 
 | @torch.fx._compatibility.compatibility(is_backward_compatible=False) | 
 | class SpecViolationError(Exception): | 
 |     pass | 
 |  | 
 | @torch.fx._compatibility.compatibility(is_backward_compatible=False) | 
 | def is_functional(op: OpOverload) -> bool: | 
 |     return not op._schema.is_mutable | 
 |  | 
 |  | 
 | @torch.fx._compatibility.compatibility(is_backward_compatible=False) | 
 | def _check_has_fake_tensor(node: torch.fx.Node) -> None: | 
 |     def _check_is_fake_tensor(val): | 
 |         if isinstance(val, FakeTensor): | 
 |             return True | 
 |         if isinstance(val, Iterable): | 
 |             return all(_check_is_fake_tensor(x) for x in val) | 
 |         return False | 
 |  | 
 |     val = node.meta.get("val") | 
 |     if not _check_is_fake_tensor(val): | 
 |         raise SpecViolationError("Node.meta {} is missing val field.".format(node.name)) | 
 |  | 
 |  | 
 | @torch.fx._compatibility.compatibility(is_backward_compatible=False) | 
 | def check_valid(gm: GraphModule) -> None:  # noqa: C901 | 
 |  | 
 |     for node in gm.graph.nodes: | 
 |         # TODO(T140410192): should have fake tensor for all dialects | 
 |         if node.op == "call_method": | 
 |             # what is delegates in ATen dialect? | 
 |             raise SpecViolationError( | 
 |                 "call_module can only be used for delegates, got a object of class '{}.{}' instead".format( | 
 |                     type(node.args[0]).__module__, type(node.args[0]).__name__ | 
 |                 ), | 
 |             ) | 
 |  | 
 |         if node.op == "call_module": | 
 |             raise SpecViolationError( | 
 |                 "call_module is not valid: got a class '{}' ".format(node.target), | 
 |             ) | 
 |  | 
 |         if node.op == "call_function": | 
 |             _check_has_fake_tensor(node) | 
 |             op_name = ( | 
 |                 node.target.name | 
 |                 if hasattr(node.target, "name") | 
 |                 else node.target.__name__ | 
 |             ) | 
 |             is_builtin_func = (node.target == operator.getitem or node.target.__name__ in [ | 
 |                 'while_loop', | 
 |                 'cond', | 
 |             ]) | 
 |             if not isinstance(node.target, OpOverload) and not is_builtin_func: | 
 |                 raise SpecViolationError( | 
 |                     "Operator '{}' is not a registered Op".format(op_name), | 
 |                 ) | 
 |             # All ops functional | 
 |             # TODO(qihan): use node.target.is_functional: when PR/83134 lands | 
 |             if not is_builtin_func and not is_functional(node.target): | 
 |                 raise SpecViolationError( | 
 |                     "operator '{}' is not functional".format(op_name), | 
 |                 ) | 
 |  | 
 |             if isinstance(node.target, OpOverload): | 
 |                 stacktrace = node.meta.get("stack_trace") | 
 |  | 
 |                 if stacktrace is None: | 
 |                     raise SpecViolationError( | 
 |                         "node of name '{}' for operator '{}' is missing stackstrace".format( | 
 |                             node.name, op_name | 
 |                         ), | 
 |                     ) | 
 |  | 
 |  | 
 | @torch.fx._compatibility.compatibility(is_backward_compatible=False) | 
 | def is_valid(gm: GraphModule) -> bool: | 
 |     try: | 
 |         check_valid(gm) | 
 |         return True | 
 |     except SpecViolationError: | 
 |         return False | 
 |  | 
 |  | 
 | @torch.fx._compatibility.compatibility(is_backward_compatible=False) | 
 | def check_valid_aten_dialect(gm: GraphModule) -> None: | 
 |     """Raises exception if gm is not in aten dialect. | 
 |  | 
 |     Args: | 
 |         gm: GraphModule | 
 |     """ | 
 |     # need to be first valid | 
 |     check_valid(gm) | 
 |     # Operators be aten cannonical | 
 |     for n in gm.graph.nodes: | 
 |         if n.op == "call_function" and isinstance(n.target, OpOverload): | 
 |             if ( | 
 |                 torch.Tag.core not in n.target.tags  # type: ignore[attr-defined] | 
 |                 and torch.Tag.view_copy not in n.target.tags  # type: ignore[attr-defined] | 
 |             ): | 
 |                 # NOTE(qihan): whether view_copy operators are marked as canonical is still under | 
 |                 #            discussion. | 
 |                 raise SpecViolationError( | 
 |                     "Operator {}.{} is not Aten Canonical.".format( | 
 |                         n.target.__module__, n.target.__name__ | 
 |                     ) | 
 |                 ) | 
 |  | 
 |     # Tensors be of contiguous format | 
 |     for name, param in itertools.chain(gm.named_parameters(), gm.named_buffers()): | 
 |         if isinstance(param, torch.Tensor): | 
 |             if not param.is_contiguous(): | 
 |                 raise SpecViolationError( | 
 |                     f"Tensors in Aten dialect must be contiguous, {name} is not contiguous" | 
 |                 ) | 
 |  | 
 |  | 
 | @torch.fx._compatibility.compatibility(is_backward_compatible=False) | 
 | def is_valid_aten_dialect(gm: GraphModule) -> bool: | 
 |     try: | 
 |         check_valid_aten_dialect(gm) | 
 |         return True | 
 |     except SpecViolationError: | 
 |         return False | 
 |  | 
 |  | 
 | @torch.fx._compatibility.compatibility(is_backward_compatible=False) | 
 | def check_valid_edge_dialect(gm: GraphModule) -> None: | 
 |     check_valid_aten_dialect(gm) | 
 |  | 
 |     # Additionally, edge dialect's operator must have same input dtype | 
 |     for n in gm.graph.nodes: | 
 |         if n.op == "call_function" and isinstance(n.target, OpOverload): | 
 |             _check_has_fake_tensor(n) | 
 |             dtypes = set() | 
 |             for arg in n.args: | 
 |                 if isinstance(arg, torch.Tensor): | 
 |                     dtypes.add(arg.dtype) | 
 |                 if isinstance(arg, torch.fx.Node): | 
 |                     dtypes.add(arg.meta["val"].dtype) | 
 |             if len(dtypes) > 1: | 
 |                 raise SpecViolationError( | 
 |                     "Operators of Edge dialect in should work on tensors of same dtype" | 
 |                 ) | 
 |  | 
 |  | 
 | @torch.fx._compatibility.compatibility(is_backward_compatible=False) | 
 | def is_valid_edge_dialect(gm: GraphModule) -> bool: | 
 |     try: | 
 |         check_valid_edge_dialect(gm) | 
 |         return True | 
 |     except SpecViolationError: | 
 |         return False |