Revert D24665950: Create prototype for AST rewriter

Test Plan: revert-hammer

Differential Revision:
D24665950 (https://github.com/pytorch/pytorch/commit/54feb00bbde9f6220721259f4a352288ab610c03)

Original commit changeset: b72110436126

fbshipit-source-id: 961412df006acd33c91a745c809832d5c6494c76
diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py
index 54b140e..d5d3dae 100644
--- a/test/test_fx_experimental.py
+++ b/test/test_fx_experimental.py
@@ -2,14 +2,8 @@
 from torch.fx.symbolic_trace import symbolic_trace
 from torch.fx.experimental import GraphManipulation
 from torch.fx.experimental.Partitioner import Partitioner, Device, PartitionerConfig
-from torch.fx.experimental.rewriter import RewritingTracer
-from torch.fx.graph_module import GraphModule
 from torch.testing._internal.common_utils import run_tests
 from torch.testing._internal.jit_utils import JitTestCase
-from typing import Union, Callable
-
-def symbolic_trace_with_rewrite(root: Union[torch.nn.Module, Callable]) -> GraphModule:
-    return GraphModule(root if isinstance(root, torch.nn.Module) else torch.nn.Module(), RewritingTracer().trace(root))
 
 class TestFXExperimental(JitTestCase):
     def test_find_single_partition(self):
@@ -168,110 +162,5 @@
         self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset))
         assert len(module_with_submodules.graph.nodes) == 24
 
-    def test_call_to_assert_no_msg(self):
-
-        class M(torch.nn.Module):
-            def forward(self, a, b):
-                assert a == b
-                return a + b
-        m = M()
-        traced = symbolic_trace_with_rewrite(m)
-
-        # Make sure the graph is well-formed
-        traced.graph.lint(traced)
-
-        # Check the IR to make sure there's a call_function node with target == "Assert"
-        self.assertTrue(any(node.op == "call_function" and node.target == torch.Assert for node in traced.graph.nodes))
-
-        # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
-        traced(3, 3)
-        with self.assertRaisesRegex(AssertionError, ""):
-            traced(3, 5)
-
-        # Confirm that the output is correct
-        self.assertEqual(traced(3, 3), m(3, 3))
-
-    def test_call_to_assert_with_msg(self):
-
-        class M(torch.nn.Module):
-            def forward(self, a, b):
-                assert a == b, "test message"
-                return a + b
-        m = M()
-        traced = symbolic_trace_with_rewrite(m)
-
-        # Make sure the graph is well-formed
-        traced.graph.lint(traced)
-
-        # Check the IR to make sure there's a call_function node with target == "Assert"
-        self.assertTrue(any(node.op == "call_function" and node.target == torch.Assert for node in traced.graph.nodes))
-
-        # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
-        traced(3, 3)
-        with self.assertRaisesRegex(AssertionError, "test message"):
-            traced(3, 5)
-
-        # Confirm that the output is correct
-        self.assertEqual(traced(3, 3), m(3, 3))
-
-    def test_call_to_assert_with_empty_msg(self):
-
-        class M(torch.nn.Module):
-            def forward(self, a, b):
-                assert a == b, ""
-                return a + b
-        m = M()
-        traced = symbolic_trace_with_rewrite(m)
-
-        # Make sure the graph is well-formed
-        traced.graph.lint(traced)
-
-        # Check the IR to make sure there's a call_function node with target == "Assert"
-        self.assertTrue(any(node.op == "call_function" and node.target == torch.Assert for node in traced.graph.nodes))
-
-        # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
-        traced(3, 3)
-        with self.assertRaisesRegex(AssertionError, ""):
-            traced(3, 5)
-
-        # Confirm that the output is correct
-        self.assertEqual(traced(3, 3), m(3, 3))
-
-    def test_call_to_assert_with_multiline_message(self):
-
-        class M(torch.nn.Module):
-            def forward(self, a, b):
-                error_msg = """
-An error message with
-terrible spacing
-                """
-                assert a == b, error_msg
-                return a + b
-        m = M()
-        traced = symbolic_trace_with_rewrite(m)
-
-        # Make sure the graph is well-formed
-        traced.graph.lint(traced)
-
-        # Check the IR to make sure there's a call_function node with target == "Assert"
-        self.assertTrue(any(node.op == "call_function" and node.target == torch.Assert for node in traced.graph.nodes))
-
-        # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
-        error_msg = """
-An error message with
-terrible spacing
-    """
-        traced(3, 3)
-        with self.assertRaisesRegex(AssertionError, error_msg):
-            traced(3, 5)
-
-        # Confirm that the output is correct
-        self.assertEqual(traced(3, 3), m(3, 3))
-
-    def test_traceable_function_with_nonstandard_name(self):
-        def foo(x):
-            return torch.relu(x)
-        traced = symbolic_trace_with_rewrite(foo)
-
 if __name__ == '__main__':
     run_tests()
diff --git a/torch/fx/experimental/rewriter.py b/torch/fx/experimental/rewriter.py
deleted file mode 100644
index 9981c88..0000000
--- a/torch/fx/experimental/rewriter.py
+++ /dev/null
@@ -1,87 +0,0 @@
-import ast
-import inspect
-import textwrap
-import copy
-from types import FunctionType
-from typing import Union, Callable
-from torch.fx.symbolic_trace import Tracer
-from torch.fx.graph import Graph
-from torch.jit.frontend import normalize_source_lines
-import torch
-
-class AST_Rewriter(ast.NodeTransformer):
-    """
-    Take a FunctionType object representing a `forward` method, then
-    perform an AST rewrite to swap out nodes that are not symbolically
-    traceable with a callsite to the FX alternative.
-
-    To support swapping out an AST node, define a new `visit` method on
-    that node. For more details, see:
-    https://docs.python.org/3/library/ast.html#ast.NodeTransformer
-    """
-
-    def rewrite(self, fn: FunctionType):
-
-        # Normalize the source lines
-        sourcelines, _ = inspect.getsourcelines(fn)
-        sourcelines = normalize_source_lines(sourcelines)
-        source = ''.join(sourcelines)
-        normalized_str = textwrap.dedent(source)
-
-        # Rewrite the original AST
-        source_ast = ast.parse(normalized_str)
-        dest_ast = ast.fix_missing_locations(self.visit(source_ast))
-
-        # Pull out the compiled fucntion from the newly-created Module
-        code = compile(dest_ast, "", "exec")
-        globals_dict = copy.copy(fn.__globals__)
-        keys_before = set(globals_dict.keys())
-        exec(code, globals_dict)
-        new_keys = list(set(globals_dict.keys()) - keys_before)
-        assert len(new_keys) == 1
-        fn_compiled = globals_dict[new_keys[0]]
-
-        # Return the correct FunctionType object
-        return fn_compiled
-
-    def visit_Assert(self, node):
-        """
-        Swap out the Assert node (Python's `assert`) with a callsite to the
-        symbolically-traceable torch.Assert function
-        """
-        # Create the Call node
-        call_node = ast.parse('torch.Assert()', mode='eval').body
-        msg = node.msg if node.msg else ast.Constant(value="", kind=None)
-        call_node.args = [node.test, msg]
-
-        # Ensure that the new node conforms to the Python AST grammar
-        expr_wrapper = ast.Expr(value=call_node)
-
-        # Return the new Call node to signify that we want to use it as
-        # a replacement for the original Assert node
-        return ast.copy_location(expr_wrapper, node)
-
-
-class RewritingTracer(Tracer):
-    def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph:
-        return super().trace(_rewrite(root))
-
-
-def _rewrite(fn : Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]:
-    if isinstance(fn, torch.nn.Module):
-        # Rewrite this module's forward() and all of its recursive children's
-        # forward. Return the new rewritten module hierarchy.
-        def rewrite_module(m : torch.nn.Module):
-            class RewrittenModule(torch.nn.Module):
-                def __init__(self, orig):
-                    super().__init__()
-                    self.__dict__ = copy.copy(orig.__dict__)
-            RewrittenModule.forward = AST_Rewriter().rewrite(m.forward)
-            new_m = RewrittenModule(m)
-            for name, child in new_m.named_children():
-                new_m[name] = rewrite_module(child)
-            return new_m
-        return rewrite_module(fn)
-    else:
-        # Rewrite this single free function
-        return AST_Rewriter().rewrite(fn)
diff --git a/torch/fx/node.py b/torch/fx/node.py
index e434879..118b32f 100644
--- a/torch/fx/node.py
+++ b/torch/fx/node.py
@@ -20,7 +20,6 @@
     BaseArgumentTypes
 ]]
 
-
 class Node:
     def __init__(self, graph: 'Graph', name: str, op: str, target: Target,
                  args: Tuple[Argument, ...], kwargs: Dict[str, Argument],
diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py
index f723c62..2865c97 100644
--- a/torch/fx/symbolic_trace.py
+++ b/torch/fx/symbolic_trace.py
@@ -6,7 +6,7 @@
 from .node import Argument
 from .graph import Graph
 from .graph_module import GraphModule
-from .proxy import Proxy, TracerBase
+from .proxy import TracerBase
 
 HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
 
@@ -206,10 +206,6 @@
             torch.nn.Module.__getattr__ = orig_getattr  # type: ignore
         return self.graph
 
-    def _proxy_placeholder(self, name: str, type_expr: Optional[Any] = None) -> Proxy:
-        return Proxy(self.create_node('placeholder', name, (), {}, type_expr=type_expr), self)
-
-
 # Symbolic tracing API
 #
 # Given an `nn.Module` or function instance `root`, this function will return a `GraphModule`