Handle cycles when flattening a tf.Module with paths.

PiperOrigin-RevId: 391329358
Change-Id: I2d4c342dbb33cb65a4339ea7bb63e4283b4a3a3f
diff --git a/tensorflow/python/module/module.py b/tensorflow/python/module/module.py
index ad46af0..2c14b7a 100644
--- a/tensorflow/python/module/module.py
+++ b/tensorflow/python/module/module.py
@@ -352,14 +352,55 @@
                     with_path,
                     expand_composites,
                     module_path=(),
-                    seen=None):
-  """Implementation of `flatten`."""
+                    seen=None,
+                    recursion_stack=None):
+  """Implementation of `flatten`.
+
+  Args:
+    module: Current module to process.
+    recursive: Whether to recurse into child modules or not.
+    predicate: (Optional) If set then only values matching predicate are
+      yielded. A value of `None` (the default) means no items will be
+      filtered.
+    attribute_traversal_key: (Optional) Method to rekey object attributes
+      before they are sorted. Contract is the same as `key` argument to
+      builtin `sorted` and only applies to object properties.
+    attributes_to_ignore: object attributes to ignored.
+    with_path: (Optional) Whether to include the path to the object as well
+      as the object itself. If `with_path` is `True` then leaves will not be
+      de-duplicated (e.g. if the same leaf instance is reachable via multiple
+      modules then it will be yielded multiple times with different paths).
+    expand_composites: If true, then composite tensors are expanded into their
+      component tensors.
+    module_path: The path to the current module as a tuple.
+    seen: A set containing all leaf IDs seen so far.
+    recursion_stack: A list containing all module IDs associated with the
+      current call stack.
+
+  Yields:
+    Matched leaves with the optional corresponding paths of the current module
+    and optionally all its submodules.
+  """
+  module_id = id(module)
   if seen is None:
-    seen = set([id(module)])
+    seen = set([module_id])
 
   module_dict = vars(module)
   submodules = []
 
+  if recursion_stack is None:
+    recursion_stack = []
+
+  # When calling `_flatten_module` with `with_path=False`, the global lookup
+  # table `seen` guarantees the uniqueness of the matched objects.
+  # In the case of `with_path=True`, there might be multiple paths associated
+  # with the same predicate, so we don't stop traversing according to `seen`
+  # to make sure all these paths are returned.
+  # When there are cycles connecting submodules, we break cycles by avoiding
+  # following back edges (links pointing to a node in `recursion_stack`).
+  if module_id in recursion_stack:
+    recursive = False
+
   for key in sorted(module_dict, key=attribute_traversal_key):
     if key in attributes_to_ignore:
       continue
@@ -377,7 +418,6 @@
     for leaf_path, leaf in leaves:
       leaf_path = (key,) + leaf_path
 
-      # TODO(tomhennigan) Handle cycles for `with_path=True` (e.g. `a.a = a`).
       if not with_path:
         leaf_id = id(leaf)
         if leaf_id in seen:
@@ -394,6 +434,8 @@
         # Walk direct properties first then recurse.
         submodules.append((module_path + leaf_path, leaf))
 
+  recursion_stack.append(module_id)
+
   for submodule_path, submodule in submodules:
     subvalues = _flatten_module(
         submodule,
@@ -404,8 +446,11 @@
         with_path=with_path,
         expand_composites=expand_composites,
         module_path=submodule_path,
-        seen=seen)
+        seen=seen,
+        recursion_stack=recursion_stack)
 
     for subvalue in subvalues:
       # Predicate is already tested for these values.
       yield subvalue
+
+  recursion_stack.pop()
diff --git a/tensorflow/python/module/module_test.py b/tensorflow/python/module/module_test.py
index 77ac2a9..1a2a0aa 100644
--- a/tensorflow/python/module/module_test.py
+++ b/tensorflow/python/module/module_test.py
@@ -537,6 +537,28 @@
                       ("decoder", "w", 0, 0, "k"): mod.decoder.w[0][0]["k"],
                       ("decoder", "w", 0, 1, "k"): mod.decoder.w[0][1]["k"]},)
 
+  def test_cycles_with_path(self):
+    mod = module.Module()
+    mod.w = variables.Variable(1.)
+    mod.encoder = module.Module()
+    mod.encoder.w = [({"k": mod.w}, {"k": mod.w})]
+    mod.decoder = mod.encoder
+
+    # This introduces two cycles: on mod.encoder.mod and mod.decoder.mod.
+    mod.decoder.mod = mod
+
+    state_dict = dict(
+        mod._flatten(with_path=True, predicate=module._is_variable))
+
+    self.assertEqual(state_dict,
+                     {("w",): mod.w,
+                      ("encoder", "mod", "w"): mod.encoder.mod.w,
+                      ("decoder", "mod", "w"): mod.decoder.mod.w,
+                      ("encoder", "w", 0, 0, "k"): mod.encoder.w[0][0]["k"],
+                      ("encoder", "w", 0, 1, "k"): mod.encoder.w[0][1]["k"],
+                      ("decoder", "w", 0, 0, "k"): mod.decoder.w[0][0]["k"],
+                      ("decoder", "w", 0, 1, "k"): mod.decoder.w[0][1]["k"]},)
+
   def test_raises_error_with_path(self):
     if six.PY2:
       class NonOrderable(object):