blob: 21e3ab673a191e37c9f8c2b41a6b82ff12227b80 [file] [log] [blame]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
import logging
from caffe2.python import core, context
from caffe2.python.task import Node, Task, TaskGroup, TaskOutput, WorkspaceType
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
@context.define_context()
class Job(object):
"""
A Job defines three TaskGroups: the `init_group`, the `epoch_group` and the
`exit_group` which will be run by a JobRunner.
The `init_group` will be run only once at startup. Its role is to
initialize globally persistent blobs such as model weights, accumulators
and data file lists.
The `epoch_group` will be run in a loop after init_group. The loop will
exit when any of the stop signals added with `add_stop_signal` is True
at the end of an epoch.
The `exit_group` will be run only once at the very end of the job, when one
of the stopping criterias for `epoch_group` was met. The role of this group
is save the results of training in the end of the job.
Jobs are context-driven, so that Tasks can be added to the active Job
without having to explicitly pass the job object around.
Example of usage:
def build_reader(partitions):
with Job.current().init_group:
reader = HiveReader(init_reader, ..., partitions)
Task(step=init_reader)
with Job.current().epoch_group:
limited_reader = ReaderWithLimit(reader, num_iter=10000)
data_queue = pipe(limited_reader, num_threads=8)
Job.current().add_stop_signal(limited_reader.data_finished())
return data_queue
def build_hogwild_trainer(reader, model):
with Job.current().init_group:
Task(step=model.param_init_net)
with Job.current().epoch_group:
pipe(reader, processor=model, num_threads=8)
with Job.current().exit_group:
Task(step=model.save_model_net)
with Job() as job:
reader = build_reader(partitions)
model = build_model(params)
build_hogwild_trainer(reader, model)
"""
def __init__(self,
init_group=None, epoch_group=None,
exit_group=None, stop_signals=None,
nodes_to_checkpoint=None):
self.init_group = init_group or TaskGroup(
workspace_type=WorkspaceType.GLOBAL)
self.epoch_group = epoch_group or TaskGroup()
self.exit_group = exit_group or TaskGroup()
self.stop_signals = stop_signals or []
self._nodes_to_checkpoint = nodes_to_checkpoint
def nodes_to_checkpoint(self):
if self._nodes_to_checkpoint:
return self._nodes_to_checkpoint
else:
return self.init_group.used_nodes()
def compile(self, session_class):
return Job(
init_group=session_class.compile(self.init_group),
epoch_group=session_class.compile(self.epoch_group),
exit_group=session_class.compile(self.exit_group),
stop_signals=self.stop_signals,
nodes_to_checkpoint=self.nodes_to_checkpoint())
def __enter__(self):
self.epoch_group.__enter__()
return self
def __exit__(self, *args):
self.epoch_group.__exit__()
def add_stop_signal(self, output):
if isinstance(output, core.BlobReference):
t = Task(outputs=[output], group=self.epoch_group)
output = t.outputs()[0]
assert isinstance(output, TaskOutput)
self.stop_signals.append(output)
class CheckpointManager(object):
"""
Controls saving and loading of workspaces on every epoch boundary of a job.
If a CheckpointManager instance is passed to JobRunner, then JobRunner will
call `init`, `read` and `save` at different moments in between epoch runs.
"""
def __init__(self, db, db_type):
self._db = db
self._db_type = db_type
# make sure these blobs are the first in the checkpoint file.
self._net = core.Net('!!checkpoint_mngr')
self._blob_names = self._net.AddExternalInput('blob_names')
self._names_output = None
def init(self, nodes=None, retrieve_from_epoch=None):
"""
Build a Task that will be run once after the job's `init_group` is run.
This task will determine which blobs need to be checkpointed.
If retrieve_from_epoch is not None, then the checkpoint metadata is
retrieved from a previously saved checkpoint.
"""
assert nodes is None or len(nodes) == 1, (
'CheckpointManager only supports single node.')
net = core.Net('get_blob_list')
if retrieve_from_epoch is None:
net.GetAllBlobNames(
[],
self._blob_names,
include_shared=False)
else:
net.Load(
[], self._blob_names,
db=self._dbname(retrieve_from_epoch),
db_type=self._db_type,
absolute_path=True)
task = Task(step=net, outputs=[self._blob_names])
self._names_output = task.outputs()[0]
return task
def blob_list(self):
assert self._names_output
return self._names_output.fetch().tolist()
def _dbname(self, epoch):
return '%s.%06d' % (self._db, epoch)
def load(self, epoch):
"""
Build a Task that will be run by JobRunner when the job is to be
resumed from a given epoch. This task will run a Load op that will
load and deserialize all relevant blobs from a persistent storage.
"""
net = core.Net('get_blob_list')
net.Load(
[],
self.blob_list(),
db=self._dbname(epoch),
db_type=self._db_type,
absolute_path=True)
return Task(step=net)
def save(self, epoch):
"""
Build a Task that is run once after `init_group` and after each
epoch is run. This will execute a Save ops to serialize and persist
blobs present in the global workspaace.
"""
net = core.Net('checkpoint_save')
net.Save(
self.blob_list(), [], db=self._dbname(epoch),
db_type=self._db_type, absolute_path=True)
return Task(step=net)
class MultiNodeCheckpointManager(object):
"""
Coordinates checkpointing and checkpointing across multiple nodes.
Each of `init`, `load` and `save` will build TaskGroups which will
trigger checkpointing on each of the nodes involved in a distributed job.
"""
def __init__(
self, db_prefix, db_type, node_manager_class=CheckpointManager):
self._node_manager_class = node_manager_class
self._node_managers = None
self._db_prefix = db_prefix
self._db_type = db_type
def _task_group(self, func, *args, **kw):
assert self._node_managers is not None, 'init must be called first.'
with TaskGroup(WorkspaceType.GLOBAL) as task_group:
for node, manager in self._node_managers:
with Node(node):
func(manager, *args, **kw)
return task_group
def init(self, nodes, retrieve_from_epoch=None):
if self._node_managers is not None:
assert [node for node, _ in self._node_managers] == nodes
return
self._node_managers = []
for node in nodes:
with Node(node):
manager = self._node_manager_class(
db=os.path.join(self._db_prefix, node),
db_type=self._db_type)
self._node_managers.append((node, manager))
return self._task_group(
self._node_manager_class.init,
nodes=[node],
retrieve_from_epoch=retrieve_from_epoch)
def load(self, epoch):
return self._task_group(self._node_manager_class.load, epoch)
def save(self, epoch):
return self._task_group(self._node_manager_class.save, epoch)
class JobRunner(object):
"""
Implement the runtime logic for jobs with checkpointing at the level of
epoch. Can be used to run either single-host or distributed jobs. Job
runner is a callable to be called once from the client, passing a Session
as argument. This call will block until the Job execution is complete.
If a checkpoint_manager is passed, checkpoints will be taken after
initialization and after each epoch execution. If, in addition,
`resume_from_epoch` is an epoch number, the corresponding checkpoint will
be loaded and job execution will continue from the given epoch. In
this case, the job's init_group will not be run.
Refer to checkpoint_test.py for an example.
"""
def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None):
self.resume_from_epoch = resume_from_epoch
self.checkpoint = checkpoint_manager
self.job = job
def __call__(self, client):
from_scratch = self.resume_from_epoch is None
if from_scratch:
client.run(self.job.init_group)
if self.checkpoint:
logger.info('Preparing checkpoint ...')
client.run(self.checkpoint.init(
self.job.nodes_to_checkpoint(),
retrieve_from_epoch=self.resume_from_epoch))
if from_scratch:
logger.info('Saving first checkpoint ...')
client.run(self.checkpoint.save(0))
logger.info('First checkpoint saved.')
else:
logger.info('Loading checkpoint for epoch {} ...'.format(
self.resume_from_epoch))
client.run(self.checkpoint.load(self.resume_from_epoch))
logger.info('Checkpoint loaded.')
epoch = 1 if from_scratch else self.resume_from_epoch + 1
while True:
logger.info('Starting epoch %d.' % epoch)
client.run(self.job.epoch_group)
logger.info('Ran epoch %d.' % epoch)
stop_signals = [o.fetch() for o in self.job.stop_signals]
if self.checkpoint:
logger.info('Saving checkpoint ...')
client.run(self.checkpoint.save(epoch))
logger.info('Checkpoint saved.')
if any(stop_signals):
logger.info('Stopping.')
break
epoch += 1
client.run(self.job.exit_group)
return epoch
def epoch_limiter(num_epochs):
"""
Creates a task that will output True when a given
number of epochs has finished.
"""
with Job.current().init_group:
init_net = core.Net('epoch_counter_init')
counter = init_net.CreateCounter([], init_count=num_epochs - 1)
Task(step=init_net)
epoch_net = core.Net('epoch_countdown')
finished = epoch_net.CountDown(counter)
output = Task(step=epoch_net, outputs=finished).outputs()[0]
Job.current().add_stop_signal(output)