Migrate usages of deprecated Trackable methods to `_trackable_children`.
PiperOrigin-RevId: 421881039
Change-Id: I7ccce930abad62bda6962a10c2ac1e04c3019d09
diff --git a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py
index fa6a51b..31e4548 100644
--- a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py
+++ b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2_test.py
@@ -45,10 +45,10 @@
bias_initializer=init_ops.constant_initializer(0.5))
g, m_new = base_cell(x, m)
wrapper_object = wrapper_type(base_cell)
- (name, dep), = wrapper_object._checkpoint_dependencies
+ children = wrapper_object._trackable_children()
wrapper_object.get_config() # Should not throw an error
- self.assertIs(dep, base_cell)
- self.assertEqual("cell", name)
+ self.assertIn("cell", children)
+ self.assertIs(children["cell"], base_cell)
g_res, m_new_res = wrapper_object(x, m)
self.evaluate([variables_lib.global_variables_initializer()])
@@ -89,10 +89,10 @@
m = array_ops.zeros([1, 3])
cell = rnn_cell_impl.GRUCell(3)
wrapped_cell = wrapper_type(cell, "/cpu:0")
- (name, dep), = wrapped_cell._checkpoint_dependencies
+ children = wrapped_cell._trackable_children()
wrapped_cell.get_config() # Should not throw an error
- self.assertIs(dep, cell)
- self.assertEqual("cell", name)
+ self.assertIn("cell", children)
+ self.assertIs(children["cell"], cell)
outputs, _ = wrapped_cell(x, m)
self.assertIn("cpu:0", outputs.device.lower())