| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| from __future__ import unicode_literals |
| |
| from caffe2.python.schema import Struct, ConstRecord |
| from caffe2.python import core, workspace |
| from caffe2.python.session import LocalSession |
| from caffe2.python.dataset import Dataset |
| from caffe2.python.pipeline import pipe |
| from caffe2.python.checkpoint import ( |
| CheckpointManager, MultiNodeCheckpointManager, Job, JobRunner) |
| from caffe2.python.task import Task, Node |
| from caffe2.python.test_util import TestCase |
| from caffe2.python.dataio import ReaderWithLimit |
| import tempfile |
| import numpy as np |
| import shutil |
| |
| |
| def build_job(num_nodes): |
| all_outputs = [] |
| with Job() as job: |
| for node_id in range(num_nodes): |
| with Node('reader' + str(node_id)): |
| with job.init_group: |
| init_net = core.Net('init_net' + str(node_id)) |
| data_arr = Struct(('val', np.array(range(10)))) |
| data = ConstRecord(init_net, data_arr) |
| ds = Dataset(data, name='dataset' + str(node_id)) |
| full_reader = ds.reader(init_net) |
| total = init_net.Const([100]) |
| Task(step=init_net) |
| |
| def inc_total(rec): |
| net = core.Net('inc_total' + str(node_id)) |
| net.Add([total, rec.val()], [total]) |
| return [net] |
| |
| epoch_reader = ReaderWithLimit(full_reader, num_iter=3) |
| pipe(epoch_reader, processor=inc_total) |
| job.add_stop_signal(epoch_reader.data_finished()) |
| all_outputs.append(total) |
| |
| total_fetcher = Task(step=core.Net('empty'), outputs=all_outputs) |
| return job, total_fetcher |
| |
| EXPECTED_TOTALS = [103, 115, 136, 145] |
| |
| |
| class TestCheckpoint(TestCase): |
| def run_with(self, builder): |
| job, output_fetcher = build_job(num_nodes=1) |
| |
| def fetch_total(session): |
| session.run(output_fetcher) |
| return output_fetcher.outputs()[0].fetch() |
| |
| session, checkpoint = builder() |
| compiled_job = job.compile(LocalSession) |
| num_epochs = JobRunner(compiled_job, checkpoint)(session) |
| self.assertEquals(num_epochs, len(EXPECTED_TOTALS)) |
| self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1]) |
| |
| for initial_epoch in range(1, num_epochs + 1): |
| session, checkpoint = builder() |
| JobRunner( |
| compiled_job, |
| checkpoint, resume_from_epoch=initial_epoch)(session) |
| self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1]) |
| |
| for epoch in range(1, num_epochs + 1): |
| session.run(checkpoint.load(epoch)) |
| self.assertEquals(fetch_total(session), EXPECTED_TOTALS[epoch - 1]) |
| |
| def test_single_checkpoint(self): |
| # test single node |
| with tempfile.NamedTemporaryFile() as tmp: |
| |
| def builder(): |
| ws = workspace.C.Workspace() |
| session = LocalSession(ws) |
| checkpoint = CheckpointManager(tmp.name, 'minidb') |
| return session, checkpoint |
| |
| self.run_with(builder) |
| |
| # test multi-node |
| try: |
| tmpdir = tempfile.mkdtemp() |
| |
| def builder(): |
| ws = workspace.C.Workspace() |
| session = LocalSession(ws) |
| checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb') |
| return session, checkpoint |
| |
| self.run_with(builder) |
| finally: |
| shutil.rmtree(tmpdir) |
| |
| def test_load_model_from_checkpoints(self): |
| try: |
| tmpdir = tempfile.mkdtemp() |
| ws = workspace.C.Workspace() |
| session = LocalSession(ws) |
| checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb') |
| |
| job, output_fetcher = build_job(num_nodes=3) |
| compiled_job = job.compile(LocalSession) |
| job_runner = JobRunner(compiled_job, checkpoint) |
| num_epochs = job_runner(session) |
| self.assertEquals(num_epochs, len(EXPECTED_TOTALS)) |
| # There are 44 blobs after finishing up the job runner. |
| self.assertEquals(len(ws.blobs), 44) |
| |
| ws = workspace.C.Workspace() |
| session = LocalSession(ws) |
| self.assertEquals(len(ws.blobs), 0) |
| model_blob_names = ['init_net0/GivenTensorInt64Fill:0', |
| 'init_net1/GivenTensorInt64Fill:0'] |
| job_runner.load_blobs_from_checkpoints(blob_names=model_blob_names, |
| epoch=1, session=session) |
| # In addition to the two model blobs, we also have 3 output blobs |
| # and one runnable blob. So there are 6 blobs in total. |
| self.assertEquals(len(ws.blobs), 6) |
| # Check that all the model blobs are loaded. |
| for blob_name in model_blob_names: |
| self.assertTrue(ws.has_blob(blob_name)) |
| self.assertEquals(ws.fetch_blob(blob_name), np.array([103])) |
| |
| finally: |
| shutil.rmtree(tmpdir) |