Correctly set distributed variable's update_uid if initialized from a checkpoint using `CheckpointInitialValueCallable`.
PiperOrigin-RevId: 369679126
Change-Id: Ib5c449dfdf02a95f804ad6132359017d3b38d49e
diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py
index 7da2232..69f6027 100644
--- a/tensorflow/python/distribute/distribute_lib.py
+++ b/tensorflow/python/distribute/distribute_lib.py
@@ -192,6 +192,7 @@
import collections
import copy
import enum # pylint: disable=g-bad-import-order
+import functools
import threading
import weakref
@@ -2105,6 +2106,16 @@
trackable.CheckpointInitialValueCallable):
checkpoint_restore_uid = kwargs[
"initial_value"].checkpoint_position.restore_uid
+ elif (isinstance(kwargs["initial_value"], functools.partial) and
+ isinstance(kwargs["initial_value"].func,
+ trackable.CheckpointInitialValueCallable)):
+ # Some libraries (e.g, Keras) create partial function out of initializer
+ # to bind shape/dtype, for example:
+ # initial_val = functools.partial(initializer, shape, dtype=dtype)
+ # Therefore to get the restore_uid we need to examine the "func" of
+ # the partial function.
+ checkpoint_restore_uid = kwargs[
+ "initial_value"].func.checkpoint_position.restore_uid
else:
checkpoint_restore_uid = None