Allow fx.Graph.owning_module to be used as attribute. (#86822)
Summary:
The current behavior of owning_module setter is difficult to understand: it changes the owning_module to None if owners is not 0 but increments the owners count. If the owning_module is None, the owners count should be 0 as none of them is accessible. On the other hand, if the owners count increases, the owning_module should be a collection (e.g. a list).
This diff changes owning_module to be a normal attribute. The semantic is that graph can have **at most one** owning module and can be assigned to new module.
The alternative is to use a list to represent the owning_modules of a graph but it breaks backward compatibility and the exact use cases of having multiple owning_modules are not clear.
Test Plan: Test with CI.
Differential Revision: D40200624
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86822
Approved by: https://github.com/tugsbayasgalan
diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py
index 56ffbca..a969803 100644
--- a/torch/fx/experimental/const_fold.py
+++ b/torch/fx/experimental/const_fold.py
@@ -24,11 +24,6 @@
fx_const_folded_attrs_name: str = None,
device_for_folded_attrs: str = "cuda",
):
- # In init, we set graph's owning module to root which will make graph's
- # owning module be None because graph already have a owning module. We
- # need owning module to run DCE. To work around we set the number of
- # graph's owners to 0.
- graph._owners = 0
super().__init__(root, graph)
self.const_subgraph_module = (
None
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index 271c43e..9397050 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -697,7 +697,6 @@
self._insert = self._root.prepend
self._len = 0
self._graph_namespace = _Namespace()
- self._owners = 0
self._owning_module = owning_module
self._tracer_cls = tracer_cls
self._tracer_extras = tracer_extras
@@ -705,18 +704,11 @@
@property
def owning_module(self):
- """
- Return the module that owns this ``GraphModule``, if there is one,
- ``None`` if there is no owning module or if there are multiple owning
- modules.
- """
return self._owning_module
@owning_module.setter
def owning_module(self, mod: Optional["GraphModule"]):
- if mod:
- self._owning_module = mod if not self._owners else None
- self._owners += 1
+ self._owning_module = mod
@property
def nodes(self) -> _node_list: