blob: 2a27b62919eec68149a2a8cbfe1392355f5b2c47 [file] [log] [blame]
import operator
from collections.abc import Iterable
from typing import Any, final, List, Set, Tuple, Type
import torch
from torch._ops import HigherOrderOperator, OpOverload
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx import GraphModule
from torch.fx.experimental.symbolic_shapes import SymBool, SymFloat, SymInt
PRESERVED_META_KEYS: Set[str] = {
"val",
"stack_trace",
"source_fn_stack",
}
class SpecViolationError(Exception):
pass
def is_functional(op: OpOverload) -> bool:
return not op._schema.is_mutable
def _check_has_fake_tensor(node: torch.fx.Node) -> None:
# TODO(angelayi): remove this in favor of _check_val
return _check_val(node)
def _check_val(node: torch.fx.Node) -> None:
def _check_correct_val(val):
if val is None:
return True
elif isinstance(val, (int, bool, str, float)):
return True
elif isinstance(val, (torch.memory_format, torch.dtype, torch.device, torch.layout)):
return True
elif isinstance(val, (FakeTensor, torch.Tensor)):
return True
elif isinstance(val, (SymInt, SymFloat, SymBool)):
return True
elif isinstance(val, Iterable):
return all(_check_correct_val(x) for x in val)
return False
if "val" not in node.meta:
raise SpecViolationError(f"Node.meta {node.name} is missing val field.")
val = node.meta["val"]
if not _check_correct_val(val):
raise SpecViolationError(f"Node.meta {node.name} has invalid val field {val}")
class Verifier:
def __call__(self, gm: GraphModule) -> None:
self.check_valid(gm)
def allowed_builtin_ops(self) -> List:
return [operator.getitem]
def allowed_op_types(self) -> Tuple[Type[Any], ...]:
return (OpOverload, HigherOrderOperator)
def check_valid_op(self, op) -> None:
if op not in self.allowed_builtin_ops():
if not isinstance(op, self.allowed_op_types()):
raise SpecViolationError(
f"Operator '{op}' is not an allowed operator type.\n"
f"Valid op types: {self.allowed_builtin_ops}"
)
if isinstance(op, OpOverload):
# All ops functional
if not is_functional(op):
raise SpecViolationError(
f"operator '{op}' is not functional"
)
def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
return (torch.fx.GraphModule,)
def check_additional(self, gm: GraphModule) -> None:
"""
Additional checks that are specific to some dialects.
"""
pass
@final
def check_valid(self, gm: GraphModule) -> None: # noqa: C901
gm.graph.lint()
if object in self.allowed_op_types():
raise SpecViolationError(
"'object' is too generic to be in the list of allowed op types"
)
if object in self.allowed_getattr_types():
raise SpecViolationError(
"'object' is too generic to be in the list of allowed getattr types"
)
for mod in gm.modules():
if not isinstance(mod, torch.fx.GraphModule):
continue
for node in mod.graph.nodes:
# TODO(T140410192): should have fake tensor for all dialects
if node.op in {"call_module", "call_method"}:
raise SpecViolationError(
f"call_module is not valid: got a class '{node.target}' ",
)
elif node.op == "call_function":
_check_val(node)
self.check_valid_op(node.target)
if isinstance(node.target, OpOverload):
# Check preserved metadata
for meta in PRESERVED_META_KEYS:
if node.meta.get(meta, None) is None:
raise SpecViolationError(
f"node {node} is missing metadata {meta}"
)
elif node.op == "get_attr":
if not isinstance(node.target, str):
raise SpecViolationError(
f"Expected get_attr target to be string, but got {type(node.target)}"
)
attr = getattr(mod, node.target)
if not isinstance(attr, self.allowed_getattr_types()):
raise SpecViolationError(
f"Invalid get_attr type {type(attr)}. \n"
f"Valid get_attr types: {self.allowed_getattr_types}"
)
elif node.op == "placeholder":
_check_val(node)
self.check_additional(gm)
def is_valid(self, gm: GraphModule) -> bool:
try:
self.check_valid(gm)
return True
except SpecViolationError:
return False
class ATenDialectVerifier(Verifier):
def check_valid_op(self, op) -> None:
super().check_valid_op(op)
if isinstance(op, OpOverload):
if (
torch.Tag.core not in op.tags
and torch.Tag.view_copy not in op.tags
):
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
# discussion.
raise SpecViolationError(
f"Operator {op.__module__}.{op.__name__} is not Aten Canonical."
)
def verify_exported_program_signature(exported_program) -> None:
# Check ExportedProgram signature matches
gs = exported_program.graph_signature
bs_grad_to_param = {}
bs_grad_to_user_inputs = {}
if gs.backward_signature is not None:
bs_grad_to_param = gs.backward_signature.gradients_to_parameters
bs_grad_to_user_inputs = gs.backward_signature.gradients_to_user_inputs
# Check every node in the signature exists in the graph
input_node_names = [node.name for node in exported_program.graph.nodes if node.op == "placeholder"]
for node in exported_program.graph.nodes:
if node.op != "placeholder":
break
input_node_names.append(node.name)
output_node = list(exported_program.graph.nodes)[-1]
assert output_node.op == "output"
output_node_names = [node.name for node in output_node.args[0]]
def check_exists(node_list, container):
for node in node_list:
if node not in container:
raise SpecViolationError(
f"Node {node} found in the signature's is not in the graph."
)
check_exists(gs.user_inputs, input_node_names)
check_exists(gs.user_outputs, output_node_names)
check_exists(gs.inputs_to_parameters.keys(), input_node_names)
check_exists(gs.inputs_to_parameters.values(), gs.parameters)
check_exists(gs.inputs_to_buffers.keys(), input_node_names)
check_exists(gs.inputs_to_buffers.values(), gs.buffers)
check_exists(gs.buffers_to_mutate.keys(), output_node_names)
check_exists(gs.buffers_to_mutate.values(), gs.buffers)
check_exists(bs_grad_to_param.keys(), output_node_names)
check_exists(bs_grad_to_param.values(), gs.parameters)
check_exists(bs_grad_to_user_inputs.keys(), output_node_names)
check_exists(bs_grad_to_user_inputs.values(), gs.user_inputs)
# Check parameters
for param in gs.parameters:
if param not in exported_program.state_dict:
raise SpecViolationError(
f"Parameter {param} is not in the state dict."
)
if not isinstance(exported_program.state_dict[param], torch.nn.Parameter):
raise SpecViolationError(
f"State dict entry for parameter {param} is not an instance of torch.nn.Parameter."
)
# Check buffers
for buffer in gs.buffers:
if buffer not in exported_program.state_dict:
raise SpecViolationError(
f"Buffer {buffer} is not in the state dict."
)
# Check inputs
placeholder_nodes = [n.name for n in exported_program.graph.nodes if n.op == "placeholder"]
total_gs_placeholders = len(gs.inputs_to_parameters) + len(gs.inputs_to_buffers) + len(gs.user_inputs)
if len(placeholder_nodes) != total_gs_placeholders:
raise SpecViolationError(
f"Number of placeholders nodes {len(placeholder_nodes)} is different "
"Than the number of inputs specified by the graph signature: \n"
f"Number of parameters: {len(gs.inputs_to_parameters)}. \n"
f"Number of buffers: {len(gs.inputs_to_buffers)}. \n"
f"Number of user inputs: {len(gs.user_inputs)}. \n"
)
parameter_nodes = placeholder_nodes[:len(gs.parameters)]
buffer_nodes = placeholder_nodes[len(gs.parameters):len(gs.parameters) + len(gs.buffers)]
user_input_nodes = placeholder_nodes[len(gs.parameters) + len(gs.buffers):]
for param_node, param_name in zip(parameter_nodes, gs.parameters):
if (
param_node not in gs.inputs_to_parameters or
gs.inputs_to_parameters[param_node] != param_name
):
raise SpecViolationError(
f"Parameter input {param_node} is not in the correct "
"order or is not found in the exported program's parameter list. \n"
f"List of parameters, in order: {gs.parameters} \n"
f"Parameter node to parameter name mapping: {gs.inputs_to_parameters} \n"
)
for buffer_node, buffer_name in zip(buffer_nodes, gs.buffers):
if (
buffer_node not in gs.inputs_to_buffers or
gs.inputs_to_buffers[buffer_node] != buffer_name
):
raise SpecViolationError(
f"Buffer input {buffer_node} is not in the correct "
"order or is not found in the exported program's buffer list. \n"
f"List of buffers, in order: {gs.buffers} \n"
f"Buffer node to buffer name mapping: {gs.inputs_to_buffers} \n"
)
for user_input_node, user_input_name in zip(user_input_nodes, gs.user_inputs):
if user_input_node != user_input_name:
raise SpecViolationError(
f"User input {user_input_node} is not in the correct "
"order or is not found in the "
f"exported program's user_input list: {gs.user_input}. "
)
# Check outputs
output_node = list(exported_program.graph.nodes)[-1]
assert output_node.op == "output"
output_nodes = [arg.name for arg in output_node.args[0]]
total_gs_outputs = (
len(gs.buffers_to_mutate) +
len(gs.user_outputs) +
len(bs_grad_to_param) +
len(bs_grad_to_user_inputs)
)
if len(output_nodes) != total_gs_outputs:
raise SpecViolationError(
f"Number of output nodes {len(output_nodes)} is different "
"Than the number of outputs specified by the graph signature: \n"
f"Number of mutated buffers: {len(gs.buffers_to_mutate)}. \n"
f"Number of user outputs: {len(gs.user_outputs)}. \n"
)
buffer_mutate_nodes = output_nodes[:len(gs.buffers_to_mutate)]
user_output_nodes = output_nodes[len(gs.buffers_to_mutate):len(gs.user_outputs) + len(gs.buffers_to_mutate)]
for buffer_node in buffer_mutate_nodes:
if (
buffer_node not in gs.buffers_to_mutate or
gs.buffers_to_mutate[buffer_node] not in gs.buffers
):
raise SpecViolationError(
f"Buffer output {buffer_node} is not in buffer mutation dictionary "
"or, it does not point to a buffer that exists. \n"
f"Dict of buffers that are mutated, in order: {gs.buffers_to_mutate} \n"
f"Buffer nodes available: {gs.buffers} \n"
)
for user_output_node, user_output_name in zip(user_output_nodes, gs.user_outputs):
if user_output_node != user_output_name:
raise SpecViolationError(
f"User output {user_output_node} is not in the correct "
"order or is not found in the "
f"exported program's user_output list: {gs.user_output}. "
)