[FX] torch.fx.symbolic_trace patching improvements and `math.*` support (#50793)
Summary:
This contains some improvements and refactoring to how patching is done in `torch.fx.symbolic_trace`.
1) Functions from `math.*` are now supported without needing to call `torch.fx.wrap()`. `wrap()` actually errors on some of these function because they are written in C and don't have `__code__` requiring use of the string version. `math` usage is relatively common, for example [BERT uses math.sqrt here](https://github.com/pytorch/benchmark/blob/6f79061bd145eeaa9b4a75847939901fd245ddf9/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/attention/single.py#L16). Both `math.sqrt()` and `from math import sqrt` (copying to module namespace) are supported. When modules are called FX now searches the module's global scope to find methods to patch.
2) [Guarded behind `env FX_PATCH_GETITEM=1`] Fixes a failed trace of [PositionalEmbedding from BERT](https://github.com/pytorch/benchmark/blob/6f79061bd145eeaa9b4a75847939901fd245ddf9/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py#L24), which failed to trace with the error `TypeError: slice indices must be integers or None or have an __index__ method` (a Proxy() is getting passed into `Tensor.__getitem__`). See https://github.com/pytorch/pytorch/issues/50710 for why this is disabled by default.
3) Support for automatically wrapping methods that may have been copied to a different module scope via an import like `from foo import wrapped_function`. This also isn't exposed in `torch.fx.wrap`, but is used to implement `math.*` support.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50793
Test Plan: Added unittests to check each feature
Reviewed By: jamesr66a
Differential Revision: D25999788
Pulled By: jansel
fbshipit-source-id: f1ce11a69b7d97f26c9e2741c6acf9c513a84467
diff --git a/test/test_fx.py b/test/test_fx.py
index e91cf4f..c0a5af4 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -1,13 +1,18 @@
+import builtins
+import contextlib
+import copy
+import functools
+import math
+import numbers
+import operator
+import os
+import pickle
+import sys
import torch
import unittest
-import operator
-import numbers
-import pickle
-import copy
-import sys
-import functools
-import contextlib
+from math import sqrt
from pathlib import Path
+from torch.multiprocessing import Process
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Tracer, Graph, wrap
from torch.fx.experimental import shape_prop
from torch.fx.immutable_collections import immutable_dict, immutable_list
@@ -58,6 +63,12 @@
def wrapped_via_decorator(a):
return a + 1
+
+real_wrapped_via_decorator = wrapped_via_decorator
+real_a_lifed_leaf = a_lifted_leaf
+real_a_lifed_leaf2 = a_lifted_leaf2
+_sqrt = sqrt
+
wrap('wrapper_fn')
def wrapper_fn(x):
@@ -242,6 +253,7 @@
m = symbolic_trace(to_trace)
self.assertIn('a_lifted_leaf', m.code)
self.assertEqual(27, m(2))
+ self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
def test_wrap_fn_directly(self):
self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
@@ -252,6 +264,7 @@
m = symbolic_trace(to_trace)
self.assertIn('a_lifted_leaf2', m.code)
self.assertEqual(27, m(2))
+ self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
def test_wrapped_via_decorator(self):
self.assertEqual(wrapped_via_decorator(0), 1)
@@ -262,6 +275,8 @@
m = symbolic_trace(to_trace)
self.assertIn('wrapped_via_decorator', m.code)
self.assertEqual(m(0), 1)
+ self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
+ self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
def test_graph_edit_with_proxy(self):
class M(torch.nn.Module):
@@ -755,6 +770,26 @@
traced2 = symbolic_trace(FXLenTest2())
inp = torch.rand(3, 4)
self.assertEqual(traced2(inp), inp + 3.0)
+ self.assertIs(len, builtins.len)
+
+ def test_sqrt(self):
+ class Sqrt1(torch.nn.Module):
+ def forward(self, x):
+ return sqrt(x.size(0))
+
+ class Sqrt2(torch.nn.Module):
+ def forward(self, x):
+ return math.sqrt(x.size(0))
+
+ class Sqrt3(torch.nn.Module):
+ def forward(self, x):
+ return x + math.sqrt(2) + sqrt(2)
+
+ self.checkGraphModule(Sqrt1(), [torch.zeros(8)])
+ self.checkGraphModule(Sqrt2(), [torch.zeros(8)])
+ self.checkGraphModule(Sqrt3(), [torch.zeros(8)])
+ self.assertIs(sqrt, _sqrt)
+ self.assertIs(math.sqrt, _sqrt)
def test_torch_custom_ops(self):
class M(torch.nn.Module):
@@ -1309,6 +1344,42 @@
scripted = torch.jit.script(traced)
self.assertIn("-> List[str]", scripted.code)
+ def getitem_inner(self):
+ class GetItemBase(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer('pe', torch.randn(8, 8))
+
+ class GetItem1(GetItemBase):
+ def forward(self, x):
+ return self.pe[:, :x.size(0)]
+
+ class GetItem2(GetItemBase):
+ def forward(self, x):
+ return self.pe[x.size(0)]
+
+ class GetItem3(GetItemBase):
+ def forward(self, x):
+ return self.pe[4] # fx creates `self._tensor_constant0` here
+
+ self.checkGraphModule(GetItem1(), [torch.zeros(4)])
+ self.checkGraphModule(GetItem2(), [torch.zeros(4)])
+ self.checkGraphModule(GetItem3(), [torch.zeros(4)])
+
+ @unittest.skipUnless(os.environ.get("FX_PATCH_GETITEM") == "1",
+ "Will be checked in test_getitem_subproc")
+ def test_getitem(self):
+ self.getitem_inner()
+
+ def test_getitem_subproc(self):
+ # need to run this test in a subproc to work around:
+ # https://github.com/pytorch/pytorch/issues/50710
+ proc = Process(target=run_getitem_target)
+ proc.start()
+ proc.join()
+ self.assertEqual(proc.exitcode, 0)
+
+
def test_user_friendly_call_provenance_with_function(self):
def fn(x):
return wrapper_fn(x)
@@ -1367,5 +1438,15 @@
i += 1
self.assertEqual(i, 3)
+
+def run_getitem_target():
+ from torch.fx.symbolic_trace import _wrapped_methods_to_patch
+ _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
+ try:
+ TestFX().getitem_inner()
+ finally:
+ _wrapped_methods_to_patch.pop()
+
+
if __name__ == '__main__':
run_tests()
diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py
index df24d05..ec36735 100644
--- a/torch/fx/symbolic_trace.py
+++ b/torch/fx/symbolic_trace.py
@@ -1,7 +1,11 @@
import builtins
+import functools
import inspect
-from types import CodeType, FunctionType
+import math
+import os
+from types import CodeType, FunctionType, ModuleType
from typing import Any, Dict, NamedTuple, Optional, Set, Tuple, List, Callable, Union
+from itertools import chain
import torch
from torch._C import ScriptObject # type: ignore
@@ -12,6 +16,11 @@
HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
+# These need to run in global scope to handle nested calls correctly
+_orig_module_call : Callable = torch.nn.Module.__call__
+_orig_module_getattr : Callable = torch.nn.Module.__getattr__
+
+
def _patch_function(fn: FunctionType, nargs: int) -> FunctionType:
co = fn.__code__
co_flags = co.co_flags & ~HAS_VARSTUFF
@@ -49,9 +58,30 @@
process. The different behaviors that can be overridden are described
in the docstrings of the methods on this class.
"""
- def __init__(self):
+ def __init__(self, autowrap_modules : Tuple[ModuleType] = (math, )):
+ """
+ Construct a Tracer object.
+
+ Args:
+
+ autowrap_modules (List[ModuleType]): defaults to `[math]`,
+ Python modules whose functions should be wrapped automatically
+ without needing to use fx.wrap().
+ """
+
super().__init__()
+ # Functions we will eagerly wrap when we see them while tracing
+ # this captures both `math.sqrt()` and `from math import sqrt` automatically
+ self._autowrap_function_ids: Set[int] = {
+ id(value) for name, value in chain(*[m.__dict__.items() for m in autowrap_modules])
+ if not name.startswith("_") and callable(value)}
+
+ # Python modules to apply autowrap to at the start, in addition to
+ # modules we see while tracing
+ self._autowrap_search: List[ModuleType] = list(autowrap_modules)
+
+
def create_arg(self, a: Any) -> 'Argument':
"""
A method to specify the behavior of tracing when preparing values to
@@ -284,14 +314,16 @@
assert isinstance(fn, FunctionType)
+ fn_globals = fn.__globals__ # run before it gets patched
fn, args = self.create_args_for_root(fn, isinstance(root, torch.nn.Module))
parameter_proxy_cache : Dict[str, Proxy] = {} # Reduce number of get_attr calls
# Method dispatch on parameters is not recorded unless it's directly used.
# Thus, we need to insert a proxy when __getattr__ requests a parameter.
+ @functools.wraps(_orig_module_getattr)
def module_getattr_wrapper(mod, attr):
- attr_val = orig_getattr(mod, attr)
+ attr_val = _orig_module_getattr(mod, attr)
if isinstance(attr_val, torch.nn.Parameter):
for n, p in self.root.named_parameters():
if attr_val is p:
@@ -300,36 +332,65 @@
return parameter_proxy_cache[n]
return attr_val
+ @functools.wraps(_orig_module_call)
def module_call_wrapper(mod, *args, **kwargs):
def forward(*args, **kwargs):
- return orig_call(mod, *args, **kwargs)
+ return _orig_module_call(mod, *args, **kwargs)
+ _autowrap_check(patcher, getattr(getattr(mod, "forward", mod), "__globals__", {}),
+ self._autowrap_function_ids)
return self.call_module(mod, forward, args, kwargs)
- orig_call = torch.nn.Module.__call__
- orig_getattr = torch.nn.Module.__getattr__
- orig_fns : List[PatchedFn] = []
-
- try:
- # Seems to be a mypy limitation: https://github.com/python/mypy/issues/2427
- torch.nn.Module.__getattr__ = module_getattr_wrapper # type: ignore
- torch.nn.Module.__call__ = module_call_wrapper
-
- _patch_wrapped_functions(orig_fns)
+ with _Patcher() as patcher:
+ # allow duplicate patches to support the case of nested calls
+ patcher.patch_method(torch.nn.Module, "__getattr__", module_getattr_wrapper, deduplicate=False)
+ patcher.patch_method(torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False)
+ _patch_wrapped_functions(patcher)
+ _autowrap_check(patcher, fn_globals, self._autowrap_function_ids)
+ for module in self._autowrap_search:
+ _autowrap_check(patcher, module.__dict__, self._autowrap_function_ids)
self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
type_expr=fn.__annotations__.get('return', None))
- finally:
- _unpatch_wrapped_functions(orig_fns)
- torch.nn.Module.__call__ = orig_call
- torch.nn.Module.__getattr__ = orig_getattr # type: ignore
+
return self.graph
+
# List of pairs of (global dict, function name) functions
# to patch for the purposes of the wrap() API.
_wrapped_fns_to_patch : List[Tuple[dict, str]] = []
+# List of methods on classes to wrap (class type, function name)
+# this currently only works for Tensor.* methods that aren't traced properly
+_wrapped_methods_to_patch : List[Tuple[type, str]] = []
+
+if os.environ.get("FX_PATCH_GETITEM") == "1":
+ # This change is needed to trace models like PositionalEmbedding from BERT:
+ # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py # noqa
+ # but causes issues in quantization documented here:
+ # https://github.com/pytorch/pytorch/issues/50710
+ # once that is fixed we can make this the default behavior.
+ _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
+
+
+def _find_proxy(*objects_to_search):
+ """
+ Recursively search a data structure for a Proxy() and return it,
+ return None if not found.
+ """
+ proxy = None
+
+ def find_proxy(x):
+ nonlocal proxy
+ if isinstance(x, Proxy):
+ proxy = x
+
+ map_aggregate(objects_to_search, find_proxy)
+ return proxy
+
+
def _create_wrapped_func(orig_fn):
+ @functools.wraps(orig_fn)
def wrapped(*args, **kwargs):
"""
Given an closed-over ``orig_function`` to invoke, search the args and kwargs for
@@ -337,77 +398,136 @@
call to this leaf function directly. Otherwise, just return the results of
this function call, as this function is not being traced.
"""
- proxy = None
-
- def find_proxy(x):
- nonlocal proxy
- if isinstance(x, Proxy):
- proxy = x
-
- map_aggregate(args, find_proxy)
- map_aggregate(kwargs, find_proxy)
-
+ proxy = _find_proxy(args, kwargs)
if proxy is not None:
return proxy.tracer.create_proxy('call_function', orig_fn, args, kwargs)
- else:
- return orig_fn(*args, **kwargs)
+ return orig_fn(*args, **kwargs)
return wrapped
-class PatchedFn(NamedTuple):
- frame_dict : Dict[str, Any]
+
+def _create_wrapped_method(cls, name):
+ orig_fn = getattr(cls, name)
+
+ @functools.wraps(orig_fn)
+ def wrapped(*args, **kwargs):
+ """
+ Search the args and kwargs for a Proxy object. If there is one,
+ emit a ``call_method`` node to preserve the call to this method
+ directly. Otherwise, just return the results of this function
+ call, as this function is not being traced.
+ """
+ proxy = _find_proxy(args, kwargs)
+ if proxy is not None:
+ return proxy.tracer.create_proxy('call_method', name, args, kwargs)
+ return orig_fn(*args, **kwargs)
+
+ return wrapped
+
+
+class _PatchedFn(NamedTuple):
+ frame_dict : Any
fn_name : str
orig_fn : Any
-# isinstance(orig_fn, NoneSentinel) if the original global namespace
-# did not contain this function at the time of patching. This can
-# occur, for example, when patching a builtin function
-class PatchedFnNoneSentinel:
- pass
+ def revert(self):
+ raise NotImplementedError()
-def _patch_wrapped_functions(orig_fns : List[PatchedFn]):
+
+class _PatchedFnSetItem(_PatchedFn):
+ def revert(self):
+ self.frame_dict[self.fn_name] = self.orig_fn
+
+
+class _PatchedFnDel(_PatchedFn):
+ def revert(self):
+ del self.frame_dict[self.fn_name]
+
+
+class _PatchedFnSetAttr(_PatchedFn):
+ def revert(self):
+ setattr(self.frame_dict, self.fn_name, self.orig_fn)
+
+
+class _Patcher(object):
+ def __init__(self):
+ super(_Patcher, self).__init__()
+ self.patches_made : List[_PatchedFn] = []
+ self.visited : Set[int] = set()
+
+ def patch(self, frame_dict : Dict[str, Any], name : str, new_fn : Callable,
+ deduplicate : bool = True):
+ """
+ Replace frame_dict[name] with new_fn until we exit the context manager.
+ """
+ new_fn.__fx_already_patched = deduplicate # type: ignore
+ if name not in frame_dict and hasattr(builtins, name):
+ self.patches_made.append(_PatchedFnDel(frame_dict, name, None))
+ elif getattr(frame_dict[name], "__fx_already_patched", False):
+ return # already patched, no need to do it again
+ else:
+ self.patches_made.append(_PatchedFnSetItem(frame_dict, name, frame_dict[name]))
+ frame_dict[name] = new_fn
+
+ def patch_method(self, cls: type, name : str, new_fn : Callable,
+ deduplicate : bool = True):
+ """
+ Replace object_or_dict.name with new_fn until we exit the context manager.
+ """
+ new_fn.__fx_already_patched = deduplicate # type: ignore
+ orig_fn = getattr(cls, name)
+ if getattr(orig_fn, "__fx_already_patched", False):
+ return # already patched, no need to do it again
+ self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn))
+ setattr(cls, name, new_fn)
+
+ def visit_once(self, thing: Any):
+ """ Return True on the first call to with thing, otherwise false """
+ idx = id(thing)
+ if idx in self.visited:
+ return False
+ self.visited.add(idx)
+ return True
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """
+ Undo all the changes made via self.patch() and self.patch_method()
+ """
+ while self.patches_made:
+ # unpatch in reverse order to handle duplicates correctly
+ self.patches_made.pop().revert()
+ self.visited.clear()
+
+
+def _patch_wrapped_functions(patcher : _Patcher):
"""
Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap
- the listed global functions in the `_create_wrapped_func` wrapper. Returns
- a list of PatchedFn, which is a record specifiying a single function
- entry that was patched and contains the original function for unpatching
-
- Note orig_fns is taken by reference and updated as we go to facilitate
- reverting patching if this function itself throws an exception.
+ the listed global functions in the `_create_wrapped_func` wrapper.
"""
- # Set to deduplicate entries. Wrapping a function multiple times would
- # be an error, since it would cause a `call_function` node for the
- # wrapper to be emitted rather than the actual underlying function
- #
- # Use id(frame_dict) as a hashable identity here since none of the
- # frame dicts should be destroyed during symtracing
- processed_entries : Set[Tuple[int, str]] = set()
-
for frame_dict, name in _wrapped_fns_to_patch:
- if (id(frame_dict), name) in processed_entries:
- continue
if name not in frame_dict and hasattr(builtins, name):
orig_fn = getattr(builtins, name)
- orig_fns.append(PatchedFn(frame_dict, name, PatchedFnNoneSentinel()))
else:
orig_fn = frame_dict[name]
- orig_fns.append(PatchedFn(frame_dict, name, orig_fn))
+ patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn))
- frame_dict[name] = _create_wrapped_func(orig_fn)
+ for cls, name in _wrapped_methods_to_patch:
+ patcher.patch_method(cls, name, _create_wrapped_method(cls, name))
- processed_entries.add((id(frame_dict), name))
-def _unpatch_wrapped_functions(orig_fns : List[PatchedFn]):
+def _autowrap_check(patcher : _Patcher, frame_dict : Dict[str, Any], function_ids : Set[int]):
"""
- Given the ``orig_fns`` dict that ``_patch_wrapped_functions``,
- replace all of the global functions with the original global functions
- that were there before symbolic tracing.
+ Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them.
+ This method searches a scope for them and patches them if found.
"""
- for frame_dict, fn_name, orig_fn in orig_fns:
- if isinstance(orig_fn, PatchedFnNoneSentinel):
- del frame_dict[fn_name]
- else:
- frame_dict[fn_name] = orig_fn
+ if patcher.visit_once(frame_dict):
+ for name, value in frame_dict.items():
+ if not name.startswith("_") and callable(value) and id(value) in function_ids:
+ patcher.patch(frame_dict, name, _create_wrapped_func(value))
+
def wrap(fn_or_name : Union[str, Callable]):
"""
@@ -442,11 +562,7 @@
fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the
graph when it's called
"""
- if callable(fn_or_name):
- fn_name = fn_or_name.__code__.co_name
- elif isinstance(fn_or_name, str):
- fn_name = fn_or_name
- else:
+ if not callable(fn_or_name) and not isinstance(fn_or_name, str):
raise RuntimeError('Unsupported type for global function! Must be either a callable or '
'string name')
@@ -464,6 +580,8 @@
if f.f_code.co_name != '<module>':
raise NotImplementedError('wrap must be called at the top level of a module')
+ # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search
+ # semantics would be slightly different, but would add support `from x import wrapped_function`
_wrapped_fns_to_patch.append((f.f_globals, fn_name))
return fn_or_name