Write checkpoint info to XDB at the end of an epoch

Summary: In this diff I am making sure that the checkpoint metadata is written out to the db for every epoch. This will allow us to automatically resume from a epoch if a workflow fails.

Reviewed By: aartibasant

Differential Revision: D6234832

fbshipit-source-id: f09a4de118f2eac25f663556476ac6313925fdf3
diff --git a/caffe2/python/checkpoint.py b/caffe2/python/checkpoint.py
index 961dc15..980b203 100644
--- a/caffe2/python/checkpoint.py
+++ b/caffe2/python/checkpoint.py
@@ -134,16 +134,37 @@
     Controls saving and loading of workspaces on every epoch boundary of a job.
     If a CheckpointManager instance is passed to JobRunner, then JobRunner will
     call `init`, `read` and `save` at different moments in between epoch runs.
+
+    Args:
+        db_prefix: The prefix used to construct full db name. Since `absolute_path`
+            is set to True, this will be used as db_name in SaveOp.
+        node_name: Name of the node where this checkpoint_manager is used.
+        db_type: Type of database to use for storing checkpoint.
+        metadata_writer: An optional writer capable of writing checkpoint info
+            in storage of choice.
     """
-    def __init__(self, db_prefix, node_name, db_type):
+    def __init__(self, db_prefix, node_name, db_type, metadata_writer=None):
         self._db_prefix = db_prefix
         self._node_name = node_name
         self._db_type = db_type
+        self._metadata_writer = metadata_writer
         # make sure these blobs are the first in the checkpoint file.
         self._net = core.Net('!!checkpoint_mngr')
         self._blob_names = self._net.AddExternalInput('blob_names')
         self._names_output = None
 
+    """
+    Initialize the checkpoint manager. Determines all blobs that need to be saved
+    or loads from a checkpoint.
+
+    Args:
+        nodes: An array of nodes where this checkpoint manager is running. Should
+            only contain a single node.
+        retrieve_from_epoch: Set to a number to load blobs from this epoch.
+        path_prefix: Used to construct db name or path where checkpoint files are
+            stored.
+        path_type: Indicate the type of path where checkpoint files are stored.
+    """
     def init(
         self,
         nodes=None,
@@ -159,6 +180,10 @@
         """
         assert nodes is None or len(nodes) == 1, (
             'CheckpointManager only supports single node.')
+
+        self._path_prefix = path_prefix
+        self._path_type = path_type
+
         with Task(outputs=[self._blob_names]) as task:
             if retrieve_from_epoch is None:
                 ops.GetAllBlobNames(
@@ -260,17 +285,36 @@
                 db_type=self._db_type, absolute_path=True)
         return task
 
+    def write_checkpoint_metadata(self, epoch):
+        if self._metadata_writer is not None:
+            self._metadata_writer.write(
+                epoch=epoch,
+                db_type=self._db_type,
+                db_prefix=self._db_prefix,
+                path_type=self._path_type,
+                path_prefix=self._path_prefix,
+                node_names=[self._node_name],
+            )
+
 
 class MultiNodeCheckpointManager(object):
     """
     Coordinates checkpointing and checkpointing across multiple nodes.
     Each of `init`, `load` and `save` will build TaskGroups which will
     trigger checkpointing on each of the nodes involved in a distributed job.
+
+    Args:
+        db_prefix: The prefix used to construct full db name. Since `absolute_path`
+            is set to True, this will be used as db_name in SaveOp.
+        db_type: Type of database to use for storing checkpoint.
+        metadata_writer: An optional writer capable of writing checkpoint info
+            in storage of choice.
     """
-    def __init__(self, db_prefix, db_type):
+    def __init__(self, db_prefix, db_type, metadata_writer=None):
         self._node_managers = None
         self._db_prefix = db_prefix
         self._db_type = db_type
+        self._metadata_writer = metadata_writer
 
     def _task_group(self, func, *args, **kw):
         assert self._node_managers is not None, 'init must be called first.'
@@ -280,6 +324,14 @@
                     func(manager, *args, **kw)
             return task_group
 
+    """
+    Args:
+        nodes: An array of nodes where this checkpoint manager is running.
+        retrieve_from_epoch: Set to a number to load blobs from this epoch.
+        path_prefix: Used to construct db name or path where checkpoint files are
+            stored.
+        path_type: Indicate the type of path where checkpoint files are stored.
+    """
     def init(
         self, nodes, retrieve_from_epoch=None, path_prefix=None, path_type=None
     ):
@@ -287,6 +339,9 @@
             assert [node for node, _ in self._node_managers] == nodes
             return
         self._node_managers = []
+        self._path_prefix = path_prefix
+        self._path_type = path_type
+        self._node_names = [str(node) for node in nodes]
         for node in nodes:
             with Node(node):
                 manager = CheckpointManager(
@@ -362,6 +417,17 @@
     def save(self, epoch):
         return self._task_group(CheckpointManager.save, epoch)
 
+    def write_checkpoint_metadata(self, epoch):
+        if self._metadata_writer is not None:
+            self._metadata_writer.write(
+                epoch=epoch,
+                db_type=self._db_type,
+                db_prefix=self._db_prefix,
+                path_type=self._path_type,
+                path_prefix=self._path_prefix,
+                node_names=self._node_names,
+            )
+
 
 class UploadTaskGroupBuilder(object):
     """A simple class to upload checkpoints."""
@@ -438,6 +504,7 @@
             if from_scratch:
                 logger.info('Saving first checkpoints ...')
                 session.run(self.checkpoint_manager.save(0))
+                self.checkpoint_manager.write_checkpoint_metadata(0)
                 logger.info('First checkpoints saved')
             else:
                 logger.info('Loading checkpoints for epoch {} ...'.format(
@@ -458,6 +525,7 @@
             if self.checkpoint_manager:
                 logger.info('Saving checkpoints for epoch {}'.format(epoch))
                 session.run(self.checkpoint_manager.save(epoch))
+                self.checkpoint_manager.write_checkpoint_metadata(epoch)
                 logger.info('Checkpoints saved')
 
             if any(stop_signals):