[fx] store Tracer class on Graph and GraphModule for package deserialization (#62497)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62497

Previously named: add support for custom tracer in __reduce_package__

Stores a Tracer class on a Graph created by Tracer, and copies the Tracer class into the GraphModule's state so that when a GraphModule is packaged by torch package, it can be reconstructed with the same Tracer and GraphModule class name.

Reviewed By: suo

Differential Revision: D30019214

fbshipit-source-id: eca09424ad30feb93524d481268b066ea55b892a
diff --git a/test/package/package_a/test_all_leaf_modules_tracer.py b/test/package/package_a/test_all_leaf_modules_tracer.py
new file mode 100644
index 0000000..ca8d8a0
--- /dev/null
+++ b/test/package/package_a/test_all_leaf_modules_tracer.py
@@ -0,0 +1,6 @@
+from torch.fx import Tracer
+
+
+class TestAllLeafModulesTracer(Tracer):
+    def is_leaf_module(self, m, qualname):
+        return True
diff --git a/test/package/test_package_fx.py b/test/package/test_package_fx.py
index 7ad05a1..7f31014 100644
--- a/test/package/test_package_fx.py
+++ b/test/package/test_package_fx.py
@@ -121,6 +121,45 @@
         packaged_dependency = pi.import_module("package_a.subpackage")
         self.assertTrue(packaged_dependency is not package_a.subpackage)
 
+    def test_package_fx_custom_tracer(self):
+        from package_a.test_all_leaf_modules_tracer import TestAllLeafModulesTracer
+        from package_a.test_module import SimpleTest, ModWithTwoSubmodsAndTensor
+
+        class SpecialGraphModule(torch.fx.GraphModule):
+            def __init__(self, root, graph, info):
+                super().__init__(root, graph)
+                self.info = info
+
+        sub_module = SimpleTest()
+        module = ModWithTwoSubmodsAndTensor(
+            torch.ones(3),
+            sub_module,
+            sub_module,
+        )
+        tracer = TestAllLeafModulesTracer()
+        graph = tracer.trace(module)
+
+        self.assertEqual(graph._tracer_cls, TestAllLeafModulesTracer)
+
+        gm = SpecialGraphModule(module, graph, "secret")
+        self.assertEqual(gm._tracer_cls, TestAllLeafModulesTracer)
+
+        f = BytesIO()
+        with PackageExporter(f) as pe:
+            pe.intern("**")
+            pe.save_pickle("model", "model.pkl", gm)
+        f.seek(0)
+
+        pi = PackageImporter(f)
+        loaded_gm = pi.load_pickle("model", "model.pkl")
+        self.assertEqual(
+            type(loaded_gm).__class__.__name__, SpecialGraphModule.__class__.__name__
+        )
+        self.assertEqual(loaded_gm.info, "secret")
+
+        input_x = torch.randn(3)
+        self.assertTrue(torch.allclose(loaded_gm(input_x), gm(input_x)))
+
 
 if __name__ == "__main__":
     run_tests()
diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py
index 56f925f..1f63a44 100644
--- a/torch/fx/_symbolic_trace.py
+++ b/torch/fx/_symbolic_trace.py
@@ -532,7 +532,9 @@
         else:
             self.root = torch.nn.Module()
             fn = root
-        self.graph = Graph()
+
+        tracer_cls: Optional['Tracer'] = getattr(self, '__class__', None)
+        self.graph = Graph(tracer_cls=tracer_cls)
 
         # When we encounter a Tensor value that's not a parameter, we look if it
         # is some other attribute on the model. Construct a dict mapping Tensor
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index b7960f8..ba74740 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -16,6 +16,7 @@
 
 if TYPE_CHECKING:
     from .graph_module import GraphModule  # noqa: F401
+    from ._symbolic_trace import Tracer   # noqa: F401
 
 
 # Mapping of builtins to their `typing` equivalent.
@@ -282,7 +283,7 @@
 
     For the semantics of operations represented in the ``Graph``, please see :class:`Node`.
     """
-    def __init__(self, owning_module: Optional["GraphModule"] = None):
+    def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional["Tracer"] = None):
         """
         Construct an empty Graph.
         """
@@ -293,6 +294,7 @@
         self._graph_namespace = _Namespace()
         self._owners = 0
         self._owning_module = owning_module
+        self._tracer_cls = tracer_cls
         self._pytree_info: Optional[_PyTreeInfo] = None
 
     @property
diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py
index 09d1b6d..c9d8655 100644
--- a/torch/fx/graph_module.py
+++ b/torch/fx/graph_module.py
@@ -71,26 +71,25 @@
     return _deserialize_graph_module(forward, body)
 
 
-def reduce_package_graph_module(importer: PackageImporter,
-                                body: Dict[Any, Any],
-                                generated_module_name: str) -> torch.nn.Module:
+def reduce_package_graph_module(
+    importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str
+) -> torch.nn.Module:
     forward = importer.import_module(generated_module_name).forward
     return _deserialize_graph_module(forward, body)
 
 
-def reduce_deploy_graph_module(importer: PackageImporter,
-                               body: Dict[Any, Any],
-                               import_block: str,
-                               tracer_cls: Type) -> torch.nn.Module:
+def reduce_deploy_graph_module(
+    importer: PackageImporter, body: Dict[Any, Any], import_block: str
+) -> torch.nn.Module:
     ns = dict()
     ns["__builtins__"] = importer.patched_builtins
     fn_src = body.get('_code')
     assert fn_src is not None
     forward = _forward_from_src(import_block + fn_src, ns)
-    return _deserialize_graph_module(forward, body, tracer_cls)
+    return _deserialize_graph_module(forward, body)
 
 
-def _deserialize_graph_module(forward, body: Dict[Any, Any], tracer_cls: Type = None) -> torch.nn.Module:
+def _deserialize_graph_module(forward, body: Dict[Any, Any]) -> torch.nn.Module:
     """
     Deserialize a GraphModule given the dictionary of the original module,
     using the code to reconstruct the graph. We delete the actual graph before
@@ -107,10 +106,13 @@
     # Try to retrieve the forward source in a backward-compatible way
     CodeOnlyModule.forward = forward
 
+    tracer_cls = body.get('_tracer_cls')
     if tracer_cls is None:
         from ._symbolic_trace import Tracer
         tracer_cls = Tracer
 
+    graphmodule_cls_name = body.get('_graphmodule_cls_name', 'GraphModule')
+
     # This is a workaround for a mypy linter issue related to
     # passing base class as an argument - https://github.com/python/mypy/issues/5865.
     cls_tracer : Any = tracer_cls
@@ -122,7 +124,22 @@
             return True
 
     com = CodeOnlyModule(body)
-    return GraphModule(com, KeepModules().trace(com))
+
+    graph = KeepModules().trace(com)
+
+    # Manually set Tracer class on the reconstructed Graph, to avoid
+    # referencing the private local subclass KeepModules.
+    graph._tracer_cls = tracer_cls
+    gm = GraphModule(com, graph, class_name=graphmodule_cls_name)
+
+    # The GraphModule constructor only retains attributes referenced by the graph.
+    # In this case, our goal is return a GraphModule as close to identical as the one
+    # put into the package. If any additional attributes were present in body,
+    # we should keep them.
+    for k, v in body.items():
+        if not hasattr(gm, k):
+            setattr(gm, k, v)
+    return gm
 
 # copy an attribute value with qualified name 'target' from 'from_module' to 'to_module'
 # This installs empty Modules where none exist yet if they are subpaths of target
@@ -251,6 +268,10 @@
 
         self.graph = graph
 
+        # Store the Tracer class responsible for creating a Graph separately,
+        # because torch.package will serialize a GraphModule without retaining the Graph.
+        self._tracer_cls = self.graph._tracer_cls
+
     # TorchScript breaks trying to compile the graph setter because of the
     # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
     #
@@ -548,12 +569,12 @@
 
     # Passing Tracer as argument allows subclasses extending fx.GraphModule
     # define their own Tracer (extending fx.Tracer).
-    def __reduce_deploy__(self, importer: Importer, tracer_cls: Type = None):
+    def __reduce_deploy__(self, importer: Importer):
         dict_without_graph = self.__dict__.copy()
         python_code = self.recompile()
         import_block = _format_import_block(python_code.globals, importer)
         del dict_without_graph['_graph']
-        return (reduce_deploy_graph_module, (dict_without_graph, import_block, tracer_cls))
+        return (reduce_deploy_graph_module, (dict_without_graph, import_block))
 
     def __reduce_package__(self, exporter: PackageExporter):
         generated_module_name = f'fx-generated._{exporter.get_unique_id()}'
@@ -563,6 +584,7 @@
         exporter.save_source_string(generated_module_name, module_code)
 
         dict_without_graph = self.__dict__.copy()
+        dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__
         del dict_without_graph['_graph']
         return (reduce_package_graph_module, (dict_without_graph, generated_module_name))