Handle cycles when flattening a tf.Module with paths.
PiperOrigin-RevId: 391329358
Change-Id: I2d4c342dbb33cb65a4339ea7bb63e4283b4a3a3f
diff --git a/tensorflow/python/module/module.py b/tensorflow/python/module/module.py
index ad46af0..2c14b7a 100644
--- a/tensorflow/python/module/module.py
+++ b/tensorflow/python/module/module.py
@@ -352,14 +352,55 @@
with_path,
expand_composites,
module_path=(),
- seen=None):
- """Implementation of `flatten`."""
+ seen=None,
+ recursion_stack=None):
+ """Implementation of `flatten`.
+
+ Args:
+ module: Current module to process.
+ recursive: Whether to recurse into child modules or not.
+ predicate: (Optional) If set then only values matching predicate are
+ yielded. A value of `None` (the default) means no items will be
+ filtered.
+ attribute_traversal_key: (Optional) Method to rekey object attributes
+ before they are sorted. Contract is the same as `key` argument to
+ builtin `sorted` and only applies to object properties.
+ attributes_to_ignore: object attributes to ignored.
+ with_path: (Optional) Whether to include the path to the object as well
+ as the object itself. If `with_path` is `True` then leaves will not be
+ de-duplicated (e.g. if the same leaf instance is reachable via multiple
+ modules then it will be yielded multiple times with different paths).
+ expand_composites: If true, then composite tensors are expanded into their
+ component tensors.
+ module_path: The path to the current module as a tuple.
+ seen: A set containing all leaf IDs seen so far.
+ recursion_stack: A list containing all module IDs associated with the
+ current call stack.
+
+ Yields:
+ Matched leaves with the optional corresponding paths of the current module
+ and optionally all its submodules.
+ """
+ module_id = id(module)
if seen is None:
- seen = set([id(module)])
+ seen = set([module_id])
module_dict = vars(module)
submodules = []
+ if recursion_stack is None:
+ recursion_stack = []
+
+ # When calling `_flatten_module` with `with_path=False`, the global lookup
+ # table `seen` guarantees the uniqueness of the matched objects.
+ # In the case of `with_path=True`, there might be multiple paths associated
+ # with the same predicate, so we don't stop traversing according to `seen`
+ # to make sure all these paths are returned.
+ # When there are cycles connecting submodules, we break cycles by avoiding
+ # following back edges (links pointing to a node in `recursion_stack`).
+ if module_id in recursion_stack:
+ recursive = False
+
for key in sorted(module_dict, key=attribute_traversal_key):
if key in attributes_to_ignore:
continue
@@ -377,7 +418,6 @@
for leaf_path, leaf in leaves:
leaf_path = (key,) + leaf_path
- # TODO(tomhennigan) Handle cycles for `with_path=True` (e.g. `a.a = a`).
if not with_path:
leaf_id = id(leaf)
if leaf_id in seen:
@@ -394,6 +434,8 @@
# Walk direct properties first then recurse.
submodules.append((module_path + leaf_path, leaf))
+ recursion_stack.append(module_id)
+
for submodule_path, submodule in submodules:
subvalues = _flatten_module(
submodule,
@@ -404,8 +446,11 @@
with_path=with_path,
expand_composites=expand_composites,
module_path=submodule_path,
- seen=seen)
+ seen=seen,
+ recursion_stack=recursion_stack)
for subvalue in subvalues:
# Predicate is already tested for these values.
yield subvalue
+
+ recursion_stack.pop()
diff --git a/tensorflow/python/module/module_test.py b/tensorflow/python/module/module_test.py
index 77ac2a9..1a2a0aa 100644
--- a/tensorflow/python/module/module_test.py
+++ b/tensorflow/python/module/module_test.py
@@ -537,6 +537,28 @@
("decoder", "w", 0, 0, "k"): mod.decoder.w[0][0]["k"],
("decoder", "w", 0, 1, "k"): mod.decoder.w[0][1]["k"]},)
+ def test_cycles_with_path(self):
+ mod = module.Module()
+ mod.w = variables.Variable(1.)
+ mod.encoder = module.Module()
+ mod.encoder.w = [({"k": mod.w}, {"k": mod.w})]
+ mod.decoder = mod.encoder
+
+ # This introduces two cycles: on mod.encoder.mod and mod.decoder.mod.
+ mod.decoder.mod = mod
+
+ state_dict = dict(
+ mod._flatten(with_path=True, predicate=module._is_variable))
+
+ self.assertEqual(state_dict,
+ {("w",): mod.w,
+ ("encoder", "mod", "w"): mod.encoder.mod.w,
+ ("decoder", "mod", "w"): mod.decoder.mod.w,
+ ("encoder", "w", 0, 0, "k"): mod.encoder.w[0][0]["k"],
+ ("encoder", "w", 0, 1, "k"): mod.encoder.w[0][1]["k"],
+ ("decoder", "w", 0, 0, "k"): mod.decoder.w[0][0]["k"],
+ ("decoder", "w", 0, 1, "k"): mod.decoder.w[0][1]["k"]},)
+
def test_raises_error_with_path(self):
if six.PY2:
class NonOrderable(object):