blob: 9494944c159fb5c2f3892b2843319d00df69d12c [file] [log] [blame]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python.dataio import ReaderWithLimit
from caffe2.python.dataset import Dataset
from caffe2.python.pipeline import pipe
from caffe2.python.schema import Struct, NewRecord, FeedRecord
from caffe2.python.session import LocalSession
from caffe2.python.task import TaskGroup
from caffe2.python.test_util import TestCase
from caffe2.python import core, workspace
import numpy as np
class TestReaderWithLimit(TestCase):
def test_reader_with_limit(self):
ws = workspace.C.Workspace()
session = LocalSession(ws)
""" 1. feed full dataset """
src_init = core.Net('src_init')
src_values = Struct(('label', np.array(range(100))))
src_blobs = NewRecord(src_init, src_values)
src_ds = Dataset(src_blobs)
FeedRecord(src_blobs, src_values, ws)
ws.run(src_init)
""" 2. Read with limit smaller than size of dataset """
dst_init = core.Net('dst_init')
dst_ds = Dataset(src_values.clone_schema())
dst_ds.init_empty(dst_init)
ws.run(dst_init)
with TaskGroup() as tg:
reader = ReaderWithLimit(src_ds.reader(), num_iter=10)
pipe(reader, dst_ds.writer(), num_threads=8)
session.run(tg)
self.assertFalse(ws.blobs[str(reader.data_finished())].fetch())
self.assertEquals(
sorted(ws.blobs[str(dst_ds.content().label())].fetch()), range(10))
""" 3. Read with limit larger than size of dataset """
ws.run(dst_init)
with TaskGroup() as tg:
reader = ReaderWithLimit(src_ds.reader(), num_iter=110)
pipe(reader, dst_ds.writer(), num_threads=8)
session.run(tg)
self.assertEquals(
sorted(ws.blobs[str(dst_ds.content().label())].fetch()), range(100))
self.assertTrue(ws.blobs[str(reader.data_finished())].fetch())