[export] Move attrs to properties and add BC decorator (#106170)
@SherlockNoMad mentioned that it's not bc safe to directly access these attributes, so I moved them to @property fields, and added a `@compatibility` decorator. For now I just set it to True for graph_module/graph.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106170
Approved by: https://github.com/SherlockNoMad
diff --git a/torch/_export/exported_program.py b/torch/_export/exported_program.py
index a9078a5..9d280e6 100644
--- a/torch/_export/exported_program.py
+++ b/torch/_export/exported_program.py
@@ -7,6 +7,7 @@
import torch
import torch.fx._pytree as fx_pytree
+from torch.fx._compatibility import compatibility
import torch.utils._pytree as pytree
from torch import fx
from torch._functorch.aot_autograd import FQN, GraphInputName, GraphOutputName
@@ -103,13 +104,56 @@
):
# Remove codegen related things from the graph. It should just be a flat graph.
graph._codegen = torch.fx.graph.CodeGen()
- self.graph_module = torch.fx.GraphModule(root, graph)
+ self._graph_module = torch.fx.GraphModule(root, graph)
- self.graph_signature: ExportGraphSignature = graph_signature
- self.call_spec: CallSpec = call_spec
- self.state_dict: Dict[str, Any] = state_dict
- self.range_constraints: Dict[sympy.Symbol, RangeConstraint] = range_constraints
- self.equality_constraints: List[Tuple[InputDim, InputDim]] = equality_constraints
+ self._graph_signature: ExportGraphSignature = graph_signature
+ self._call_spec: CallSpec = call_spec
+ self._state_dict: Dict[str, Any] = state_dict
+ self._range_constraints: Dict[sympy.Symbol, RangeConstraint] = range_constraints
+ self._equality_constraints: List[Tuple[InputDim, InputDim]] = equality_constraints
+
+ @property
+ @compatibility(is_backward_compatible=True)
+ def graph_module(self):
+ return self._graph_module
+
+ @graph_module.setter
+ def graph_module(self, gm: torch.fx.GraphModule) -> None:
+ """
+ Set the underlying ``GraphModule`` for this ``ExportedProgram``.
+ """
+ assert isinstance(gm, torch.fx.GraphModule), f'Expected a GraphModule instance, but got {type(gm)}'
+ self._graph_module = gm
+
+ @property
+ @compatibility(is_backward_compatible=True)
+ def graph(self):
+ return self.graph_module.graph
+
+ @property
+ @compatibility(is_backward_compatible=False)
+ def graph_signature(self):
+ return self._graph_signature
+
+ @property
+ @compatibility(is_backward_compatible=False)
+ def state_dict(self):
+ return self._state_dict
+
+ @property
+ @compatibility(is_backward_compatible=False)
+ def call_spec(self):
+ return self._call_spec
+
+ @property
+ @compatibility(is_backward_compatible=False)
+ def range_constraints(self):
+ return self._range_constraints
+
+ @property
+ @compatibility(is_backward_compatible=False)
+ def equality_constraints(self):
+ return self._equality_constraints
def __call__(self, *args: Any, **kwargs: Any) -> Any:
if self.call_spec.in_spec is not None:
@@ -189,11 +233,6 @@
)
return new_ep
-
- @property
- def graph(self):
- return self.graph_module.graph
-
def transform(self, *passes: PassType) -> "ExportedProgram":
pm = PassManager(list(passes))
res = pm(self.graph_module)
diff --git a/torch/_export/serde/upgrade.py b/torch/_export/serde/upgrade.py
index d9afaf0..b33b3e0 100644
--- a/torch/_export/serde/upgrade.py
+++ b/torch/_export/serde/upgrade.py
@@ -197,6 +197,6 @@
# NB: we have to retrace the graph_module instead of ep because of some failure. Also, we need to turn of
# _add_runtime_assertions because dynamo is not happy with sym_size.int.
exported_program = export(upgraded_program.graph_module, inputs, {})
- exported_program.call_spec = upgraded_program.call_spec
+ exported_program._call_spec = upgraded_program.call_spec
return exported_program