blob: 1c6e803a7fde018c468ea16ca2552592dc4fba14 [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Saves and restore variables inside traced @tf.functions."""
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.saved_model import registration
from tensorflow.python.training.saving import checkpoint_options
from tensorflow.python.training.saving import saveable_hook
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.util import nest
class _SingleDeviceSaver(object):
"""Saves and restores checkpoints from the current device."""
__slots__ = ["_saveable_objects"]
def __init__(self, saveable_objects):
"""Specify a list of `SaveableObject`s to save and restore.
Args:
saveable_objects: A list of `SaveableObject`s.
"""
saveable_objects = list(saveable_objects)
for saveable in saveable_objects:
if not isinstance(saveable, saveable_object.SaveableObject):
raise ValueError(f"Expected a list of SaveableObjects, got {saveable}.")
self._saveable_objects = saveable_objects
def save(self, file_prefix, options=None):
"""Save the saveable objects to a checkpoint with `file_prefix`.
Args:
file_prefix: A string or scalar string Tensor containing the prefix to
save under.
options: Optional `CheckpointOptions` object.
Returns:
An `Operation`, or None when executing eagerly.
"""
options = options or checkpoint_options.CheckpointOptions()
tensor_names = []
tensors = []
tensor_slices = []
for saveable in self._saveable_objects:
for spec in saveable.specs:
tensor = spec.tensor
# A tensor value of `None` indicates that this SaveableObject gets
# recorded in the object graph, but that no value is saved in the
# checkpoint.
if tensor is not None:
tensor_names.append(spec.name)
tensors.append(tensor)
tensor_slices.append(spec.slice_spec)
save_device = options.experimental_io_device or "cpu:0"
with ops.device(save_device):
return io_ops.save_v2(file_prefix, tensor_names, tensor_slices, tensors)
def restore(self, file_prefix, options=None):
"""Restore the saveable objects from a checkpoint with `file_prefix`.
Args:
file_prefix: A string or scalar string Tensor containing the prefix for
files to read from.
options: Optional `CheckpointOptions` object.
Returns:
A dictionary mapping from SaveableObject names to restore operations.
"""
options = options or checkpoint_options.CheckpointOptions()
restore_specs = []
tensor_structure = []
for saveable in self._saveable_objects:
saveable_tensor_structure = []
tensor_structure.append(saveable_tensor_structure)
for spec in saveable.specs:
saveable_tensor_structure.append(spec.name)
restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
restore_device = options.experimental_io_device or "cpu:0"
with ops.device(restore_device):
restored_tensors = io_ops.restore_v2(
file_prefix, tensor_names, tensor_slices, tensor_dtypes)
structured_restored_tensors = nest.pack_sequence_as(
tensor_structure, restored_tensors)
restore_ops = {}
for saveable, restored_tensors in zip(self._saveable_objects,
structured_restored_tensors):
restore_ops[saveable.name] = saveable.restore(
restored_tensors, restored_shapes=None)
return restore_ops
def sharded_filename(filename_tensor, shard, num_shards):
"""Append sharding information to a filename.
Args:
filename_tensor: A string tensor.
shard: Integer. The shard for the filename.
num_shards: An int Tensor for the number of shards.
Returns:
A string tensor.
"""
return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards)
def registered_saver_filename(filename_tensor, saver_name):
return string_ops.string_join(
[filename_tensor, constant_op.constant(f"-{saver_name}")])
def _get_mapped_registered_save_fn(fn, trackables, call_with_mapped_captures):
"""Converts the function to a python or tf.function with a single file arg."""
if call_with_mapped_captures is None:
def mapped_fn(file_prefix):
return fn(trackables=trackables, file_prefix=file_prefix)
return mapped_fn
else:
tf_fn = def_function.function(fn, autograph=False)
concrete = tf_fn.get_concrete_function(
trackables=trackables,
file_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string))
def mapped_fn(file_prefix):
return call_with_mapped_captures(concrete, [file_prefix])
return mapped_fn
def _get_mapped_registered_restore_fn(fn, trackables,
call_with_mapped_captures):
"""Converts the function to a python or tf.function with a single file arg."""
if call_with_mapped_captures is None:
def mapped_fn(merged_prefix):
return fn(trackables=trackables, merged_prefix=merged_prefix)
return mapped_fn
else:
tf_fn = def_function.function(fn, autograph=False)
concrete = tf_fn.get_concrete_function(
trackables=trackables,
merged_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string))
def mapped_fn(merged_prefix):
return call_with_mapped_captures(concrete, [merged_prefix])
return mapped_fn
class MultiDeviceSaver(object):
"""Saves checkpoints directly from multiple devices.
Note that this is a low-level utility which stores Tensors in the keys
specified by `SaveableObject`s. Higher-level utilities for object-based
checkpointing are built on top of it.
"""
def __init__(self,
saveable_objects,
registered_savers=None,
call_with_mapped_captures=None):
"""Specify a list of `SaveableObject`s to save and restore.
Args:
saveable_objects: A list of `SaveableObject`s.
Objects extending `SaveableObject` will be saved and restored, and
objects extending `SaveableHook` will be called into at save and
restore time.
registered_savers: A dictionary mapping `registration.RegisteredSaver`
namedtuples to a dictionary of named Trackables. The keys of the
Trackable dictionary are string names that uniquely identify the
Trackable in the checkpoint.
call_with_mapped_captures: TODO
"""
self._before_save_callbacks = []
self._after_restore_callbacks = []
saveable_objects = list(saveable_objects)
saveables_by_device = {}
for saveable in saveable_objects:
is_saveable = isinstance(saveable, saveable_object.SaveableObject)
is_hook = isinstance(saveable, saveable_hook.SaveableHook)
if not is_saveable and not is_hook:
raise ValueError(
f"Expected a dictionary of SaveableObjects, got {saveable}.")
if is_hook:
self._before_save_callbacks.append(saveable.before_save)
self._after_restore_callbacks.append(saveable.after_restore)
if is_saveable:
host_device = saveable_object_util.set_cpu0(saveable.device)
saveables_by_device.setdefault(host_device, []).append(saveable)
self._single_device_savers = {
device: _SingleDeviceSaver(saveables)
for device, saveables in saveables_by_device.items()}
self._registered_savers = {}
if registered_savers:
for registered_name, trackables in registered_savers.items():
save_fn = _get_mapped_registered_save_fn(
registration.get_save_function(registered_name),
trackables, call_with_mapped_captures)
restore_fn = _get_mapped_registered_restore_fn(
registration.get_restore_function(registered_name),
trackables, call_with_mapped_captures)
self._registered_savers[registered_name] = (save_fn, restore_fn)
def to_proto(self):
"""Serializes to a SaverDef referencing the current graph."""
filename_tensor = array_ops.placeholder(
shape=[], dtype=dtypes.string, name="saver_filename")
save_tensor = self._traced_save(filename_tensor)
restore_op = self._traced_restore(filename_tensor).op
return saver_pb2.SaverDef(
filename_tensor_name=filename_tensor.name,
save_tensor_name=save_tensor.name,
restore_op_name=restore_op.name,
version=saver_pb2.SaverDef.V2)
@def_function.function(
input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),),
autograph=False)
def _traced_save(self, file_prefix):
save_op = self.save(file_prefix)
with ops.device("cpu:0"):
with ops.control_dependencies([save_op]):
return array_ops.identity(file_prefix)
@def_function.function(
input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),),
autograph=False)
def _traced_restore(self, file_prefix):
restore_ops = self.restore(file_prefix)
with ops.device("cpu:0"):
with ops.control_dependencies(restore_ops.values()):
return array_ops.identity(file_prefix)
def save(self, file_prefix, options=None):
"""Save the saveable objects to a checkpoint with `file_prefix`.
Args:
file_prefix: A string or scalar string Tensor containing the prefix to
save under.
options: Optional `CheckpointOptions` object.
Returns:
An `Operation`, or None when executing eagerly.
"""
options = options or checkpoint_options.CheckpointOptions()
for callback in self._before_save_callbacks:
callback()
# IMPLEMENTATION DETAILS: most clients should skip.
#
# Suffix for any well-formed "checkpoint_prefix", when sharded.
# Transformations:
# * Users pass in "save_path" in save() and restore(). Say "myckpt".
# * checkpoint_prefix gets fed <save_path><sharded_suffix>.
#
# Example:
# During runtime, a temporary directory is first created, which contains
# files
#
# <train dir>/myckpt_temp/
# part-?????-of-?????{.index, .data-00000-of-00001}
#
# Before .save() finishes, they will be (hopefully, atomically) renamed to
#
# <train dir>/
# myckpt{.index, .data-?????-of-?????}
#
# Filesystems with eventual consistency (such as S3), don't need a
# temporary location. Using a temporary directory in those cases might
# cause situations where files are not available during copy.
#
# Users only need to interact with the user-specified prefix, which is
# "<train dir>/myckpt" in this case. Save() and Restore() work with the
# prefix directly, instead of any physical pathname. (On failure and
# subsequent restore, an outdated and orphaned temporary directory can be
# safely removed.)
with ops.device("CPU"):
sharded_suffix = array_ops.where(
string_ops.regex_full_match(file_prefix, "^s3://.*"),
constant_op.constant(".part"),
constant_op.constant("_temp/part"))
tmp_checkpoint_prefix = string_ops.string_join(
[file_prefix, sharded_suffix])
registered_paths = {
saver_name: registered_saver_filename(file_prefix, saver_name)
for saver_name in self._registered_savers
}
def save_fn():
saved_prefixes = []
# Save with the registered savers.
for saver_name, (save_fn, _) in self._registered_savers.items():
maybe_saved_prefixes = save_fn(registered_paths[saver_name])
if maybe_saved_prefixes is not None:
flattened_saved_prefixes = nest.flatten(maybe_saved_prefixes)
if not all(
tensor_util.is_tf_type(x) and x.dtype == dtypes.string
for x in flattened_saved_prefixes):
raise ValueError(
"Registered saver can only return `None` or "
f"string type tensors. Got {maybe_saved_prefixes}.")
saved_prefixes.extend(flattened_saved_prefixes)
# (Default saver) Save with single device savers.
num_shards = len(self._single_device_savers)
sharded_saves = []
num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
last_device = None
for shard, (device, saver) in enumerate(
sorted(self._single_device_savers.items())):
last_device = device
with ops.device(saveable_object_util.set_cpu0(device)):
shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard,
num_shards_tensor)
saved_prefixes.append(shard_prefix)
with ops.device(device):
# _SingleDeviceSaver will use the CPU device when necessary, but
# initial read operations should be placed on the SaveableObject's
# device.
sharded_saves.append(saver.save(shard_prefix, options))
with ops.control_dependencies(sharded_saves):
# Merge on the io_device if specified, otherwise co-locates the merge op
# with the last device used.
merge_device = (
options.experimental_io_device or
saveable_object_util.set_cpu0(last_device))
with ops.device(merge_device):
# V2 format write path consists of a metadata merge step. Once
# merged, attempts to delete the temporary directory,
# "<user-fed prefix>_temp".
return gen_io_ops.merge_v2_checkpoints(
saved_prefixes, file_prefix, delete_old_dirs=True)
# Since this will causes a function re-trace on each save, limit this to the
# cases where it is needed: eager and when there are multiple tasks/single
# device savers. Note that the retrace is needed to ensure we pickup the
# latest values of options like experimental_io_device.
if context.executing_eagerly() and len(self._single_device_savers) > 1:
# Explicitly place the identity op on the first device.
@def_function.function(jit_compile=False)
def tf_function_save():
save_fn()
tf_function_save()
else:
return save_fn()
def restore(self, file_prefix, options=None):
"""Restore the saveable objects from a checkpoint with `file_prefix`.
Args:
file_prefix: A string or scalar string Tensor containing the prefix for
files to read from.
options: Optional `CheckpointOptions` object.
Returns:
When not run eagerly or when saving on a single device, returns a
dictionary mapping from SaveableObject names to restore operations;
otherwise, returns an empty dict.
"""
options = options or checkpoint_options.CheckpointOptions()
def restore_fn():
restore_ops = {}
# Sort by device name to avoid propagating non-deterministic dictionary
# ordering in some Python versions.
for device, saver in sorted(self._single_device_savers.items()):
with ops.device(device):
restore_ops.update(saver.restore(file_prefix, options))
for _, (_, restore_fn) in self._registered_savers.items():
restore_fn(file_prefix)
return restore_ops
# Since this will causes a function re-trace on each restore, limit this to
# cases where it is needed: eager and when there are multiple tasks/single
# device savers. Note that the retrace is needed to ensure we pickup the
# latest values of options like experimental_io_device.
if context.executing_eagerly() and len(self._single_device_savers) > 1:
@def_function.function(jit_compile=False)
def tf_function_restore():
restore_fn()
return {}
restore_ops = tf_function_restore()
else:
restore_ops = restore_fn()
for callback in self._after_restore_callbacks:
callback()
return restore_ops