Remove use of deprecated `_checkpoint_dependencies`.
PiperOrigin-RevId: 439431826
diff --git a/tensorflow/python/training/tracking/data_structures_test.py b/tensorflow/python/training/tracking/data_structures_test.py
index 76ab714..76b637b 100644
--- a/tensorflow/python/training/tracking/data_structures_test.py
+++ b/tensorflow/python/training/tracking/data_structures_test.py
@@ -182,7 +182,7 @@
m.l = [1, 2]
m.l.insert(0, 0)
self.assertEqual(m.l, [0, 1, 2])
- self.assertEqual(m.l._checkpoint_dependencies, [])
+ self.assertEqual(m.l._trackable_children(), {})
def testFunctionCaching(self):
@def_function.function
@@ -272,7 +272,7 @@
l[:] = 2, 8, 9, 0
self.assertEqual(l, [2, 8, 9, 0])
l._maybe_initialize_trackable() # pylint: disable=protected-access
- self.assertEqual(len(l._checkpoint_dependencies), 0) # pylint: disable=protected-access
+ self.assertEqual(len(l._trackable_children()), 0) # pylint: disable=protected-access
def testSetSlice_cannotSaveIfTrackableModified(self):
v1 = resource_variable_ops.ResourceVariable(1.)
@@ -301,14 +301,14 @@
def testIMulPositive(self):
v = variables.Variable(1.)
l = data_structures.ListWrapper([1, 2, 3, 4, v])
- self.assertEqual([("4", v)], l._checkpoint_dependencies)
+ self.assertDictEqual({"4": v}, l._trackable_children())
root = util.Checkpoint(l=l)
prefix = os.path.join(self.get_temp_dir(), "ckpt")
path = root.save(prefix)
v.assign(5.)
l *= 2
self.assertEqual(l, [1, 2, 3, 4, v, 1, 2, 3, 4, v])
- self.assertEqual([("4", v), ("9", v)], l._checkpoint_dependencies)
+ self.assertDictEqual({"4": v, "9": v}, l._trackable_children())
root.restore(path)
self.assertAllClose(1., v.numpy())
@@ -607,7 +607,7 @@
self.assertIs(v, m.nt.x)
self.assertIs(v, m.nt[0])
self.assertIs(
- v, m._checkpoint_dependencies[0].ref._checkpoint_dependencies[0].ref)
+ v, m._trackable_children()["nt"]._trackable_children()["x"])
self.assertEqual(2, m.nt.y)
def testNamedTupleConflictingAttributes(self):
@@ -637,10 +637,10 @@
m.nt = nt
self.assertEqual(3, m.nt.y)
self.assertIs(v, m.nt.x)
- self.assertIs(
- v, m._checkpoint_dependencies[0].ref._checkpoint_dependencies[0].ref)
- self.assertEqual("x", m.nt._checkpoint_dependencies[0].name)
- self.assertEqual("0", m.nt._checkpoint_dependencies[1].name)
+ self.assertIn(v,
+ m._trackable_children()["nt"]._trackable_children().values())
+ self.assertIn("x", m.nt._trackable_children())
+ self.assertIn("0", m.nt._trackable_children())
self.assertEqual(5, self.evaluate(m.nt.summed))
def testUnnamedSubclassing(self):
@@ -655,8 +655,8 @@
unt = UnnamedSubclass([v, 2])
m = module.Module()
m.unt = unt
- self.assertEqual("0", m.unt._checkpoint_dependencies[0].name)
- self.assertLen(m.unt._checkpoint_dependencies, 1)
+ self.assertIn("0", m.unt._trackable_children())
+ self.assertLen(m.unt._trackable_children(), 1)
self.assertEqual(4, self.evaluate(m.unt.summed))
nest.assert_same_structure(
[m.unt], nest.map_structure(lambda x: x, [m.unt]))
@@ -744,9 +744,9 @@
def testLoopAssignedModule(self):
m = module.Module()
m.s = (m,)
- self.assertLen(m._checkpoint_dependencies, 1)
- self.assertIs(m.s, m._checkpoint_dependencies[0].ref)
- self.assertIs("s", m._checkpoint_dependencies[0].name)
+ self.assertLen(m._trackable_children(), 1)
+ self.assertIn("s", m._trackable_children())
+ self.assertIs(m.s, m._trackable_children()["s"])
self.assertEqual((), m.trainable_variables)