blob: 053ecfaad7a744c288f0579e07f6388968badf37 [file] [log] [blame]
import functools
import importlib
from .utils import hashable
from .variables import TorchCtxManagerClassVariable
"""
Map of torch objects to their tracing rules (Dynamo variables).
* TorchVariable: The functions should be put into the FX graph or can be constant folded. E.g.,
- torch.add: should be put into the FX graph.
- torch.is_floating_point: constant folded.
* TorchCtxManagerClassVariable: The context manager classes are supported by Dynamo. E.g., torch.no_grad
* SkipFilesVariable: The objects should be skipped from tracing.
* UserFunctionVariable: The functions should be inlined.
We explicitly list torch objects which should be wrapped as TorchCtxManagerClassVariable.
The initial list comes from the heuristic in test/dynamo/test_trace_rules.py:generate_allow_list.
For developers: If you add/remove a torch level API, it may trigger failures from
test/dynamo/test_trace_rules.py:test_torch_name_rule_map. To fix the failures:
If you are adding a new torch level API or Dynamo implementation:
* Add the name with TorchCtxManagerClassVariable to this map
if you are adding Dynamo implementation for that context manager.
* Remove the object name from test/dynamo/test_trace_rules.ignored_torch_name_rule_set if it's there.
If you are removing an existing torch level API:
* Remove the entry represented the API from this map or test/dynamo/test_trace_rules.ignored_torch_name_rule_set
depends on where it is.
TODO: Add torch object names mapping to TorchVariable for in graph and constant fold functions.
TODO: We would consolidate the skipfiles.check rules into trace_rules.lookup later.
TODO: We would support explictly list objects treated as skip/inline after the skipfiles.check
and trace_rules.lookup consolidation is done. Then the explicit listing of skip/inline objects have
a higher priority, which can be used to override the skipfiles.check rules in some cases.
"""
torch_name_rule_map = {
"torch._C.DisableTorchFunctionSubclass": TorchCtxManagerClassVariable,
"torch.amp.autocast_mode.autocast": TorchCtxManagerClassVariable,
"torch.autograd.grad_mode.enable_grad": TorchCtxManagerClassVariable,
"torch.autograd.grad_mode.inference_mode": TorchCtxManagerClassVariable,
"torch.autograd.grad_mode.no_grad": TorchCtxManagerClassVariable,
"torch.autograd.grad_mode.set_grad_enabled": TorchCtxManagerClassVariable,
"torch.autograd.profiler.profile": TorchCtxManagerClassVariable,
"torch.autograd.profiler.record_function": TorchCtxManagerClassVariable,
"torch.cpu.amp.autocast_mode.autocast": TorchCtxManagerClassVariable,
"torch.cuda.amp.autocast_mode.autocast": TorchCtxManagerClassVariable,
"torch.profiler.profiler.profile": TorchCtxManagerClassVariable,
}
@functools.lru_cache(None)
def get_torch_obj_rule_map():
d = dict()
for k, v in torch_name_rule_map.items():
obj = load_object(k)
assert obj not in d
d[obj] = v
return d
def load_object(name):
mod_name, obj_name = name.rsplit(".", 1)
mod = importlib.import_module(mod_name)
obj = getattr(mod, obj_name)
return obj
def lookup(obj):
if not hashable(obj):
return None
rule = get_torch_obj_rule_map().get(obj, None)
return rule