|  |  | 
|  |  | 
|  |  | 
|  |  | 
|  |  | 
|  | from caffe2.python.dataio import ( | 
|  | CompositeReader, | 
|  | CompositeReaderBuilder, | 
|  | Reader, | 
|  | ReaderBuilder, | 
|  | ReaderWithDelay, | 
|  | ReaderWithLimit, | 
|  | ReaderWithTimeLimit, | 
|  | ) | 
|  | from caffe2.python.dataset import Dataset | 
|  | from caffe2.python.db_file_reader import DBFileReader | 
|  | from caffe2.python.pipeline import pipe | 
|  | from caffe2.python.schema import Struct, NewRecord, FeedRecord | 
|  | from caffe2.python.session import LocalSession | 
|  | from caffe2.python.task import TaskGroup, final_output, WorkspaceType | 
|  | from caffe2.python.test_util import TestCase | 
|  | from caffe2.python.cached_reader import CachedReader | 
|  | from caffe2.python import core, workspace, schema | 
|  | from caffe2.python.net_builder import ops | 
|  |  | 
|  | import numpy as np | 
|  | import numpy.testing as npt | 
|  | import os | 
|  | import shutil | 
|  | import unittest | 
|  | import tempfile | 
|  | import time | 
|  |  | 
|  |  | 
|  | def make_source_dataset(ws, size=100, offset=0, name=None): | 
|  | name = name or "src" | 
|  | src_init = core.Net("{}_init".format(name)) | 
|  | with core.NameScope(name): | 
|  | src_values = Struct(('label', np.array(range(offset, offset + size)))) | 
|  | src_blobs = NewRecord(src_init, src_values) | 
|  | src_ds = Dataset(src_blobs, name=name) | 
|  | FeedRecord(src_blobs, src_values, ws) | 
|  | ws.run(src_init) | 
|  | return src_ds | 
|  |  | 
|  |  | 
|  | def make_destination_dataset(ws, schema, name=None): | 
|  | name = name or 'dst' | 
|  | dst_init = core.Net('{}_init'.format(name)) | 
|  | with core.NameScope(name): | 
|  | dst_ds = Dataset(schema, name=name) | 
|  | dst_ds.init_empty(dst_init) | 
|  | ws.run(dst_init) | 
|  | return dst_ds | 
|  |  | 
|  |  | 
|  | class TestReaderBuilder(ReaderBuilder): | 
|  | def __init__(self, name, size, offset): | 
|  | self._schema = schema.Struct( | 
|  | ('label', schema.Scalar()), | 
|  | ) | 
|  | self._name = name | 
|  | self._size = size | 
|  | self._offset = offset | 
|  | self._src_ds = None | 
|  |  | 
|  | def schema(self): | 
|  | return self._schema | 
|  |  | 
|  | def setup(self, ws): | 
|  | self._src_ds = make_source_dataset(ws, offset=self._offset, size=self._size, | 
|  | name=self._name) | 
|  | return {} | 
|  |  | 
|  | def new_reader(self, **kwargs): | 
|  | return self._src_ds | 
|  |  | 
|  |  | 
|  | class TestCompositeReader(TestCase): | 
|  | @unittest.skipIf(os.environ.get('JENKINS_URL'), 'Flaky test on Jenkins') | 
|  | def test_composite_reader(self): | 
|  | ws = workspace.C.Workspace() | 
|  | session = LocalSession(ws) | 
|  | num_srcs = 3 | 
|  | names = ["src_{}".format(i) for i in range(num_srcs)] | 
|  | size = 100 | 
|  | offsets = [i * size for i in range(num_srcs)] | 
|  | src_dses = [make_source_dataset(ws, offset=offset, size=size, name=name) | 
|  | for (name, offset) in zip(names, offsets)] | 
|  |  | 
|  | data = [ws.fetch_blob(str(src.field_blobs[0])) for src in src_dses] | 
|  | # Sanity check we didn't overwrite anything | 
|  | for d, offset in zip(data, offsets): | 
|  | npt.assert_array_equal(d, range(offset, offset + size)) | 
|  |  | 
|  | # Make an identically-sized empty destination dataset | 
|  | dst_ds_schema = schema.Struct( | 
|  | *[ | 
|  | (name, src_ds.content().clone_schema()) | 
|  | for name, src_ds in zip(names, src_dses) | 
|  | ] | 
|  | ) | 
|  | dst_ds = make_destination_dataset(ws, dst_ds_schema) | 
|  |  | 
|  | with TaskGroup() as tg: | 
|  | reader = CompositeReader(names, | 
|  | [src_ds.reader() for src_ds in src_dses]) | 
|  | pipe(reader, dst_ds.writer(), num_runtime_threads=3) | 
|  | session.run(tg) | 
|  |  | 
|  | for i in range(num_srcs): | 
|  | written_data = sorted( | 
|  | ws.fetch_blob(str(dst_ds.content()[names[i]].label()))) | 
|  | npt.assert_array_equal(data[i], written_data, "i: {}".format(i)) | 
|  |  | 
|  | @unittest.skipIf(os.environ.get('JENKINS_URL'), 'Flaky test on Jenkins') | 
|  | def test_composite_reader_builder(self): | 
|  | ws = workspace.C.Workspace() | 
|  | session = LocalSession(ws) | 
|  | num_srcs = 3 | 
|  | names = ["src_{}".format(i) for i in range(num_srcs)] | 
|  | size = 100 | 
|  | offsets = [i * size for i in range(num_srcs)] | 
|  | src_ds_builders = [ | 
|  | TestReaderBuilder(offset=offset, size=size, name=name) | 
|  | for (name, offset) in zip(names, offsets) | 
|  | ] | 
|  |  | 
|  | # Make an identically-sized empty destination dataset | 
|  | dst_ds_schema = schema.Struct( | 
|  | *[ | 
|  | (name, src_ds_builder.schema()) | 
|  | for name, src_ds_builder in zip(names, src_ds_builders) | 
|  | ] | 
|  | ) | 
|  | dst_ds = make_destination_dataset(ws, dst_ds_schema) | 
|  |  | 
|  | with TaskGroup() as tg: | 
|  | reader_builder = CompositeReaderBuilder( | 
|  | names, src_ds_builders) | 
|  | reader_builder.setup(ws=ws) | 
|  | pipe(reader_builder.new_reader(), dst_ds.writer(), | 
|  | num_runtime_threads=3) | 
|  | session.run(tg) | 
|  |  | 
|  | for name, offset in zip(names, offsets): | 
|  | written_data = sorted( | 
|  | ws.fetch_blob(str(dst_ds.content()[name].label()))) | 
|  | npt.assert_array_equal(range(offset, offset + size), written_data, | 
|  | "name: {}".format(name)) | 
|  |  | 
|  |  | 
|  | class TestReaderWithLimit(TestCase): | 
|  | def test_runtime_threads(self): | 
|  | ws = workspace.C.Workspace() | 
|  | session = LocalSession(ws) | 
|  | src_ds = make_source_dataset(ws) | 
|  | totals = [None] * 3 | 
|  |  | 
|  | def proc(rec): | 
|  | # executed once | 
|  | with ops.task_init(): | 
|  | counter1 = ops.CreateCounter([], ['global_counter']) | 
|  | counter2 = ops.CreateCounter([], ['global_counter2']) | 
|  | counter3 = ops.CreateCounter([], ['global_counter3']) | 
|  | # executed once per thread | 
|  | with ops.task_instance_init(): | 
|  | task_counter = ops.CreateCounter([], ['task_counter']) | 
|  | # executed on each iteration | 
|  | ops.CountUp(counter1) | 
|  | ops.CountUp(task_counter) | 
|  | # executed once per thread | 
|  | with ops.task_instance_exit(): | 
|  | with ops.loop(ops.RetrieveCount(task_counter)): | 
|  | ops.CountUp(counter2) | 
|  | ops.CountUp(counter3) | 
|  | # executed once | 
|  | with ops.task_exit(): | 
|  | totals[0] = final_output(ops.RetrieveCount(counter1)) | 
|  | totals[1] = final_output(ops.RetrieveCount(counter2)) | 
|  | totals[2] = final_output(ops.RetrieveCount(counter3)) | 
|  | return rec | 
|  |  | 
|  | # Read full data set from original reader | 
|  | with TaskGroup() as tg: | 
|  | pipe(src_ds.reader(), num_runtime_threads=8, processor=proc) | 
|  | session.run(tg) | 
|  | self.assertEqual(totals[0].fetch(), 100) | 
|  | self.assertEqual(totals[1].fetch(), 100) | 
|  | self.assertEqual(totals[2].fetch(), 8) | 
|  |  | 
|  | # Read with a count-limited reader | 
|  | with TaskGroup() as tg: | 
|  | q1 = pipe(src_ds.reader(), num_runtime_threads=2) | 
|  | q2 = pipe( | 
|  | ReaderWithLimit(q1.reader(), num_iter=25), | 
|  | num_runtime_threads=3) | 
|  | pipe(q2, processor=proc, num_runtime_threads=6) | 
|  | session.run(tg) | 
|  | self.assertEqual(totals[0].fetch(), 25) | 
|  | self.assertEqual(totals[1].fetch(), 25) | 
|  | self.assertEqual(totals[2].fetch(), 6) | 
|  |  | 
|  | def _test_limit_reader_init_shared(self, size): | 
|  | ws = workspace.C.Workspace() | 
|  | session = LocalSession(ws) | 
|  |  | 
|  | # Make source dataset | 
|  | src_ds = make_source_dataset(ws, size=size) | 
|  |  | 
|  | # Make an identically-sized empty destination Dataset | 
|  | dst_ds = make_destination_dataset(ws, src_ds.content().clone_schema()) | 
|  |  | 
|  | return ws, session, src_ds, dst_ds | 
|  |  | 
|  | def _test_limit_reader_shared(self, reader_class, size, expected_read_len, | 
|  | expected_read_len_threshold, | 
|  | expected_finish, num_threads, read_delay, | 
|  | **limiter_args): | 
|  | ws, session, src_ds, dst_ds = \ | 
|  | self._test_limit_reader_init_shared(size) | 
|  |  | 
|  | # Read without limiter | 
|  | # WorkspaceType.GLOBAL is required because we are fetching | 
|  | # reader.data_finished() after the TaskGroup finishes. | 
|  | with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg: | 
|  | if read_delay > 0: | 
|  | reader = reader_class(ReaderWithDelay(src_ds.reader(), | 
|  | read_delay), | 
|  | **limiter_args) | 
|  | else: | 
|  | reader = reader_class(src_ds.reader(), **limiter_args) | 
|  | pipe(reader, dst_ds.writer(), num_runtime_threads=num_threads) | 
|  | session.run(tg) | 
|  | read_len = len(sorted(ws.blobs[str(dst_ds.content().label())].fetch())) | 
|  |  | 
|  | # Do a fuzzy match (expected_read_len +/- expected_read_len_threshold) | 
|  | # to eliminate flakiness for time-limited tests | 
|  | self.assertGreaterEqual( | 
|  | read_len, | 
|  | expected_read_len - expected_read_len_threshold) | 
|  | self.assertLessEqual( | 
|  | read_len, | 
|  | expected_read_len + expected_read_len_threshold) | 
|  | self.assertEqual( | 
|  | sorted(ws.blobs[str(dst_ds.content().label())].fetch()), | 
|  | list(range(read_len)) | 
|  | ) | 
|  | self.assertEqual(ws.blobs[str(reader.data_finished())].fetch(), | 
|  | expected_finish) | 
|  |  | 
|  | def test_count_limit_reader_without_limit(self): | 
|  | # No iter count specified, should read all records. | 
|  | self._test_limit_reader_shared(ReaderWithLimit, | 
|  | size=100, | 
|  | expected_read_len=100, | 
|  | expected_read_len_threshold=0, | 
|  | expected_finish=True, | 
|  | num_threads=8, | 
|  | read_delay=0, | 
|  | num_iter=None) | 
|  |  | 
|  | def test_count_limit_reader_with_zero_limit(self): | 
|  | # Zero iter count specified, should read 0 records. | 
|  | self._test_limit_reader_shared(ReaderWithLimit, | 
|  | size=100, | 
|  | expected_read_len=0, | 
|  | expected_read_len_threshold=0, | 
|  | expected_finish=False, | 
|  | num_threads=8, | 
|  | read_delay=0, | 
|  | num_iter=0) | 
|  |  | 
|  | def test_count_limit_reader_with_low_limit(self): | 
|  | # Read with limit smaller than size of dataset | 
|  | self._test_limit_reader_shared(ReaderWithLimit, | 
|  | size=100, | 
|  | expected_read_len=10, | 
|  | expected_read_len_threshold=0, | 
|  | expected_finish=False, | 
|  | num_threads=8, | 
|  | read_delay=0, | 
|  | num_iter=10) | 
|  |  | 
|  | def test_count_limit_reader_with_high_limit(self): | 
|  | # Read with limit larger than size of dataset | 
|  | self._test_limit_reader_shared(ReaderWithLimit, | 
|  | size=100, | 
|  | expected_read_len=100, | 
|  | expected_read_len_threshold=0, | 
|  | expected_finish=True, | 
|  | num_threads=8, | 
|  | read_delay=0, | 
|  | num_iter=110) | 
|  |  | 
|  | def test_time_limit_reader_without_limit(self): | 
|  | # No duration specified, should read all records. | 
|  | self._test_limit_reader_shared(ReaderWithTimeLimit, | 
|  | size=100, | 
|  | expected_read_len=100, | 
|  | expected_read_len_threshold=0, | 
|  | expected_finish=True, | 
|  | num_threads=8, | 
|  | read_delay=0.1, | 
|  | duration=0) | 
|  |  | 
|  | def test_time_limit_reader_with_short_limit(self): | 
|  | # Read with insufficient time limit | 
|  | size = 50 | 
|  | num_threads = 4 | 
|  | sleep_duration = 0.25 | 
|  | duration = 1 | 
|  | expected_read_len = int(round(num_threads * duration / sleep_duration)) | 
|  | # Because the time limit check happens before the delay + read op, | 
|  | # subtract a little bit of time to ensure we don't get in an extra read | 
|  | duration = duration - 0.25 * sleep_duration | 
|  |  | 
|  | # NOTE: `expected_read_len_threshold` was added because this test case | 
|  | # has significant execution variation under stress. Under stress, we may | 
|  | # read strictly less than the expected # of samples; anywhere from | 
|  | # [0,N] where N = expected_read_len. | 
|  | # Hence we set expected_read_len to N/2, plus or minus N/2. | 
|  | self._test_limit_reader_shared(ReaderWithTimeLimit, | 
|  | size=size, | 
|  | expected_read_len=expected_read_len / 2, | 
|  | expected_read_len_threshold=expected_read_len / 2, | 
|  | expected_finish=False, | 
|  | num_threads=num_threads, | 
|  | read_delay=sleep_duration, | 
|  | duration=duration) | 
|  |  | 
|  | def test_time_limit_reader_with_long_limit(self): | 
|  | # Read with ample time limit | 
|  | # NOTE: we don't use `expected_read_len_threshold` because the duration, | 
|  | # read_delay, and # threads should be more than sufficient | 
|  | self._test_limit_reader_shared(ReaderWithTimeLimit, | 
|  | size=50, | 
|  | expected_read_len=50, | 
|  | expected_read_len_threshold=0, | 
|  | expected_finish=True, | 
|  | num_threads=4, | 
|  | read_delay=0.2, | 
|  | duration=10) | 
|  |  | 
|  |  | 
|  | class TestDBFileReader(TestCase): | 
|  | def setUp(self): | 
|  | self.temp_paths = [] | 
|  |  | 
|  | def tearDown(self): | 
|  | # In case any test method fails, clean up temp paths. | 
|  | for path in self.temp_paths: | 
|  | self._delete_path(path) | 
|  |  | 
|  | @staticmethod | 
|  | def _delete_path(path): | 
|  | if os.path.isfile(path): | 
|  | os.remove(path)  # Remove file. | 
|  | elif os.path.isdir(path): | 
|  | shutil.rmtree(path)  # Remove dir recursively. | 
|  |  | 
|  | def _make_temp_path(self): | 
|  | # Make a temp path as db_path. | 
|  | with tempfile.NamedTemporaryFile() as f: | 
|  | temp_path = f.name | 
|  | self.temp_paths.append(temp_path) | 
|  | return temp_path | 
|  |  | 
|  | @staticmethod | 
|  | def _build_source_reader(ws, size): | 
|  | src_ds = make_source_dataset(ws, size) | 
|  | return src_ds.reader() | 
|  |  | 
|  | @staticmethod | 
|  | def _read_all_data(ws, reader, session): | 
|  | dst_ds = make_destination_dataset(ws, reader.schema().clone_schema()) | 
|  |  | 
|  | with TaskGroup() as tg: | 
|  | pipe(reader, dst_ds.writer(), num_runtime_threads=8) | 
|  | session.run(tg) | 
|  |  | 
|  | return ws.blobs[str(dst_ds.content().label())].fetch() | 
|  |  | 
|  | @unittest.skipIf("LevelDB" not in core.C.registered_dbs(), "Need LevelDB") | 
|  | def test_cached_reader(self): | 
|  | ws = workspace.C.Workspace() | 
|  | session = LocalSession(ws) | 
|  | db_path = self._make_temp_path() | 
|  |  | 
|  | # Read data for the first time. | 
|  | cached_reader1 = CachedReader( | 
|  | self._build_source_reader(ws, 100), db_path, loop_over=False, | 
|  | ) | 
|  | build_cache_step = cached_reader1.build_cache_step() | 
|  | session.run(build_cache_step) | 
|  |  | 
|  | data = self._read_all_data(ws, cached_reader1, session) | 
|  | self.assertEqual(sorted(data), list(range(100))) | 
|  |  | 
|  | # Read data from cache. | 
|  | cached_reader2 = CachedReader( | 
|  | self._build_source_reader(ws, 200), db_path, | 
|  | ) | 
|  | build_cache_step = cached_reader2.build_cache_step() | 
|  | session.run(build_cache_step) | 
|  |  | 
|  | data = self._read_all_data(ws, cached_reader2, session) | 
|  | self.assertEqual(sorted(data), list(range(100))) | 
|  |  | 
|  | self._delete_path(db_path) | 
|  |  | 
|  | # We removed cache so we expect to receive data from original reader. | 
|  | cached_reader3 = CachedReader( | 
|  | self._build_source_reader(ws, 300), db_path, | 
|  | ) | 
|  | build_cache_step = cached_reader3.build_cache_step() | 
|  | session.run(build_cache_step) | 
|  |  | 
|  | data = self._read_all_data(ws, cached_reader3, session) | 
|  | self.assertEqual(sorted(data), list(range(300))) | 
|  |  | 
|  | self._delete_path(db_path) | 
|  |  | 
|  | @unittest.skipIf("LevelDB" not in core.C.registered_dbs(), "Need LevelDB") | 
|  | def test_db_file_reader(self): | 
|  | ws = workspace.C.Workspace() | 
|  | session = LocalSession(ws) | 
|  | db_path = self._make_temp_path() | 
|  |  | 
|  | # Build a cache DB file. | 
|  | cached_reader = CachedReader( | 
|  | self._build_source_reader(ws, 100), | 
|  | db_path=db_path, | 
|  | db_type='LevelDB', | 
|  | ) | 
|  | build_cache_step = cached_reader.build_cache_step() | 
|  | session.run(build_cache_step) | 
|  |  | 
|  | # Read data from cache DB file. | 
|  | db_file_reader = DBFileReader( | 
|  | db_path=db_path, | 
|  | db_type='LevelDB', | 
|  | ) | 
|  | data = self._read_all_data(ws, db_file_reader, session) | 
|  | self.assertEqual(sorted(data), list(range(100))) | 
|  |  | 
|  | self._delete_path(db_path) |