Correctly set distributed variable's update_uid if initialized from a checkpoint using `CheckpointInitialValueCallable`.

PiperOrigin-RevId: 369679126
Change-Id: Ib5c449dfdf02a95f804ad6132359017d3b38d49e
diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py
index 7da2232..69f6027 100644
--- a/tensorflow/python/distribute/distribute_lib.py
+++ b/tensorflow/python/distribute/distribute_lib.py
@@ -192,6 +192,7 @@
 import collections
 import copy
 import enum  # pylint: disable=g-bad-import-order
+import functools
 import threading
 import weakref
 
@@ -2105,6 +2106,16 @@
                       trackable.CheckpointInitialValueCallable):
         checkpoint_restore_uid = kwargs[
             "initial_value"].checkpoint_position.restore_uid
+      elif (isinstance(kwargs["initial_value"], functools.partial) and
+            isinstance(kwargs["initial_value"].func,
+                       trackable.CheckpointInitialValueCallable)):
+        # Some libraries (e.g, Keras) create partial function out of initializer
+        # to bind shape/dtype, for example:
+        #  initial_val = functools.partial(initializer, shape, dtype=dtype)
+        # Therefore to get the restore_uid we need to examine the "func" of
+        # the partial function.
+        checkpoint_restore_uid = kwargs[
+            "initial_value"].func.checkpoint_position.restore_uid
       else:
         checkpoint_restore_uid = None