blob: 2f9890fb838bcba5a352b52089f0b604d7daddd0 [file] [log] [blame]
import dataclasses
import math
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
import torch
from torch._export import ExportedProgram
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._pytree import (
_register_pytree_node,
Context,
DumpableContext,
FlattenFunc,
FromDumpableContextFn,
ToDumpableContextFn,
tree_flatten,
UnflattenFunc,
)
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS: Dict[str, Type[Any]] = {}
@torch._dynamo.disable
def _check_input_constraints_pre_hook(self, *args, **kwargs):
flat_args, _ = tree_flatten(args)
return _check_input_constraints_for_graph(
[node for node in self.graph.nodes if node.op == "placeholder"],
flat_args,
self.range_constraints,
)
def _check_input_constraints_for_graph(
input_placeholders: List[torch.fx.Node], args, range_constraints
):
def check(cond, msg):
if not cond:
# TODO(avik): maybe add more context, e.g., graph signature
raise RuntimeError(msg)
import sympy
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
_convert_range_to_int,
)
check(
len(args) == len(input_placeholders),
"Unexpected number of inputs "
f"(expected {len(input_placeholders)}, got {len(args)})",
)
# NOTE: export already guarantees that the same symbol is used in metadata
# for all InputDims related by equality constraints, so we can just unify
# symbols with given input dimension values to check equality constraints.
unification_map: "Dict[sympy.Symbol, Any]" = {}
for arg, node in zip(args, input_placeholders):
node_val = node.meta["val"]
if isinstance(node_val, FakeTensor):
check(
isinstance(arg, torch.Tensor),
f"Expected input {node.name} to be a tensor, but got {type(arg)}",
)
check(
len(node_val.shape) == len(arg.shape),
f"Unexpected number of dimensions in input {node.name}.shape "
f"(expected {node_val.shape}, got {arg.shape})",
)
for j, (arg_dim, node_dim) in enumerate(zip(arg.shape, node_val.shape)):
if isinstance(node_dim, torch.SymInt):
if node_dim.node.expr in unification_map:
existing_dim = unification_map[node_dim.node.expr]
check(
arg_dim == existing_dim,
f"Expected input {node.name}.shape[{j}] to be equal to "
f"{existing_dim}, but got {arg_dim}",
)
else:
unification_map[node_dim.node.expr] = arg_dim
if node_dim.node.expr in range_constraints:
min_val, max_val = _convert_range_to_int(
range_constraints[node_dim.node.expr]
)
# NOTE: we allow dimensions to be 0/1 at runtime
if min_val > 2:
check(
arg_dim >= min_val,
f"Expected input {node.name}.shape[{j}] to be >= "
f"{min_val}, but got {arg_dim}",
)
if max_val < math.inf:
check(
arg_dim <= max_val,
f"Expected input {node.name}.shape[{j}] to be <= "
f"{max_val}, but got {arg_dim}",
)
else:
check(
arg_dim == node_dim,
f"Expected input {node.name}.shape[{j}] to be equal to "
f"{node_dim}, but got {arg_dim}",
)
elif isinstance(node_val, (int, float, str)):
check(
type(arg) == type(node_val) and arg == node_val,
f"Expected input {node.name} to be equal to {node_val}, but got {arg}",
)
def register_dataclass_as_pytree_node(
cls: Type[Any],
flatten_fn: Optional[FlattenFunc] = None,
unflatten_fn: Optional[UnflattenFunc] = None,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
return_none_fields: bool = False,
) -> None:
assert dataclasses.is_dataclass(
cls
), f"Only dataclasses can be registered with this function: {cls}"
serialized_type = f"{cls.__module__}.{cls.__qualname__}"
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = cls
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
flattened = []
flat_names = []
none_names = []
for f in dataclasses.fields(obj):
name, val = f.name, getattr(obj, f.name)
if val is not None or return_none_fields:
flattened.append(val)
flat_names.append(name)
else:
none_names.append(name)
return flattened, (cls, flat_names, none_names)
def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any:
typ, flat_names, none_names = context
return typ(**dict(zip(flat_names, values)), **{k: None for k in none_names})
def default_to_dumpable_context(context: Context) -> DumpableContext:
return (serialized_type, context[1], context[2])
def default_from_dumpable_context(dumpable_context: DumpableContext) -> Context:
return (
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[dumpable_context[0]],
dumpable_context[1],
dumpable_context[2],
)
flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn
unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn
if (to_dumpable_context is None) ^ (from_dumpable_context is None):
raise ValueError(
f"Both to_dumpable_context and from_dumpable_context for {cls} must "
"be None or registered."
)
to_dumpable_context = (
to_dumpable_context
if to_dumpable_context is not None
else default_to_dumpable_context
)
from_dumpable_context = (
from_dumpable_context
if from_dumpable_context is not None
else default_from_dumpable_context
)
_register_pytree_node(
cls,
flatten_fn,
unflatten_fn,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
)
def is_param(program: ExportedProgram, node: torch.fx.Node) -> bool:
"""
Checks if the given node is a parameter within the exported program
"""
return node.name in program.graph_signature.inputs_to_parameters
def get_param(
program: ExportedProgram,
node: torch.fx.Node,
) -> Optional[torch.nn.Parameter]:
"""
Returns the parameter associated with the given node in the exported program.
Returns None if the node is not a parameter within the exported program
"""
if is_param(program, node):
parameter_name = program.graph_signature.inputs_to_parameters[node.name]
return program.state_dict[parameter_name]
return None
def is_buffer(program: ExportedProgram, node: torch.fx.Node) -> bool:
"""
Checks if the given node is a buffer within the exported program
"""
return node.name in program.graph_signature.inputs_to_buffers
def get_buffer(
program: ExportedProgram,
node: torch.fx.Node,
) -> Optional[torch.Tensor]:
"""
Returns the buffer associated with the given node in the exported program.
Returns None if the node is not a buffer within the exported program
"""
if is_buffer(program, node):
buffer_name = program.graph_signature.inputs_to_buffers[node.name]
return program.state_dict[buffer_name]
return None