Make "map_resources" overridable by subclass of `Trackable`.
This allows moving the implementation of map_resources from `tf.saved_model.save` to subclass of `Trackable`, e.g, Variable, DistributedVariable.
This is a non-functional change.
PiperOrigin-RevId: 317198449
Change-Id: I4aa48d4974b6547b5de8ac0f5c38f3da29d364bc
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 96559a9..7208807 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -744,14 +744,12 @@
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
- "//tensorflow/python:tensor_util",
+ "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:tf_export",
"//tensorflow/python:type_spec",
- "//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
- "//tensorflow/python/eager:tape",
"//tensorflow/python/training/saving:saveable_object",
"//tensorflow/python/training/saving:saveable_object_util",
"//tensorflow/python/training/tracking:base",
diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py
index d0ed27c..60b2ea4 100644
--- a/tensorflow/python/distribute/values.py
+++ b/tensorflow/python/distribute/values.py
@@ -32,6 +32,7 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training.saving import saveable_object
@@ -793,6 +794,17 @@
return ops.convert_to_tensor(
self._get(), dtype=dtype, name=name, as_ref=as_ref)
+ def _map_resources(self):
+ """For implementing `Trackable`."""
+ new_obj = resource_variable_ops.copy_to_graph_uninitialized(self._primary)
+ obj_map, resource_map = {}, {}
+ for v in self._values:
+ obj_map[v] = new_obj
+ resource_map[v.handle] = new_obj.handle
+ obj_map[self] = new_obj
+ resource_map[self] = new_obj.handle
+ return obj_map, resource_map
+
class _DistributedVariableSaveable(saveable_object.SaveableObject):
"""Class for defining how to restore a DistributedVariable."""
diff --git a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py
index 7d0abe3..57e8ced 100644
--- a/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py
+++ b/tensorflow/python/keras/mixed_precision/experimental/autocast_variable.py
@@ -285,6 +285,13 @@
# models with normal variables, and vice versa.
return self._variable._gather_saveables_for_checkpoint() # pylint:disable=protected-access
+ def _map_resources(self):
+ # By delegating this method to the wrapped variable, SavedModel with
+ # AutoCastVariables are identical to SavedModel with normal variables.
+ obj_map, resource_map = self._variable._map_resources() # pylint:disable=protected-access
+ obj_map[self] = obj_map[self._variable]
+ return obj_map, resource_map
+
# TODO(reedwm): Maybe encode the fact the variable is an AutoCastVariable in
# to_proto().
def to_proto(self, export_scope=None):
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 25f6347..cb235fc 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -633,6 +633,13 @@
return gen_state_ops.resource_count_up_to(self.handle, limit=limit,
T=self.dtype)
+ def _map_resources(self):
+ """For implementing `Trackable`."""
+ new_variable = copy_to_graph_uninitialized(self)
+ obj_map = {self: new_variable}
+ resource_map = {self._handle: new_variable.handle}
+ return obj_map, resource_map
+
def _read_variable_op(self):
variable_accessed(self)
diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py
index 5844c80..802ce1d 100644
--- a/tensorflow/python/saved_model/save.py
+++ b/tensorflow/python/saved_model/save.py
@@ -19,14 +19,12 @@
from __future__ import print_function
import collections
-import copy
import os
from tensorflow.core.framework import versions_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.core.protobuf import saved_object_graph_pb2
-from tensorflow.python.distribute import distribute_utils as ds_utils
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun
@@ -241,7 +239,7 @@
Creates resource handle ops in the current default graph, whereas
`accessible_objects` will be from an eager context. Resource mapping adds
resource handle ops to the main GraphDef of a SavedModel, which allows the
- C++ loader API to interact with variables.
+ C++ loader API to interact with resources.
Returns:
A tuple of (object_map, resource_map, asset_info):
@@ -265,33 +263,15 @@
asset_index={})
for node_id, obj in enumerate(self.nodes):
- if isinstance(obj, tracking.CapturableResource):
- new_obj = object_map[obj] = copy.copy(obj)
- # pylint: disable=protected-access
- with ops.device(obj._resource_device):
- new_resource = new_obj._create_resource()
- new_obj._resource_handle = new_resource
- # pylint: enable=protected-access
- resource_map[obj.resource_handle] = new_resource
- self.captured_tensor_node_ids[obj.resource_handle] = node_id
- elif (ds_utils.is_distributed_variable(obj) or
- resource_variable_ops.is_resource_variable(obj)):
- obj_to_copy = obj._primary if ds_utils.is_distributed_variable( # pylint: disable=protected-access
- obj) else obj
- new_variable = resource_variable_ops.copy_to_graph_uninitialized(
- obj_to_copy)
- if ds_utils.is_distributed_variable(obj):
- self.captured_tensor_node_ids[obj] = node_id
- for v in obj.values:
- object_map[v] = new_variable
- resource_map[v.handle] = new_variable.handle
- self.captured_tensor_node_ids[v.handle] = node_id
- object_map[obj] = new_variable
- resource_map[obj.handle] = new_variable.handle
- self.captured_tensor_node_ids[obj.handle] = node_id
- elif isinstance(obj, tracking.Asset):
+ if isinstance(obj, tracking.Asset):
_process_asset(obj, asset_info, resource_map)
self.captured_tensor_node_ids[obj.asset_path] = node_id
+ elif isinstance(obj, base.Trackable):
+ node_object_map, node_resource_map = obj._map_resources() # pylint: disable=protected-access
+ for capturable in node_resource_map.keys():
+ self.captured_tensor_node_ids[capturable] = node_id
+ object_map.update(node_object_map)
+ resource_map.update(node_resource_map)
# Note: some concrete functions can have been realized when tracing other
# functions, and might closure-capture tensors from their parent functions.
diff --git a/tensorflow/python/training/tracking/base.py b/tensorflow/python/training/tracking/base.py
index e3cd982..ea76ad8 100644
--- a/tensorflow/python/training/tracking/base.py
+++ b/tensorflow/python/training/tracking/base.py
@@ -1021,3 +1021,21 @@
"""
del serialization_cache
return dict()
+
+ def _map_resources(self):
+ """Makes new resource handle ops corresponding to existing resource tensors.
+
+ Internal sub-classes can override this to inform model saving how to add new
+ resource handle ops to the main GraphDef of a SavedModel (TF 1.x style
+ graph), which allows session based APIs (e.g, C++ loader API) to interact
+ with resources owned by this object.
+
+ Returns:
+ A tuple of (object_map, resource_map):
+ object_map: A dictionary mapping from objects that hold existing
+ resource tensors to replacement objects created to hold the new
+ resource tensors.
+ resource_map: A dictionary mapping from existing resource tensors to
+ newly created resource tensors.
+ """
+ return {}, {}
diff --git a/tensorflow/python/training/tracking/tracking.py b/tensorflow/python/training/tracking/tracking.py
index 553f0ec..fb2735e 100644
--- a/tensorflow/python/training/tracking/tracking.py
+++ b/tensorflow/python/training/tracking/tracking.py
@@ -17,6 +17,7 @@
from __future__ import division
from __future__ import print_function
+import copy
import functools
import weakref
@@ -243,6 +244,18 @@
self._resource_handle = self._create_resource()
return self._resource_handle
+ def _map_resources(self):
+ """For implementing `Trackable`."""
+ new_obj = copy.copy(self)
+ # pylint: disable=protected-access
+ with ops.device(self._resource_device):
+ new_resource = new_obj._create_resource()
+ new_obj._resource_handle = new_resource
+ # pylint: enable=protected-access
+ obj_map = {self: new_obj}
+ resource_map = {self.resource_handle: new_resource}
+ return obj_map, resource_map
+
def _list_functions_for_serialization(self, unused_functions):
@def_function.function(input_signature=[], autograph=False)
def _creator():