[fx] store Tracer class on Graph and GraphModule for package deserialization (#62497)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62497
Previously named: add support for custom tracer in __reduce_package__
Stores a Tracer class on a Graph created by Tracer, and copies the Tracer class into the GraphModule's state so that when a GraphModule is packaged by torch package, it can be reconstructed with the same Tracer and GraphModule class name.
Reviewed By: suo
Differential Revision: D30019214
fbshipit-source-id: eca09424ad30feb93524d481268b066ea55b892a
diff --git a/test/package/package_a/test_all_leaf_modules_tracer.py b/test/package/package_a/test_all_leaf_modules_tracer.py
new file mode 100644
index 0000000..ca8d8a0
--- /dev/null
+++ b/test/package/package_a/test_all_leaf_modules_tracer.py
@@ -0,0 +1,6 @@
+from torch.fx import Tracer
+
+
+class TestAllLeafModulesTracer(Tracer):
+ def is_leaf_module(self, m, qualname):
+ return True
diff --git a/test/package/test_package_fx.py b/test/package/test_package_fx.py
index 7ad05a1..7f31014 100644
--- a/test/package/test_package_fx.py
+++ b/test/package/test_package_fx.py
@@ -121,6 +121,45 @@
packaged_dependency = pi.import_module("package_a.subpackage")
self.assertTrue(packaged_dependency is not package_a.subpackage)
+ def test_package_fx_custom_tracer(self):
+ from package_a.test_all_leaf_modules_tracer import TestAllLeafModulesTracer
+ from package_a.test_module import SimpleTest, ModWithTwoSubmodsAndTensor
+
+ class SpecialGraphModule(torch.fx.GraphModule):
+ def __init__(self, root, graph, info):
+ super().__init__(root, graph)
+ self.info = info
+
+ sub_module = SimpleTest()
+ module = ModWithTwoSubmodsAndTensor(
+ torch.ones(3),
+ sub_module,
+ sub_module,
+ )
+ tracer = TestAllLeafModulesTracer()
+ graph = tracer.trace(module)
+
+ self.assertEqual(graph._tracer_cls, TestAllLeafModulesTracer)
+
+ gm = SpecialGraphModule(module, graph, "secret")
+ self.assertEqual(gm._tracer_cls, TestAllLeafModulesTracer)
+
+ f = BytesIO()
+ with PackageExporter(f) as pe:
+ pe.intern("**")
+ pe.save_pickle("model", "model.pkl", gm)
+ f.seek(0)
+
+ pi = PackageImporter(f)
+ loaded_gm = pi.load_pickle("model", "model.pkl")
+ self.assertEqual(
+ type(loaded_gm).__class__.__name__, SpecialGraphModule.__class__.__name__
+ )
+ self.assertEqual(loaded_gm.info, "secret")
+
+ input_x = torch.randn(3)
+ self.assertTrue(torch.allclose(loaded_gm(input_x), gm(input_x)))
+
if __name__ == "__main__":
run_tests()
diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py
index 56f925f..1f63a44 100644
--- a/torch/fx/_symbolic_trace.py
+++ b/torch/fx/_symbolic_trace.py
@@ -532,7 +532,9 @@
else:
self.root = torch.nn.Module()
fn = root
- self.graph = Graph()
+
+ tracer_cls: Optional['Tracer'] = getattr(self, '__class__', None)
+ self.graph = Graph(tracer_cls=tracer_cls)
# When we encounter a Tensor value that's not a parameter, we look if it
# is some other attribute on the model. Construct a dict mapping Tensor
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index b7960f8..ba74740 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -16,6 +16,7 @@
if TYPE_CHECKING:
from .graph_module import GraphModule # noqa: F401
+ from ._symbolic_trace import Tracer # noqa: F401
# Mapping of builtins to their `typing` equivalent.
@@ -282,7 +283,7 @@
For the semantics of operations represented in the ``Graph``, please see :class:`Node`.
"""
- def __init__(self, owning_module: Optional["GraphModule"] = None):
+ def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional["Tracer"] = None):
"""
Construct an empty Graph.
"""
@@ -293,6 +294,7 @@
self._graph_namespace = _Namespace()
self._owners = 0
self._owning_module = owning_module
+ self._tracer_cls = tracer_cls
self._pytree_info: Optional[_PyTreeInfo] = None
@property
diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py
index 09d1b6d..c9d8655 100644
--- a/torch/fx/graph_module.py
+++ b/torch/fx/graph_module.py
@@ -71,26 +71,25 @@
return _deserialize_graph_module(forward, body)
-def reduce_package_graph_module(importer: PackageImporter,
- body: Dict[Any, Any],
- generated_module_name: str) -> torch.nn.Module:
+def reduce_package_graph_module(
+ importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str
+) -> torch.nn.Module:
forward = importer.import_module(generated_module_name).forward
return _deserialize_graph_module(forward, body)
-def reduce_deploy_graph_module(importer: PackageImporter,
- body: Dict[Any, Any],
- import_block: str,
- tracer_cls: Type) -> torch.nn.Module:
+def reduce_deploy_graph_module(
+ importer: PackageImporter, body: Dict[Any, Any], import_block: str
+) -> torch.nn.Module:
ns = dict()
ns["__builtins__"] = importer.patched_builtins
fn_src = body.get('_code')
assert fn_src is not None
forward = _forward_from_src(import_block + fn_src, ns)
- return _deserialize_graph_module(forward, body, tracer_cls)
+ return _deserialize_graph_module(forward, body)
-def _deserialize_graph_module(forward, body: Dict[Any, Any], tracer_cls: Type = None) -> torch.nn.Module:
+def _deserialize_graph_module(forward, body: Dict[Any, Any]) -> torch.nn.Module:
"""
Deserialize a GraphModule given the dictionary of the original module,
using the code to reconstruct the graph. We delete the actual graph before
@@ -107,10 +106,13 @@
# Try to retrieve the forward source in a backward-compatible way
CodeOnlyModule.forward = forward
+ tracer_cls = body.get('_tracer_cls')
if tracer_cls is None:
from ._symbolic_trace import Tracer
tracer_cls = Tracer
+ graphmodule_cls_name = body.get('_graphmodule_cls_name', 'GraphModule')
+
# This is a workaround for a mypy linter issue related to
# passing base class as an argument - https://github.com/python/mypy/issues/5865.
cls_tracer : Any = tracer_cls
@@ -122,7 +124,22 @@
return True
com = CodeOnlyModule(body)
- return GraphModule(com, KeepModules().trace(com))
+
+ graph = KeepModules().trace(com)
+
+ # Manually set Tracer class on the reconstructed Graph, to avoid
+ # referencing the private local subclass KeepModules.
+ graph._tracer_cls = tracer_cls
+ gm = GraphModule(com, graph, class_name=graphmodule_cls_name)
+
+ # The GraphModule constructor only retains attributes referenced by the graph.
+ # In this case, our goal is return a GraphModule as close to identical as the one
+ # put into the package. If any additional attributes were present in body,
+ # we should keep them.
+ for k, v in body.items():
+ if not hasattr(gm, k):
+ setattr(gm, k, v)
+ return gm
# copy an attribute value with qualified name 'target' from 'from_module' to 'to_module'
# This installs empty Modules where none exist yet if they are subpaths of target
@@ -251,6 +268,10 @@
self.graph = graph
+ # Store the Tracer class responsible for creating a Graph separately,
+ # because torch.package will serialize a GraphModule without retaining the Graph.
+ self._tracer_cls = self.graph._tracer_cls
+
# TorchScript breaks trying to compile the graph setter because of the
# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
#
@@ -548,12 +569,12 @@
# Passing Tracer as argument allows subclasses extending fx.GraphModule
# define their own Tracer (extending fx.Tracer).
- def __reduce_deploy__(self, importer: Importer, tracer_cls: Type = None):
+ def __reduce_deploy__(self, importer: Importer):
dict_without_graph = self.__dict__.copy()
python_code = self.recompile()
import_block = _format_import_block(python_code.globals, importer)
del dict_without_graph['_graph']
- return (reduce_deploy_graph_module, (dict_without_graph, import_block, tracer_cls))
+ return (reduce_deploy_graph_module, (dict_without_graph, import_block))
def __reduce_package__(self, exporter: PackageExporter):
generated_module_name = f'fx-generated._{exporter.get_unique_id()}'
@@ -563,6 +584,7 @@
exporter.save_source_string(generated_module_name, module_code)
dict_without_graph = self.__dict__.copy()
+ dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__
del dict_without_graph['_graph']
return (reduce_package_graph_module, (dict_without_graph, generated_module_name))