blob: a38c0880077641b4e582d8ee966c39b37ab7a012 [file] [log] [blame]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python.schema import Struct, ConstRecord
from caffe2.python import core, workspace
from caffe2.python.session import LocalSession
from caffe2.python.dataset import Dataset
from caffe2.python.pipeline import pipe
from caffe2.python.checkpoint import (
CheckpointManager, MultiNodeCheckpointManager, Job, JobRunner)
from caffe2.python.task import Task, Node
from caffe2.python.test_util import TestCase
from caffe2.python.dataio import ReaderWithLimit
import tempfile
import numpy as np
import shutil
def build_job(num_nodes):
all_outputs = []
with Job() as job:
for node_id in range(num_nodes):
with Node('reader' + str(node_id)):
with job.init_group:
init_net = core.Net('init_net' + str(node_id))
data_arr = Struct(('val', np.array(range(10))))
data = ConstRecord(init_net, data_arr)
ds = Dataset(data, name='dataset' + str(node_id))
full_reader = ds.reader(init_net)
total = init_net.Const([100])
Task(step=init_net)
def inc_total(rec):
net = core.Net('inc_total' + str(node_id))
net.Add([total, rec.val()], [total])
return [net]
epoch_reader = ReaderWithLimit(full_reader, num_iter=3)
pipe(epoch_reader, processor=inc_total)
job.add_stop_signal(epoch_reader.data_finished())
all_outputs.append(total)
total_fetcher = Task(step=core.Net('empty'), outputs=all_outputs)
return job, total_fetcher
EXPECTED_TOTALS = [103, 115, 136, 145]
class TestCheckpoint(TestCase):
def run_with(self, builder):
job, output_fetcher = build_job(num_nodes=1)
def fetch_total(session):
session.run(output_fetcher)
return output_fetcher.outputs()[0].fetch()
session, checkpoint = builder()
compiled_job = job.compile(LocalSession)
num_epochs = JobRunner(compiled_job, checkpoint)(session)
self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
for initial_epoch in range(1, num_epochs + 1):
session, checkpoint = builder()
JobRunner(
compiled_job,
checkpoint, resume_from_epoch=initial_epoch)(session)
self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
for epoch in range(1, num_epochs + 1):
session.run(checkpoint.load(epoch))
self.assertEquals(fetch_total(session), EXPECTED_TOTALS[epoch - 1])
def test_single_checkpoint(self):
# test single node
with tempfile.NamedTemporaryFile() as tmp:
def builder():
ws = workspace.C.Workspace()
session = LocalSession(ws)
checkpoint = CheckpointManager(tmp.name, 'minidb')
return session, checkpoint
self.run_with(builder)
# test multi-node
try:
tmpdir = tempfile.mkdtemp()
def builder():
ws = workspace.C.Workspace()
session = LocalSession(ws)
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
return session, checkpoint
self.run_with(builder)
finally:
shutil.rmtree(tmpdir)
def test_load_model_from_checkpoints(self):
try:
tmpdir = tempfile.mkdtemp()
ws = workspace.C.Workspace()
session = LocalSession(ws)
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
job, output_fetcher = build_job(num_nodes=3)
compiled_job = job.compile(LocalSession)
job_runner = JobRunner(compiled_job, checkpoint)
num_epochs = job_runner(session)
self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
# There are 44 blobs after finishing up the job runner.
self.assertEquals(len(ws.blobs), 44)
ws = workspace.C.Workspace()
session = LocalSession(ws)
self.assertEquals(len(ws.blobs), 0)
model_blob_names = ['init_net0/GivenTensorInt64Fill:0',
'init_net1/GivenTensorInt64Fill:0']
job_runner.load_blobs_from_checkpoints(blob_names=model_blob_names,
epoch=1, session=session)
# In addition to the two model blobs, we also have 3 output blobs
# and one runnable blob. So there are 6 blobs in total.
self.assertEquals(len(ws.blobs), 6)
# Check that all the model blobs are loaded.
for blob_name in model_blob_names:
self.assertTrue(ws.has_blob(blob_name))
self.assertEquals(ws.fetch_blob(blob_name), np.array([103]))
finally:
shutil.rmtree(tmpdir)