| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| from __future__ import unicode_literals |
| |
| import os |
| import logging |
| from caffe2.python import core, context |
| from caffe2.python.task import Node, Task, TaskGroup, TaskOutput, WorkspaceType |
| |
| logger = logging.getLogger(__name__) |
| logger.setLevel(logging.INFO) |
| |
| |
| @context.define_context() |
| class Job(object): |
| """ |
| A Job defines three TaskGroups: the `init_group`, the `epoch_group` and the |
| `exit_group` which will be run by a JobRunner. |
| |
| The `init_group` will be run only once at startup. Its role is to |
| initialize globally persistent blobs such as model weights, accumulators |
| and data file lists. |
| |
| The `epoch_group` will be run in a loop after init_group. The loop will |
| exit when any of the stop signals added with `add_stop_signal` is True |
| at the end of an epoch. |
| |
| The `exit_group` will be run only once at the very end of the job, when one |
| of the stopping criterias for `epoch_group` was met. The role of this group |
| is save the results of training in the end of the job. |
| |
| Jobs are context-driven, so that Tasks can be added to the active Job |
| without having to explicitly pass the job object around. |
| |
| Example of usage: |
| |
| def build_reader(partitions): |
| with Job.current().init_group: |
| reader = HiveReader(init_reader, ..., partitions) |
| Task(step=init_reader) |
| with Job.current().epoch_group: |
| limited_reader = ReaderWithLimit(reader, num_iter=10000) |
| data_queue = pipe(limited_reader, num_threads=8) |
| Job.current().add_stop_signal(limited_reader.data_finished()) |
| return data_queue |
| |
| def build_hogwild_trainer(reader, model): |
| with Job.current().init_group: |
| Task(step=model.param_init_net) |
| with Job.current().epoch_group: |
| pipe(reader, processor=model, num_threads=8) |
| with Job.current().exit_group: |
| Task(step=model.save_model_net) |
| |
| with Job() as job: |
| reader = build_reader(partitions) |
| model = build_model(params) |
| build_hogwild_trainer(reader, model) |
| """ |
| def __init__(self, |
| init_group=None, epoch_group=None, |
| exit_group=None, stop_signals=None, |
| nodes_to_checkpoint=None): |
| self.init_group = init_group or TaskGroup( |
| workspace_type=WorkspaceType.GLOBAL) |
| self.epoch_group = epoch_group or TaskGroup() |
| self.exit_group = exit_group or TaskGroup() |
| self.stop_signals = stop_signals or [] |
| self._nodes_to_checkpoint = nodes_to_checkpoint |
| |
| def nodes_to_checkpoint(self): |
| if self._nodes_to_checkpoint: |
| return self._nodes_to_checkpoint |
| else: |
| return self.init_group.used_nodes() |
| |
| def compile(self, session_class): |
| return Job( |
| init_group=session_class.compile(self.init_group), |
| epoch_group=session_class.compile(self.epoch_group), |
| exit_group=session_class.compile(self.exit_group), |
| stop_signals=self.stop_signals, |
| nodes_to_checkpoint=self.nodes_to_checkpoint()) |
| |
| def __enter__(self): |
| self.epoch_group.__enter__() |
| return self |
| |
| def __exit__(self, *args): |
| self.epoch_group.__exit__() |
| |
| def add_stop_signal(self, output): |
| if isinstance(output, core.BlobReference): |
| t = Task(outputs=[output], group=self.epoch_group) |
| output = t.outputs()[0] |
| assert isinstance(output, TaskOutput) |
| self.stop_signals.append(output) |
| |
| |
| class CheckpointManager(object): |
| """ |
| Controls saving and loading of workspaces on every epoch boundary of a job. |
| If a CheckpointManager instance is passed to JobRunner, then JobRunner will |
| call `init`, `read` and `save` at different moments in between epoch runs. |
| """ |
| def __init__(self, db, db_type): |
| self._db = db |
| self._db_type = db_type |
| # make sure these blobs are the first in the checkpoint file. |
| self._net = core.Net('!!checkpoint_mngr') |
| self._blob_names = self._net.AddExternalInput('blob_names') |
| self._names_output = None |
| |
| def init(self, nodes=None, retrieve_from_epoch=None): |
| """ |
| Build a Task that will be run once after the job's `init_group` is run. |
| This task will determine which blobs need to be checkpointed. |
| If retrieve_from_epoch is not None, then the checkpoint metadata is |
| retrieved from a previously saved checkpoint. |
| """ |
| assert nodes is None or len(nodes) == 1, ( |
| 'CheckpointManager only supports single node.') |
| net = core.Net('get_blob_list') |
| if retrieve_from_epoch is None: |
| net.GetAllBlobNames( |
| [], |
| self._blob_names, |
| include_shared=False) |
| else: |
| net.Load( |
| [], self._blob_names, |
| db=self._dbname(retrieve_from_epoch), |
| db_type=self._db_type, |
| absolute_path=True) |
| task = Task(step=net, outputs=[self._blob_names]) |
| self._names_output = task.outputs()[0] |
| return task |
| |
| def blob_list(self): |
| assert self._names_output |
| return self._names_output.fetch().tolist() |
| |
| def _dbname(self, epoch): |
| return '%s.%06d' % (self._db, epoch) |
| |
| def load(self, epoch): |
| """ |
| Build a Task that will be run by JobRunner when the job is to be |
| resumed from a given epoch. This task will run a Load op that will |
| load and deserialize all relevant blobs from a persistent storage. |
| """ |
| net = core.Net('get_blob_list') |
| net.Load( |
| [], |
| self.blob_list(), |
| db=self._dbname(epoch), |
| db_type=self._db_type, |
| absolute_path=True) |
| return Task(step=net) |
| |
| def save(self, epoch): |
| """ |
| Build a Task that is run once after `init_group` and after each |
| epoch is run. This will execute a Save ops to serialize and persist |
| blobs present in the global workspaace. |
| """ |
| net = core.Net('checkpoint_save') |
| net.Save( |
| self.blob_list(), [], db=self._dbname(epoch), |
| db_type=self._db_type, absolute_path=True) |
| return Task(step=net) |
| |
| |
| class MultiNodeCheckpointManager(object): |
| """ |
| Coordinates checkpointing and checkpointing across multiple nodes. |
| Each of `init`, `load` and `save` will build TaskGroups which will |
| trigger checkpointing on each of the nodes involved in a distributed job. |
| """ |
| def __init__( |
| self, db_prefix, db_type, node_manager_class=CheckpointManager): |
| self._node_manager_class = node_manager_class |
| self._node_managers = None |
| self._db_prefix = db_prefix |
| self._db_type = db_type |
| |
| def _task_group(self, func, *args, **kw): |
| assert self._node_managers is not None, 'init must be called first.' |
| with TaskGroup(WorkspaceType.GLOBAL) as task_group: |
| for node, manager in self._node_managers: |
| with Node(node): |
| func(manager, *args, **kw) |
| return task_group |
| |
| def init(self, nodes, retrieve_from_epoch=None): |
| if self._node_managers is not None: |
| assert [node for node, _ in self._node_managers] == nodes |
| return |
| self._node_managers = [] |
| for node in nodes: |
| with Node(node): |
| manager = self._node_manager_class( |
| db=os.path.join(self._db_prefix, node), |
| db_type=self._db_type) |
| self._node_managers.append((node, manager)) |
| return self._task_group( |
| self._node_manager_class.init, |
| nodes=[node], |
| retrieve_from_epoch=retrieve_from_epoch) |
| |
| def load(self, epoch): |
| return self._task_group(self._node_manager_class.load, epoch) |
| |
| def save(self, epoch): |
| return self._task_group(self._node_manager_class.save, epoch) |
| |
| |
| class JobRunner(object): |
| """ |
| Implement the runtime logic for jobs with checkpointing at the level of |
| epoch. Can be used to run either single-host or distributed jobs. Job |
| runner is a callable to be called once from the client, passing a Session |
| as argument. This call will block until the Job execution is complete. |
| |
| If a checkpoint_manager is passed, checkpoints will be taken after |
| initialization and after each epoch execution. If, in addition, |
| `resume_from_epoch` is an epoch number, the corresponding checkpoint will |
| be loaded and job execution will continue from the given epoch. In |
| this case, the job's init_group will not be run. |
| |
| Refer to checkpoint_test.py for an example. |
| """ |
| def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None): |
| self.resume_from_epoch = resume_from_epoch |
| self.checkpoint = checkpoint_manager |
| self.job = job |
| |
| def __call__(self, client): |
| from_scratch = self.resume_from_epoch is None |
| if from_scratch: |
| client.run(self.job.init_group) |
| |
| if self.checkpoint: |
| logger.info('Preparing checkpoint ...') |
| client.run(self.checkpoint.init( |
| self.job.nodes_to_checkpoint(), |
| retrieve_from_epoch=self.resume_from_epoch)) |
| if from_scratch: |
| logger.info('Saving first checkpoint ...') |
| client.run(self.checkpoint.save(0)) |
| logger.info('First checkpoint saved.') |
| else: |
| logger.info('Loading checkpoint for epoch {} ...'.format( |
| self.resume_from_epoch)) |
| client.run(self.checkpoint.load(self.resume_from_epoch)) |
| logger.info('Checkpoint loaded.') |
| |
| epoch = 1 if from_scratch else self.resume_from_epoch + 1 |
| while True: |
| logger.info('Starting epoch %d.' % epoch) |
| client.run(self.job.epoch_group) |
| logger.info('Ran epoch %d.' % epoch) |
| stop_signals = [o.fetch() for o in self.job.stop_signals] |
| |
| if self.checkpoint: |
| logger.info('Saving checkpoint ...') |
| client.run(self.checkpoint.save(epoch)) |
| logger.info('Checkpoint saved.') |
| |
| if any(stop_signals): |
| logger.info('Stopping.') |
| break |
| epoch += 1 |
| client.run(self.job.exit_group) |
| return epoch |
| |
| |
| def epoch_limiter(num_epochs): |
| """ |
| Creates a task that will output True when a given |
| number of epochs has finished. |
| """ |
| with Job.current().init_group: |
| init_net = core.Net('epoch_counter_init') |
| counter = init_net.CreateCounter([], init_count=num_epochs - 1) |
| Task(step=init_net) |
| epoch_net = core.Net('epoch_countdown') |
| finished = epoch_net.CountDown(counter) |
| output = Task(step=epoch_net, outputs=finished).outputs()[0] |
| Job.current().add_stop_signal(output) |