[FX] torch.fx.symbolic_trace patching improvements and `math.*` support (#50793)

Summary:
This contains some improvements and refactoring to how patching is done in `torch.fx.symbolic_trace`.

1) Functions from `math.*` are now supported without needing to call `torch.fx.wrap()`.  `wrap()` actually errors on some of these function because they are written in C and don't have `__code__` requiring use of the string version.  `math` usage is relatively common, for example [BERT uses math.sqrt here](https://github.com/pytorch/benchmark/blob/6f79061bd145eeaa9b4a75847939901fd245ddf9/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/attention/single.py#L16).  Both `math.sqrt()` and `from math import sqrt` (copying to module namespace) are supported.  When modules are called FX now searches the module's global scope to find methods to patch.

2) [Guarded behind `env FX_PATCH_GETITEM=1`] Fixes a failed trace of [PositionalEmbedding from BERT](https://github.com/pytorch/benchmark/blob/6f79061bd145eeaa9b4a75847939901fd245ddf9/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py#L24), which failed to trace with the error `TypeError: slice indices must be integers or None or have an __index__ method` (a Proxy() is getting passed into `Tensor.__getitem__`).  See https://github.com/pytorch/pytorch/issues/50710 for why this is disabled by default.

3) Support for automatically wrapping methods that may have been copied to a different module scope via an import like `from foo import wrapped_function`.  This also isn't exposed in `torch.fx.wrap`, but is used to implement `math.*` support.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50793

Test Plan: Added unittests to check each feature

Reviewed By: jamesr66a

Differential Revision: D25999788

Pulled By: jansel

fbshipit-source-id: f1ce11a69b7d97f26c9e2741c6acf9c513a84467
diff --git a/test/test_fx.py b/test/test_fx.py
index e91cf4f..c0a5af4 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -1,13 +1,18 @@
+import builtins
+import contextlib
+import copy
+import functools
+import math
+import numbers
+import operator
+import os
+import pickle
+import sys
 import torch
 import unittest
-import operator
-import numbers
-import pickle
-import copy
-import sys
-import functools
-import contextlib
+from math import sqrt
 from pathlib import Path
+from torch.multiprocessing import Process
 from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Tracer, Graph, wrap
 from torch.fx.experimental import shape_prop
 from torch.fx.immutable_collections import immutable_dict, immutable_list
@@ -58,6 +63,12 @@
 def wrapped_via_decorator(a):
     return a + 1
 
+
+real_wrapped_via_decorator = wrapped_via_decorator
+real_a_lifed_leaf = a_lifted_leaf
+real_a_lifed_leaf2 = a_lifted_leaf2
+_sqrt = sqrt
+
 wrap('wrapper_fn')
 
 def wrapper_fn(x):
@@ -242,6 +253,7 @@
         m = symbolic_trace(to_trace)
         self.assertIn('a_lifted_leaf', m.code)
         self.assertEqual(27, m(2))
+        self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
 
     def test_wrap_fn_directly(self):
         self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
@@ -252,6 +264,7 @@
         m = symbolic_trace(to_trace)
         self.assertIn('a_lifted_leaf2', m.code)
         self.assertEqual(27, m(2))
+        self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
 
     def test_wrapped_via_decorator(self):
         self.assertEqual(wrapped_via_decorator(0), 1)
@@ -262,6 +275,8 @@
         m = symbolic_trace(to_trace)
         self.assertIn('wrapped_via_decorator', m.code)
         self.assertEqual(m(0), 1)
+        self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
+        self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
 
     def test_graph_edit_with_proxy(self):
         class M(torch.nn.Module):
@@ -755,6 +770,26 @@
         traced2 = symbolic_trace(FXLenTest2())
         inp = torch.rand(3, 4)
         self.assertEqual(traced2(inp), inp + 3.0)
+        self.assertIs(len, builtins.len)
+
+    def test_sqrt(self):
+        class Sqrt1(torch.nn.Module):
+            def forward(self, x):
+                return sqrt(x.size(0))
+
+        class Sqrt2(torch.nn.Module):
+            def forward(self, x):
+                return math.sqrt(x.size(0))
+
+        class Sqrt3(torch.nn.Module):
+            def forward(self, x):
+                return x + math.sqrt(2) + sqrt(2)
+
+        self.checkGraphModule(Sqrt1(), [torch.zeros(8)])
+        self.checkGraphModule(Sqrt2(), [torch.zeros(8)])
+        self.checkGraphModule(Sqrt3(), [torch.zeros(8)])
+        self.assertIs(sqrt, _sqrt)
+        self.assertIs(math.sqrt, _sqrt)
 
     def test_torch_custom_ops(self):
         class M(torch.nn.Module):
@@ -1309,6 +1344,42 @@
         scripted = torch.jit.script(traced)
         self.assertIn("-> List[str]", scripted.code)
 
+    def getitem_inner(self):
+        class GetItemBase(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.register_buffer('pe', torch.randn(8, 8))
+
+        class GetItem1(GetItemBase):
+            def forward(self, x):
+                return self.pe[:, :x.size(0)]
+
+        class GetItem2(GetItemBase):
+            def forward(self, x):
+                return self.pe[x.size(0)]
+
+        class GetItem3(GetItemBase):
+            def forward(self, x):
+                return self.pe[4]  # fx creates `self._tensor_constant0` here
+
+        self.checkGraphModule(GetItem1(), [torch.zeros(4)])
+        self.checkGraphModule(GetItem2(), [torch.zeros(4)])
+        self.checkGraphModule(GetItem3(), [torch.zeros(4)])
+
+    @unittest.skipUnless(os.environ.get("FX_PATCH_GETITEM") == "1",
+                         "Will be checked in test_getitem_subproc")
+    def test_getitem(self):
+        self.getitem_inner()
+
+    def test_getitem_subproc(self):
+        # need to run this test in a subproc to work around:
+        #   https://github.com/pytorch/pytorch/issues/50710
+        proc = Process(target=run_getitem_target)
+        proc.start()
+        proc.join()
+        self.assertEqual(proc.exitcode, 0)
+
+
     def test_user_friendly_call_provenance_with_function(self):
         def fn(x):
             return wrapper_fn(x)
@@ -1367,5 +1438,15 @@
             i += 1
         self.assertEqual(i, 3)
 
+
+def run_getitem_target():
+    from torch.fx.symbolic_trace import _wrapped_methods_to_patch
+    _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
+    try:
+        TestFX().getitem_inner()
+    finally:
+        _wrapped_methods_to_patch.pop()
+
+
 if __name__ == '__main__':
     run_tests()
diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py
index df24d05..ec36735 100644
--- a/torch/fx/symbolic_trace.py
+++ b/torch/fx/symbolic_trace.py
@@ -1,7 +1,11 @@
 import builtins
+import functools
 import inspect
-from types import CodeType, FunctionType
+import math
+import os
+from types import CodeType, FunctionType, ModuleType
 from typing import Any, Dict, NamedTuple, Optional, Set, Tuple, List, Callable, Union
+from itertools import chain
 import torch
 from torch._C import ScriptObject  # type: ignore
 
@@ -12,6 +16,11 @@
 
 HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
 
+# These need to run in global scope to handle nested calls correctly
+_orig_module_call : Callable = torch.nn.Module.__call__
+_orig_module_getattr : Callable = torch.nn.Module.__getattr__
+
+
 def _patch_function(fn: FunctionType, nargs: int) -> FunctionType:
     co = fn.__code__
     co_flags = co.co_flags & ~HAS_VARSTUFF
@@ -49,9 +58,30 @@
     process. The different behaviors that can be overridden are described
     in the docstrings of the methods on this class.
     """
-    def __init__(self):
+    def __init__(self, autowrap_modules : Tuple[ModuleType] = (math, )):
+        """
+        Construct a Tracer object.
+
+        Args:
+
+            autowrap_modules (List[ModuleType]): defaults to `[math]`,
+                Python modules whose functions should be wrapped automatically
+                without needing to use fx.wrap().
+        """
+
         super().__init__()
 
+        # Functions we will eagerly wrap when we see them while tracing
+        # this captures both `math.sqrt()` and `from math import sqrt` automatically
+        self._autowrap_function_ids: Set[int] = {
+            id(value) for name, value in chain(*[m.__dict__.items() for m in autowrap_modules])
+            if not name.startswith("_") and callable(value)}
+
+        # Python modules to apply autowrap to at the start, in addition to
+        # modules we see while tracing
+        self._autowrap_search: List[ModuleType] = list(autowrap_modules)
+
+
     def create_arg(self, a: Any) -> 'Argument':
         """
         A method to specify the behavior of tracing when preparing values to
@@ -284,14 +314,16 @@
 
         assert isinstance(fn, FunctionType)
 
+        fn_globals = fn.__globals__  # run before it gets patched
         fn, args = self.create_args_for_root(fn, isinstance(root, torch.nn.Module))
 
         parameter_proxy_cache : Dict[str, Proxy] = {}  # Reduce number of get_attr calls
 
         # Method dispatch on parameters is not recorded unless it's directly used.
         # Thus, we need to insert a proxy when __getattr__ requests a parameter.
+        @functools.wraps(_orig_module_getattr)
         def module_getattr_wrapper(mod, attr):
-            attr_val = orig_getattr(mod, attr)
+            attr_val = _orig_module_getattr(mod, attr)
             if isinstance(attr_val, torch.nn.Parameter):
                 for n, p in self.root.named_parameters():
                     if attr_val is p:
@@ -300,36 +332,65 @@
                         return parameter_proxy_cache[n]
             return attr_val
 
+        @functools.wraps(_orig_module_call)
         def module_call_wrapper(mod, *args, **kwargs):
             def forward(*args, **kwargs):
-                return orig_call(mod, *args, **kwargs)
+                return _orig_module_call(mod, *args, **kwargs)
 
+            _autowrap_check(patcher, getattr(getattr(mod, "forward", mod), "__globals__", {}),
+                            self._autowrap_function_ids)
             return self.call_module(mod, forward, args, kwargs)
 
-        orig_call = torch.nn.Module.__call__
-        orig_getattr = torch.nn.Module.__getattr__
-        orig_fns : List[PatchedFn] = []
-
-        try:
-            # Seems to be a mypy limitation: https://github.com/python/mypy/issues/2427
-            torch.nn.Module.__getattr__ = module_getattr_wrapper  # type: ignore
-            torch.nn.Module.__call__ = module_call_wrapper
-
-            _patch_wrapped_functions(orig_fns)
+        with _Patcher() as patcher:
+            # allow duplicate patches to support the case of nested calls
+            patcher.patch_method(torch.nn.Module, "__getattr__", module_getattr_wrapper, deduplicate=False)
+            patcher.patch_method(torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False)
+            _patch_wrapped_functions(patcher)
+            _autowrap_check(patcher, fn_globals, self._autowrap_function_ids)
+            for module in self._autowrap_search:
+                _autowrap_check(patcher, module.__dict__, self._autowrap_function_ids)
 
             self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
                              type_expr=fn.__annotations__.get('return', None))
-        finally:
-            _unpatch_wrapped_functions(orig_fns)
-            torch.nn.Module.__call__ = orig_call
-            torch.nn.Module.__getattr__ = orig_getattr  # type: ignore
+
         return self.graph
 
+
 # List of pairs of (global dict, function name) functions
 # to patch for the purposes of the wrap() API.
 _wrapped_fns_to_patch : List[Tuple[dict, str]] = []
 
+# List of methods on classes to wrap (class type, function name)
+# this currently only works for Tensor.* methods that aren't traced properly
+_wrapped_methods_to_patch : List[Tuple[type, str]] = []
+
+if os.environ.get("FX_PATCH_GETITEM") == "1":
+    # This change is needed to trace models like PositionalEmbedding from BERT:
+    # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py  # noqa
+    # but causes issues in quantization documented here:
+    # https://github.com/pytorch/pytorch/issues/50710
+    # once that is fixed we can make this the default behavior.
+    _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
+
+
+def _find_proxy(*objects_to_search):
+    """
+    Recursively search a data structure for a Proxy() and return it,
+    return None if not found.
+    """
+    proxy = None
+
+    def find_proxy(x):
+        nonlocal proxy
+        if isinstance(x, Proxy):
+            proxy = x
+
+    map_aggregate(objects_to_search, find_proxy)
+    return proxy
+
+
 def _create_wrapped_func(orig_fn):
+    @functools.wraps(orig_fn)
     def wrapped(*args, **kwargs):
         """
         Given an closed-over ``orig_function`` to invoke, search the args and kwargs for
@@ -337,77 +398,136 @@
         call to this leaf function directly. Otherwise, just return the results of
         this function call, as this function is not being traced.
         """
-        proxy = None
-
-        def find_proxy(x):
-            nonlocal proxy
-            if isinstance(x, Proxy):
-                proxy = x
-
-        map_aggregate(args, find_proxy)
-        map_aggregate(kwargs, find_proxy)
-
+        proxy = _find_proxy(args, kwargs)
         if proxy is not None:
             return proxy.tracer.create_proxy('call_function', orig_fn, args, kwargs)
-        else:
-            return orig_fn(*args, **kwargs)
+        return orig_fn(*args, **kwargs)
 
     return wrapped
 
-class PatchedFn(NamedTuple):
-    frame_dict : Dict[str, Any]
+
+def _create_wrapped_method(cls, name):
+    orig_fn = getattr(cls, name)
+
+    @functools.wraps(orig_fn)
+    def wrapped(*args, **kwargs):
+        """
+        Search the args and kwargs for a Proxy object. If there is one,
+        emit a ``call_method`` node to preserve the call to this method
+        directly. Otherwise, just return the results of this function
+        call, as this function is not being traced.
+        """
+        proxy = _find_proxy(args, kwargs)
+        if proxy is not None:
+            return proxy.tracer.create_proxy('call_method', name, args, kwargs)
+        return orig_fn(*args, **kwargs)
+
+    return wrapped
+
+
+class _PatchedFn(NamedTuple):
+    frame_dict : Any
     fn_name : str
     orig_fn : Any
 
-# isinstance(orig_fn, NoneSentinel) if the original global namespace
-# did not contain this function at the time of patching. This can
-# occur, for example, when patching a builtin function
-class PatchedFnNoneSentinel:
-    pass
+    def revert(self):
+        raise NotImplementedError()
 
-def _patch_wrapped_functions(orig_fns : List[PatchedFn]):
+
+class _PatchedFnSetItem(_PatchedFn):
+    def revert(self):
+        self.frame_dict[self.fn_name] = self.orig_fn
+
+
+class _PatchedFnDel(_PatchedFn):
+    def revert(self):
+        del self.frame_dict[self.fn_name]
+
+
+class _PatchedFnSetAttr(_PatchedFn):
+    def revert(self):
+        setattr(self.frame_dict, self.fn_name, self.orig_fn)
+
+
+class _Patcher(object):
+    def __init__(self):
+        super(_Patcher, self).__init__()
+        self.patches_made : List[_PatchedFn] = []
+        self.visited : Set[int] = set()
+
+    def patch(self, frame_dict : Dict[str, Any], name : str, new_fn : Callable,
+              deduplicate : bool = True):
+        """
+        Replace frame_dict[name] with new_fn until we exit the context manager.
+        """
+        new_fn.__fx_already_patched = deduplicate  # type: ignore
+        if name not in frame_dict and hasattr(builtins, name):
+            self.patches_made.append(_PatchedFnDel(frame_dict, name, None))
+        elif getattr(frame_dict[name], "__fx_already_patched", False):
+            return  # already patched, no need to do it again
+        else:
+            self.patches_made.append(_PatchedFnSetItem(frame_dict, name, frame_dict[name]))
+        frame_dict[name] = new_fn
+
+    def patch_method(self, cls: type, name : str, new_fn : Callable,
+                     deduplicate : bool = True):
+        """
+        Replace object_or_dict.name with new_fn until we exit the context manager.
+        """
+        new_fn.__fx_already_patched = deduplicate  # type: ignore
+        orig_fn = getattr(cls, name)
+        if getattr(orig_fn, "__fx_already_patched", False):
+            return  # already patched, no need to do it again
+        self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn))
+        setattr(cls, name, new_fn)
+
+    def visit_once(self, thing: Any):
+        """ Return True on the first call to with thing, otherwise false """
+        idx = id(thing)
+        if idx in self.visited:
+            return False
+        self.visited.add(idx)
+        return True
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        """
+        Undo all the changes made via self.patch() and self.patch_method()
+        """
+        while self.patches_made:
+            # unpatch in reverse order to handle duplicates correctly
+            self.patches_made.pop().revert()
+        self.visited.clear()
+
+
+def _patch_wrapped_functions(patcher : _Patcher):
     """
     Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap
-    the listed global functions in the `_create_wrapped_func` wrapper. Returns
-    a list of PatchedFn, which is a record specifiying a single function
-    entry that was patched and contains the original function for unpatching
-
-    Note orig_fns is taken by reference and updated as we go to facilitate
-    reverting patching if this function itself throws an exception.
+    the listed global functions in the `_create_wrapped_func` wrapper.
     """
-    # Set to deduplicate entries. Wrapping a function multiple times would
-    # be an error, since it would cause a `call_function` node for the
-    # wrapper to be emitted rather than the actual underlying function
-    #
-    # Use id(frame_dict) as a hashable identity here since none of the
-    # frame dicts should be destroyed during symtracing
-    processed_entries : Set[Tuple[int, str]] = set()
-
     for frame_dict, name in _wrapped_fns_to_patch:
-        if (id(frame_dict), name) in processed_entries:
-            continue
         if name not in frame_dict and hasattr(builtins, name):
             orig_fn = getattr(builtins, name)
-            orig_fns.append(PatchedFn(frame_dict, name, PatchedFnNoneSentinel()))
         else:
             orig_fn = frame_dict[name]
-            orig_fns.append(PatchedFn(frame_dict, name, orig_fn))
+        patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn))
 
-        frame_dict[name] = _create_wrapped_func(orig_fn)
+    for cls, name in _wrapped_methods_to_patch:
+        patcher.patch_method(cls, name, _create_wrapped_method(cls, name))
 
-        processed_entries.add((id(frame_dict), name))
 
-def _unpatch_wrapped_functions(orig_fns : List[PatchedFn]):
+def _autowrap_check(patcher : _Patcher, frame_dict : Dict[str, Any], function_ids : Set[int]):
     """
-    Given the ``orig_fns`` dict that ``_patch_wrapped_functions``,
-    replace all of the global functions with the original global functions
-    that were there before symbolic tracing.
+    Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them.
+    This method searches a scope for them and patches them if found.
     """
-    for frame_dict, fn_name, orig_fn in orig_fns:
-        if isinstance(orig_fn, PatchedFnNoneSentinel):
-            del frame_dict[fn_name]
-        else:
-            frame_dict[fn_name] = orig_fn
+    if patcher.visit_once(frame_dict):
+        for name, value in frame_dict.items():
+            if not name.startswith("_") and callable(value) and id(value) in function_ids:
+                patcher.patch(frame_dict, name, _create_wrapped_func(value))
+
 
 def wrap(fn_or_name : Union[str, Callable]):
     """
@@ -442,11 +562,7 @@
         fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the
             graph when it's called
     """
-    if callable(fn_or_name):
-        fn_name = fn_or_name.__code__.co_name
-    elif isinstance(fn_or_name, str):
-        fn_name = fn_or_name
-    else:
+    if not callable(fn_or_name) and not isinstance(fn_or_name, str):
         raise RuntimeError('Unsupported type for global function! Must be either a callable or '
                            'string name')
 
@@ -464,6 +580,8 @@
     if f.f_code.co_name != '<module>':
         raise NotImplementedError('wrap must be called at the top level of a module')
 
+    # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search
+    # semantics would be slightly different, but would add support `from x import wrapped_function`
     _wrapped_fns_to_patch.append((f.f_globals, fn_name))
     return fn_or_name