| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| from __future__ import unicode_literals |
| |
| from caffe2.python.dataio import ReaderWithLimit |
| from caffe2.python.dataset import Dataset |
| from caffe2.python.pipeline import pipe |
| from caffe2.python.schema import Struct, NewRecord, FeedRecord |
| from caffe2.python.session import LocalSession |
| from caffe2.python.task import TaskGroup, final_output, WorkspaceType |
| from caffe2.python.test_util import TestCase |
| from caffe2.python import core, workspace |
| from caffe2.python.net_builder import ops |
| import numpy as np |
| |
| |
| def init_dataset(ws): |
| src_init = core.Net('src_init') |
| with core.NameScope('src'): |
| src_values = Struct(('label', np.array(range(100)))) |
| src_blobs = NewRecord(src_init, src_values) |
| src_ds = Dataset(src_blobs) |
| FeedRecord(src_blobs, src_values, ws) |
| ws.run(src_init) |
| return src_ds |
| |
| |
| class TestReaderWithLimit(TestCase): |
| def test_runtime_threads(self): |
| ws = workspace.C.Workspace() |
| session = LocalSession(ws) |
| src_ds = init_dataset(ws) |
| totals = [None] * 3 |
| |
| def proc(rec): |
| # executed once |
| with ops.task_init(): |
| counter1 = ops.CreateCounter([], ['global_counter']) |
| counter2 = ops.CreateCounter([], ['global_counter2']) |
| counter3 = ops.CreateCounter([], ['global_counter3']) |
| # executed once per thread |
| with ops.task_instance_init(): |
| task_counter = ops.CreateCounter([], ['task_counter']) |
| # executed on each iteration |
| ops.CountUp(counter1) |
| ops.CountUp(task_counter) |
| # executed once per thread |
| with ops.task_instance_exit(): |
| with ops.loop(ops.RetrieveCount(task_counter)): |
| ops.CountUp(counter2) |
| ops.CountUp(counter3) |
| # executed once |
| with ops.task_exit(): |
| totals[0] = final_output(ops.RetrieveCount(counter1)) |
| totals[1] = final_output(ops.RetrieveCount(counter2)) |
| totals[2] = final_output(ops.RetrieveCount(counter3)) |
| return rec |
| |
| """ 1. Feed full dataset """ |
| with TaskGroup() as tg: |
| pipe(src_ds.reader(), num_runtime_threads=8, processor=proc) |
| session.run(tg) |
| self.assertEquals(totals[0].fetch(), 100) |
| self.assertEquals(totals[1].fetch(), 100) |
| self.assertEquals(totals[2].fetch(), 8) |
| |
| """ 2. Add a few steps in between """ |
| with TaskGroup() as tg: |
| q1 = pipe(src_ds.reader(), num_runtime_threads=2) |
| q2 = pipe( |
| ReaderWithLimit(q1.reader(), num_iter=25), |
| num_runtime_threads=3) |
| pipe(q2, processor=proc, num_runtime_threads=6) |
| session.run(tg) |
| self.assertEquals(totals[0].fetch(), 25) |
| self.assertEquals(totals[1].fetch(), 25) |
| self.assertEquals(totals[2].fetch(), 6) |
| |
| |
| def test_reader_with_limit(self): |
| ws = workspace.C.Workspace() |
| session = LocalSession(ws) |
| |
| """ 1. feed full dataset """ |
| src_ds = init_dataset(ws) |
| |
| """ 2. Read with limit smaller than size of dataset """ |
| dst_init = core.Net('dst_init') |
| with core.NameScope('dst'): |
| dst_ds = Dataset(src_ds.content().clone_schema()) |
| dst_ds.init_empty(dst_init) |
| ws.run(dst_init) |
| |
| # WorkspaceType.GLOBAL is required because we are fetching |
| # reader.data_finished() after the TaskGroup finishes. |
| with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg: |
| reader = ReaderWithLimit(src_ds.reader(), num_iter=10) |
| pipe(reader, dst_ds.writer(), num_threads=8) |
| session.run(tg) |
| |
| self.assertFalse(ws.blobs[str(reader.data_finished())].fetch()) |
| self.assertEquals( |
| sorted(ws.blobs[str(dst_ds.content().label())].fetch()), |
| list(range(10)) |
| ) |
| |
| """ 3. Read with limit larger than size of dataset """ |
| ws.run(dst_init) |
| with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg: |
| reader = ReaderWithLimit(src_ds.reader(), num_iter=110) |
| pipe(reader, dst_ds.writer(), num_runtime_threads=8) |
| session.run(tg) |
| self.assertEquals( |
| sorted(ws.blobs[str(dst_ds.content().label())].fetch()), |
| list(range(100)) |
| ) |
| self.assertTrue(ws.blobs[str(reader.data_finished())].fetch()) |
| |
| """ 4. Read without counter """ |
| ws.run(dst_init) |
| with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg: |
| reader = ReaderWithLimit(src_ds.reader(), num_iter=None) |
| pipe(reader, dst_ds.writer(), num_threads=8) |
| session.run(tg) |
| self.assertEquals( |
| sorted(ws.blobs[str(dst_ds.content().label())].fetch()), |
| list(range(100)) |
| ) |
| self.assertTrue(ws.blobs[str(reader.data_finished())].fetch()) |
| |
| """ 5. Read using the same reader without resetting workspace """ |
| session.run(tg) |
| self.assertEquals( |
| sorted(ws.blobs[str(dst_ds.content().label())].fetch()), |
| sorted(list(range(100)) * 2) |
| ) |