Revert D24665950: Create prototype for AST rewriter
Test Plan: revert-hammer
Differential Revision:
D24665950 (https://github.com/pytorch/pytorch/commit/54feb00bbde9f6220721259f4a352288ab610c03)
Original commit changeset: b72110436126
fbshipit-source-id: 961412df006acd33c91a745c809832d5c6494c76
diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py
index 54b140e..d5d3dae 100644
--- a/test/test_fx_experimental.py
+++ b/test/test_fx_experimental.py
@@ -2,14 +2,8 @@
from torch.fx.symbolic_trace import symbolic_trace
from torch.fx.experimental import GraphManipulation
from torch.fx.experimental.Partitioner import Partitioner, Device, PartitionerConfig
-from torch.fx.experimental.rewriter import RewritingTracer
-from torch.fx.graph_module import GraphModule
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.jit_utils import JitTestCase
-from typing import Union, Callable
-
-def symbolic_trace_with_rewrite(root: Union[torch.nn.Module, Callable]) -> GraphModule:
- return GraphModule(root if isinstance(root, torch.nn.Module) else torch.nn.Module(), RewritingTracer().trace(root))
class TestFXExperimental(JitTestCase):
def test_find_single_partition(self):
@@ -168,110 +162,5 @@
self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset))
assert len(module_with_submodules.graph.nodes) == 24
- def test_call_to_assert_no_msg(self):
-
- class M(torch.nn.Module):
- def forward(self, a, b):
- assert a == b
- return a + b
- m = M()
- traced = symbolic_trace_with_rewrite(m)
-
- # Make sure the graph is well-formed
- traced.graph.lint(traced)
-
- # Check the IR to make sure there's a call_function node with target == "Assert"
- self.assertTrue(any(node.op == "call_function" and node.target == torch.Assert for node in traced.graph.nodes))
-
- # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
- traced(3, 3)
- with self.assertRaisesRegex(AssertionError, ""):
- traced(3, 5)
-
- # Confirm that the output is correct
- self.assertEqual(traced(3, 3), m(3, 3))
-
- def test_call_to_assert_with_msg(self):
-
- class M(torch.nn.Module):
- def forward(self, a, b):
- assert a == b, "test message"
- return a + b
- m = M()
- traced = symbolic_trace_with_rewrite(m)
-
- # Make sure the graph is well-formed
- traced.graph.lint(traced)
-
- # Check the IR to make sure there's a call_function node with target == "Assert"
- self.assertTrue(any(node.op == "call_function" and node.target == torch.Assert for node in traced.graph.nodes))
-
- # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
- traced(3, 3)
- with self.assertRaisesRegex(AssertionError, "test message"):
- traced(3, 5)
-
- # Confirm that the output is correct
- self.assertEqual(traced(3, 3), m(3, 3))
-
- def test_call_to_assert_with_empty_msg(self):
-
- class M(torch.nn.Module):
- def forward(self, a, b):
- assert a == b, ""
- return a + b
- m = M()
- traced = symbolic_trace_with_rewrite(m)
-
- # Make sure the graph is well-formed
- traced.graph.lint(traced)
-
- # Check the IR to make sure there's a call_function node with target == "Assert"
- self.assertTrue(any(node.op == "call_function" and node.target == torch.Assert for node in traced.graph.nodes))
-
- # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
- traced(3, 3)
- with self.assertRaisesRegex(AssertionError, ""):
- traced(3, 5)
-
- # Confirm that the output is correct
- self.assertEqual(traced(3, 3), m(3, 3))
-
- def test_call_to_assert_with_multiline_message(self):
-
- class M(torch.nn.Module):
- def forward(self, a, b):
- error_msg = """
-An error message with
-terrible spacing
- """
- assert a == b, error_msg
- return a + b
- m = M()
- traced = symbolic_trace_with_rewrite(m)
-
- # Make sure the graph is well-formed
- traced.graph.lint(traced)
-
- # Check the IR to make sure there's a call_function node with target == "Assert"
- self.assertTrue(any(node.op == "call_function" and node.target == torch.Assert for node in traced.graph.nodes))
-
- # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
- error_msg = """
-An error message with
-terrible spacing
- """
- traced(3, 3)
- with self.assertRaisesRegex(AssertionError, error_msg):
- traced(3, 5)
-
- # Confirm that the output is correct
- self.assertEqual(traced(3, 3), m(3, 3))
-
- def test_traceable_function_with_nonstandard_name(self):
- def foo(x):
- return torch.relu(x)
- traced = symbolic_trace_with_rewrite(foo)
-
if __name__ == '__main__':
run_tests()
diff --git a/torch/fx/experimental/rewriter.py b/torch/fx/experimental/rewriter.py
deleted file mode 100644
index 9981c88..0000000
--- a/torch/fx/experimental/rewriter.py
+++ /dev/null
@@ -1,87 +0,0 @@
-import ast
-import inspect
-import textwrap
-import copy
-from types import FunctionType
-from typing import Union, Callable
-from torch.fx.symbolic_trace import Tracer
-from torch.fx.graph import Graph
-from torch.jit.frontend import normalize_source_lines
-import torch
-
-class AST_Rewriter(ast.NodeTransformer):
- """
- Take a FunctionType object representing a `forward` method, then
- perform an AST rewrite to swap out nodes that are not symbolically
- traceable with a callsite to the FX alternative.
-
- To support swapping out an AST node, define a new `visit` method on
- that node. For more details, see:
- https://docs.python.org/3/library/ast.html#ast.NodeTransformer
- """
-
- def rewrite(self, fn: FunctionType):
-
- # Normalize the source lines
- sourcelines, _ = inspect.getsourcelines(fn)
- sourcelines = normalize_source_lines(sourcelines)
- source = ''.join(sourcelines)
- normalized_str = textwrap.dedent(source)
-
- # Rewrite the original AST
- source_ast = ast.parse(normalized_str)
- dest_ast = ast.fix_missing_locations(self.visit(source_ast))
-
- # Pull out the compiled fucntion from the newly-created Module
- code = compile(dest_ast, "", "exec")
- globals_dict = copy.copy(fn.__globals__)
- keys_before = set(globals_dict.keys())
- exec(code, globals_dict)
- new_keys = list(set(globals_dict.keys()) - keys_before)
- assert len(new_keys) == 1
- fn_compiled = globals_dict[new_keys[0]]
-
- # Return the correct FunctionType object
- return fn_compiled
-
- def visit_Assert(self, node):
- """
- Swap out the Assert node (Python's `assert`) with a callsite to the
- symbolically-traceable torch.Assert function
- """
- # Create the Call node
- call_node = ast.parse('torch.Assert()', mode='eval').body
- msg = node.msg if node.msg else ast.Constant(value="", kind=None)
- call_node.args = [node.test, msg]
-
- # Ensure that the new node conforms to the Python AST grammar
- expr_wrapper = ast.Expr(value=call_node)
-
- # Return the new Call node to signify that we want to use it as
- # a replacement for the original Assert node
- return ast.copy_location(expr_wrapper, node)
-
-
-class RewritingTracer(Tracer):
- def trace(self, root: Union[torch.nn.Module, Callable]) -> Graph:
- return super().trace(_rewrite(root))
-
-
-def _rewrite(fn : Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]:
- if isinstance(fn, torch.nn.Module):
- # Rewrite this module's forward() and all of its recursive children's
- # forward. Return the new rewritten module hierarchy.
- def rewrite_module(m : torch.nn.Module):
- class RewrittenModule(torch.nn.Module):
- def __init__(self, orig):
- super().__init__()
- self.__dict__ = copy.copy(orig.__dict__)
- RewrittenModule.forward = AST_Rewriter().rewrite(m.forward)
- new_m = RewrittenModule(m)
- for name, child in new_m.named_children():
- new_m[name] = rewrite_module(child)
- return new_m
- return rewrite_module(fn)
- else:
- # Rewrite this single free function
- return AST_Rewriter().rewrite(fn)
diff --git a/torch/fx/node.py b/torch/fx/node.py
index e434879..118b32f 100644
--- a/torch/fx/node.py
+++ b/torch/fx/node.py
@@ -20,7 +20,6 @@
BaseArgumentTypes
]]
-
class Node:
def __init__(self, graph: 'Graph', name: str, op: str, target: Target,
args: Tuple[Argument, ...], kwargs: Dict[str, Argument],
diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py
index f723c62..2865c97 100644
--- a/torch/fx/symbolic_trace.py
+++ b/torch/fx/symbolic_trace.py
@@ -6,7 +6,7 @@
from .node import Argument
from .graph import Graph
from .graph_module import GraphModule
-from .proxy import Proxy, TracerBase
+from .proxy import TracerBase
HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS
@@ -206,10 +206,6 @@
torch.nn.Module.__getattr__ = orig_getattr # type: ignore
return self.graph
- def _proxy_placeholder(self, name: str, type_expr: Optional[Any] = None) -> Proxy:
- return Proxy(self.create_node('placeholder', name, (), {}, type_expr=type_expr), self)
-
-
# Symbolic tracing API
#
# Given an `nn.Module` or function instance `root`, this function will return a `GraphModule`