blob: df0a3a1c01e71e3538bd98a1578b53523c1fad58 [file] [log] [blame]
import builtins
import collections
import copy
import functools
import inspect
import itertools
import math
import operator
import sys
import types
import warnings
from collections import defaultdict
from typing import Any, Callable, cast, Dict, List, Optional, Set, Union
np: Optional[types.ModuleType] = None
try:
import numpy as np
except ModuleNotFoundError:
pass
import torch
import torch._functorch.deprecated as deprecated_func
from torch.fx._symbolic_trace import is_fx_tracing
from . import config
from .external_utils import is_compiling
from .utils import is_safe_constant, NP_SUPPORTED_MODULES
"""
A note on allowed functions:
Dynamo consults this file to determine if a particular function/module
is allowed to appear as a node in its fx output.
If a function is disallowed, it may either be traced-through, or skipped.
Trace-through means dynamo will continue to trace the interior code for
the function/module rather than stopping at its boundary and recording it
as a node in the fx graph. Whether tracing through or allowing, the functionality
of the function/module is part of the dynamo graph. Caveat: if tracing through,
any interior operation could trigger its own graph-break.
Skips are determined by (torch/_dynamo/skipfiles.py) - see "a note on
skipfiles" there.
"""
class FunctionIdSet:
"""
Track a set of `id()`s of objects which are either allowed or not
allowed to go into the generated FX graph. Use to test for torch.*,
numpy.*, builtins.*, etc.
Support user modification to permit customization of what can be
added to the graph and what will cause a graph break.
"""
function_ids: Optional[Set[int]] = None
function_names: Optional[Dict[int, str]] = None
def __init__(self, lazy_initializer: Callable[[], Union[Dict[int, str], Set[int]]]):
self.lazy_initializer = lazy_initializer
def __call__(self):
if self.function_ids is None:
value = self.lazy_initializer()
if isinstance(value, dict):
self.function_ids = set(value.keys())
self.function_names = value
else:
assert isinstance(value, set)
self.function_ids = value
return self.function_ids
def get_name(self, idx: int, default: str):
self() # lazy init
assert self.function_names is not None
return self.function_names.get(idx, default)
def add(self, idx: int):
function_ids = self() # lazy init
function_ids.add(idx)
def remove(self, idx: int):
function_ids = self()
if idx in function_ids:
function_ids.remove(idx)
def __contains__(self, idx: int):
return idx in self()
@FunctionIdSet
def _disallowed_function_ids() -> Set[int]:
remove: List[Any] = [
True,
False,
None,
collections.OrderedDict,
copy.copy,
copy.deepcopy,
inspect.signature,
math.__package__,
torch.__builtins__,
torch.autocast_decrement_nesting,
torch.autocast_increment_nesting,
torch.autograd.grad,
torch.clear_autocast_cache,
torch.cuda.current_device,
torch.cuda.set_device,
torch.distributions.constraints.is_dependent,
torch.distributions.normal.Normal,
torch.inference_mode,
torch.jit.isinstance,
torch.set_anomaly_enabled,
torch.set_autocast_cache_enabled,
torch.set_autocast_cpu_dtype,
torch.set_autocast_cpu_enabled,
torch.set_autocast_enabled,
torch.set_autocast_gpu_dtype,
warnings.warn,
torch._C._dynamo.eval_frame.unsupported,
torch.Tensor.__init__,
torch.resize_as_,
torch._tensor._convert,
]
# extract all dtypes from torch
dtypes = [
obj for obj in torch.__dict__.values() if isinstance(obj, type(torch.float32))
]
remove += dtypes
storage = [
obj
for obj in torch.__dict__.values()
if isinstance(obj, type(torch.FloatStorage))
]
remove += storage
# Distributed APIs don't work well with torch.compile.
if torch.distributed.is_available():
remove.extend(
torch.distributed.distributed_c10d.dynamo_unsupported_distributed_c10d_ops
)
return {id(x) for x in remove}
# We are in progress of refactoring and moving the following functions to test_trace_rules.py.
# If you made any change to the following functions, please also update there as well.
# If you are not clear of how to update, please contact @yanboliang.
@FunctionIdSet
def _allowed_function_ids() -> Dict[int, str]:
"""
Walk torch.* and get the ids of all the stuff in it
"""
warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed")
torch_object_ids = dict()
def _is_allowed_module_prefix(obj):
allowed_modules = ("torch", "math")
# torch.nn.modules.rnn is disallowed because these modules internally
# flatten their parameters. This flattening process will call
# Tensor.set_ with a Storage, and Storages cannot be traced with
# AOTAutograd; so we need to graph-break. To ensure this, we inline
# these functions, rather than keep them opaque-ly in the graph.
disallowed_modules = [
"torch.optim.",
"torch.utils._foreach_utils", # omit the period so we match all the functions in this module
"torch.utils._pytree",
"torch.nn.modules.rnn.",
"torch._dynamo.",
"torch._C._dynamo.",
"torch._inductor.",
"torch._C.inductor.",
"torch.fx.",
"torch.distributed.fsdp.",
"torch.distributed._tensor.",
# Inline through the ActivationWrapper in
# torch.distributed.algorithms._checkpoint.checkpoint_wrapper. This
# nn module calls torch.utils.checkpoint internally. If Dynamo does
# not trace this, AOT Autograd will try to trace this and can cause
# issues observed in
# https://github.com/pytorch/pytorch/issues/108269
"torch.distributed.algorithms.",
]
if config.trace_distributed:
disallowed_modules.append("torch.distributed.")
allowed_modules_dot = tuple([x + "." for x in allowed_modules])
module = inspect.getmodule(obj)
if module is None:
return False
mod_name = module.__name__
if any(mod_name.startswith(m) for m in disallowed_modules):
return False
return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot)
def _find_torch_objects(module):
if any(
module.__name__.startswith(mod_name)
for mod_name in config.allowed_functions_module_string_ignorelist
):
return
torch_object_ids[id(module)] = module.__name__
for name, obj in list(module.__dict__.items()):
if id(obj) not in torch_object_ids:
# Dynamo allows all builtins into the graph and does not attempt
# to introspect into them. We don't want to allow instances of
# HigherOrderOperator into the graph all the time (Dynamo needs
# to introspect the body functions of these HigherOrderOperator
# first, decide they are safe, and then allow them into the graph).
# So we exclude HigherOrderOperator from being a builtin.
import torch._ops
if isinstance(obj, torch._ops.HigherOrderOperator):
continue
# We want to trace through `grad` and `vmap`
if obj in (
torch.func.grad,
deprecated_func.grad,
torch.func.vmap,
deprecated_func.vmap,
torch.nn.functional.triplet_margin_with_distance_loss,
torch.cond,
):
continue
if isinstance(obj, types.ModuleType):
if obj.__name__.startswith("torch.") and _is_allowed_module_prefix(
obj
):
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
_find_torch_objects(obj)
elif _is_allowed_module_prefix(obj):
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
elif inspect.getmodule(obj) is None and not is_safe_constant(obj):
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
_find_torch_objects(torch)
_find_torch_objects(math)
if config.trace_distributed:
from torch.distributed import _functional_collectives_impl as fci
for f in [
fci._all_gather_into_tensor,
fci._all_reduce,
fci._reduce_scatter_tensor,
fci._all_reduce_coalesced,
fci._all_gather_into_tensor_coalesced,
fci._reduce_scatter_tensor_coalesced,
]:
torch_object_ids[id(f)] = repr(f)
# torch.Tensor.{fn}
for name in dir(torch.Tensor):
method = getattr(torch.Tensor, name)
if isinstance(
method, (types.MethodDescriptorType, types.WrapperDescriptorType)
):
torch_object_ids[id(method)] = f"torch.Tensor.{name}"
for idx in _disallowed_function_ids():
if idx in torch_object_ids:
del torch_object_ids[idx]
for extra in (is_fx_tracing, is_compiling):
torch_object_ids[id(extra)] = f"{extra.__module__}.{extra.__name__}"
return torch_object_ids
@FunctionIdSet
def _allowed_user_defined_function_ids() -> Dict[int, str]:
rv: Dict[int, str] = {}
return rv
@FunctionIdSet
def _builtin_function_ids() -> Dict[int, str]:
rv = {
id(v): f"builtins.{k}"
for k, v in builtins.__dict__.items()
if not k.startswith("_") and callable(v)
}
rv.update(
{
id(v): f"operator.{k}"
for k, v in operator.__dict__.items()
if not k.startswith("_") and callable(v)
}
)
rv.update(
{id(v): f"functools.{v.__name__}" for v in (itertools.chain, itertools.islice)}
)
rv.update({id(cast): "typing.cast"})
rv[id(functools.reduce)] = "functools.reduce"
return rv
@FunctionIdSet
def _numpy_function_ids() -> Dict[int, str]:
rv = dict()
for mod in NP_SUPPORTED_MODULES:
rv.update(
{
id(v): f"{mod.__name__}.{k}"
for k, v in mod.__dict__.items()
if callable(v)
and (getattr(v, "__module__", None) or mod.__name__) == mod.__name__
}
)
return rv
@FunctionIdSet
def _builtin_constant_ids() -> Dict[int, str]:
"""
Collects constant builtins by eliminating callable items.
"""
rv = {
id(v): f"builtins.{k}"
for k, v in builtins.__dict__.items()
if not k.startswith("_") and not callable(v)
}
return rv
_lazy_module_init: Dict[str, List[Callable[[], None]]] = defaultdict(list)
def add_module_init_func(name: str, init_func: Callable[[], None]) -> None:
"""Register a module without eagerly importing it"""
# If the module is already imported, eagerly run init
assert "." not in name, f"Expected a root module name, but got {name}"
if name in sys.modules:
init_func()
# Module is not yet imported, delay processing until needed
assert name not in _lazy_module_init
_lazy_module_init[name].append(init_func)
def _maybe_init_lazy_module(obj: object) -> None:
module = getattr(obj, "__module__", None)
if module is None:
return
base_module = module.split(".")[0]
init_funcs = _lazy_module_init.pop(base_module, None)
if init_funcs is not None:
for fn in init_funcs:
fn()
def is_allowed(obj) -> bool:
"""Is this safe to trace like torch.add ?"""
_maybe_init_lazy_module(obj)
if id(obj) in _disallowed_function_ids:
return False
if id(obj) in _allowed_function_ids:
return True
# torch.ops is populated lazily so we don't necessarily have them in
# _allowed_function_ids. Figure it out by testing the type instead
# in those cases
return isinstance(
obj,
(torch._ops.OpOverloadPacket, torch._ops.OpOverload, torch._ops._OpNamespace),
)
def is_user_defined_allowed(obj) -> bool:
_maybe_init_lazy_module(obj)
return id(obj) in _allowed_user_defined_function_ids
def is_forbidden(obj) -> bool:
_maybe_init_lazy_module(obj)
return getattr(obj, "_dynamo_forbidden", False)
def torch_get_name(obj, default) -> str:
"""Convert a torch.* function to a string"""
return _allowed_function_ids.get_name(id(obj), default)
def is_builtin_callable(obj) -> bool:
return id(obj) in _builtin_function_ids
def is_builtin_constant(obj) -> bool:
return id(obj) in _builtin_constant_ids
def is_numpy(obj) -> bool:
if np is None:
return False
return isinstance(obj, (np.ndarray, np.generic)) or id(obj) in _numpy_function_ids