Write checkpoint info to XDB at the end of an epoch
Summary: In this diff I am making sure that the checkpoint metadata is written out to the db for every epoch. This will allow us to automatically resume from a epoch if a workflow fails.
Reviewed By: aartibasant
Differential Revision: D6234832
fbshipit-source-id: f09a4de118f2eac25f663556476ac6313925fdf3
diff --git a/caffe2/python/checkpoint.py b/caffe2/python/checkpoint.py
index 961dc15..980b203 100644
--- a/caffe2/python/checkpoint.py
+++ b/caffe2/python/checkpoint.py
@@ -134,16 +134,37 @@
Controls saving and loading of workspaces on every epoch boundary of a job.
If a CheckpointManager instance is passed to JobRunner, then JobRunner will
call `init`, `read` and `save` at different moments in between epoch runs.
+
+ Args:
+ db_prefix: The prefix used to construct full db name. Since `absolute_path`
+ is set to True, this will be used as db_name in SaveOp.
+ node_name: Name of the node where this checkpoint_manager is used.
+ db_type: Type of database to use for storing checkpoint.
+ metadata_writer: An optional writer capable of writing checkpoint info
+ in storage of choice.
"""
- def __init__(self, db_prefix, node_name, db_type):
+ def __init__(self, db_prefix, node_name, db_type, metadata_writer=None):
self._db_prefix = db_prefix
self._node_name = node_name
self._db_type = db_type
+ self._metadata_writer = metadata_writer
# make sure these blobs are the first in the checkpoint file.
self._net = core.Net('!!checkpoint_mngr')
self._blob_names = self._net.AddExternalInput('blob_names')
self._names_output = None
+ """
+ Initialize the checkpoint manager. Determines all blobs that need to be saved
+ or loads from a checkpoint.
+
+ Args:
+ nodes: An array of nodes where this checkpoint manager is running. Should
+ only contain a single node.
+ retrieve_from_epoch: Set to a number to load blobs from this epoch.
+ path_prefix: Used to construct db name or path where checkpoint files are
+ stored.
+ path_type: Indicate the type of path where checkpoint files are stored.
+ """
def init(
self,
nodes=None,
@@ -159,6 +180,10 @@
"""
assert nodes is None or len(nodes) == 1, (
'CheckpointManager only supports single node.')
+
+ self._path_prefix = path_prefix
+ self._path_type = path_type
+
with Task(outputs=[self._blob_names]) as task:
if retrieve_from_epoch is None:
ops.GetAllBlobNames(
@@ -260,17 +285,36 @@
db_type=self._db_type, absolute_path=True)
return task
+ def write_checkpoint_metadata(self, epoch):
+ if self._metadata_writer is not None:
+ self._metadata_writer.write(
+ epoch=epoch,
+ db_type=self._db_type,
+ db_prefix=self._db_prefix,
+ path_type=self._path_type,
+ path_prefix=self._path_prefix,
+ node_names=[self._node_name],
+ )
+
class MultiNodeCheckpointManager(object):
"""
Coordinates checkpointing and checkpointing across multiple nodes.
Each of `init`, `load` and `save` will build TaskGroups which will
trigger checkpointing on each of the nodes involved in a distributed job.
+
+ Args:
+ db_prefix: The prefix used to construct full db name. Since `absolute_path`
+ is set to True, this will be used as db_name in SaveOp.
+ db_type: Type of database to use for storing checkpoint.
+ metadata_writer: An optional writer capable of writing checkpoint info
+ in storage of choice.
"""
- def __init__(self, db_prefix, db_type):
+ def __init__(self, db_prefix, db_type, metadata_writer=None):
self._node_managers = None
self._db_prefix = db_prefix
self._db_type = db_type
+ self._metadata_writer = metadata_writer
def _task_group(self, func, *args, **kw):
assert self._node_managers is not None, 'init must be called first.'
@@ -280,6 +324,14 @@
func(manager, *args, **kw)
return task_group
+ """
+ Args:
+ nodes: An array of nodes where this checkpoint manager is running.
+ retrieve_from_epoch: Set to a number to load blobs from this epoch.
+ path_prefix: Used to construct db name or path where checkpoint files are
+ stored.
+ path_type: Indicate the type of path where checkpoint files are stored.
+ """
def init(
self, nodes, retrieve_from_epoch=None, path_prefix=None, path_type=None
):
@@ -287,6 +339,9 @@
assert [node for node, _ in self._node_managers] == nodes
return
self._node_managers = []
+ self._path_prefix = path_prefix
+ self._path_type = path_type
+ self._node_names = [str(node) for node in nodes]
for node in nodes:
with Node(node):
manager = CheckpointManager(
@@ -362,6 +417,17 @@
def save(self, epoch):
return self._task_group(CheckpointManager.save, epoch)
+ def write_checkpoint_metadata(self, epoch):
+ if self._metadata_writer is not None:
+ self._metadata_writer.write(
+ epoch=epoch,
+ db_type=self._db_type,
+ db_prefix=self._db_prefix,
+ path_type=self._path_type,
+ path_prefix=self._path_prefix,
+ node_names=self._node_names,
+ )
+
class UploadTaskGroupBuilder(object):
"""A simple class to upload checkpoints."""
@@ -438,6 +504,7 @@
if from_scratch:
logger.info('Saving first checkpoints ...')
session.run(self.checkpoint_manager.save(0))
+ self.checkpoint_manager.write_checkpoint_metadata(0)
logger.info('First checkpoints saved')
else:
logger.info('Loading checkpoints for epoch {} ...'.format(
@@ -458,6 +525,7 @@
if self.checkpoint_manager:
logger.info('Saving checkpoints for epoch {}'.format(epoch))
session.run(self.checkpoint_manager.save(epoch))
+ self.checkpoint_manager.write_checkpoint_metadata(epoch)
logger.info('Checkpoints saved')
if any(stop_signals):