blob: edea20e8454cb011600a252300181719067f931c [file] [log] [blame]
"""
Defines the base interface for reading and writing operations.
Readers/Writers are objects that produce operations that read/write sequences
of data. Each operation reads or writes a list of BlobReferences.
Readers and Writers must be implemented such that read and write operations
are atomic and thread safe.
Examples of possible Readers and Writers:
HiveReader, HiveWriter,
QueueReader, QueueWriter,
DatasetReader, DatasetWriter,
DBReader, DBWriter,
See `dataset.py` for an example of implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core
from caffe2.python.schema import Field, Struct, from_blob_list
import numpy as np
class Reader(object):
def __init__(self, schema=None):
if schema is not None:
assert isinstance(schema, Field)
self._schema = schema
def schema(self):
"""
Return the schema associated with the Hive Reader
"""
assert self._schema is not None, 'Schema not provided for this reader.'
return self._schema
def _set_schema(self, schema):
self._schema = schema
def setup_ex(self, init_net, finish_net):
"""Nets to be executed once at startup and finish.
Experimental extension. Don't use yet"""
pass
def read_ex(self, local_init_net, local_finish_net):
"""Experimental extension to the interface. Don't use yet"""
read_net = core.Net('reader_body')
return ([read_net], ) + self.read(read_net)
def read_record_ex(self, local_init_net, local_finish_net):
"""Experimental extension to the interface. Don't use yet"""
nets, should_stop, fields = self.read_ex(
local_init_net, local_finish_net)
if self._schema:
fields = from_blob_list(self._schema, fields)
return nets, should_stop, fields
"""
Reader is a abstract class to be implemented in order to provide
operations capable of iterating through a dataset or stream of data.
A Reader must implement at least one operation, `read`, which
adds operations to a net that read the next batch of data. Readers can
optionally support the `reset` operation, which is useful when multiple
passes over the data are required.
"""
def read(self, read_net):
"""
Add operations to read_net that will read the read batch of data
and return a list of BlobReference representing the blobs that will
contain the batches produced.
Operations added to `read_net` must be thread safe and atomic, that is,
it should be possible to clone `read_net` and run multiple instances of
it in parallel.
Args:
read_net: the net that will be appended with read operations
Returns:
A tuple (should_stop, fields), with:
should_stop: BlobReference pointing to a boolean scalar
blob that indicates whether the read operation
was succesfull or whether the end of data has
been reached.
fields: A tuple of BlobReference containing the latest batch
of data that was read.
"""
raise NotImplementedError('Readers must implement `read`.')
def reset(self, net):
"""Append operations to `net` that will reset the reader.
This can be used to read the data multiple times.
Not all readers support this operation.
"""
raise NotImplementedError('This reader cannot be resetted.')
def read_record(self, read_net):
should_stop, fields = self.read(read_net)
if self._schema:
fields = from_blob_list(self._schema, fields)
return should_stop, fields
def execution_step(self, reader_net_name=None):
"""Create an execution step with a net containing read operators.
The execution step will contain a `stop_blob` that knows how to stop
the execution loop when end of data was reached.
E.g.:
read_step, fields = reader.execution_step()
consume_net = core.Net('consume')
consume_net.Print(fields[0], [])
p = core.Plan('reader')
p.AddStep(read_step.AddNet(consume_net))
core.RunPlan(p)
Args:
reader_net_name: (optional) the name of the reader_net to be
created. The execution step will
be named accordingly.
Returns:
A tuple (read_step, fields), with:
read_step: A newly created execution step containing a net with
read operations. The step will have `stop_blob` set,
in order to stop the loop on end of data.
fields: A tuple of BlobReference containing the latest batch
of data that was read.
"""
reader_net = core.Net(reader_net_name or 'reader')
should_stop, fields = self.read_record(reader_net)
read_step = core.execution_step(
'{}_step'.format(reader_net_name),
reader_net,
should_stop_blob=should_stop)
return (read_step, fields)
class Writer(object):
"""
Writer is a abstract class to be implemented in order to provide
operations capable of feeding a data stream or a dataset.
A Writer must implement 2 operations:
`write`, which adds operations to a net that write the write batch of
data, and `commit`, which adds operations to a net in order to indicate
that no more data will be written.
"""
_schema = None
def schema(self):
return self._schema
def write(self, writer_net, fields):
"""Add operations to `writer_net` that write the next batch of data.
Operations added to the net must be thread-safe and unique, that is:
multiple writers must be able to write to the dataset in parallel.
Args:
fields: a tuple of BlobReference containing the batch of data to
write.
"""
raise NotImplementedError('Writers must implement write.')
def write_record(self, writer_net, fields):
if isinstance(fields, Field):
self._schema = fields
fields = fields.field_blobs()
self.write(writer_net, fields)
def setup_ex(self, init_net, finish_net):
"""Experimental, don't use yet"""
self.commit(finish_net)
def write_ex(self, fields, local_init_net, local_finish_net, stop_blob):
"""Experimental extension to the interface. Don't use yet"""
write_net = core.Net('write_net')
self.write(write_net, fields)
return [write_net]
def write_record_ex(
self, fields, local_init_net, local_finish_net, stop_blob=None):
"""Experimental extension to the interface. Don't use yet."""
if isinstance(fields, Field):
self._schema = fields
fields = fields.field_blobs()
if stop_blob is None:
stop_blob = local_init_net.NextName("dequeue_status")
write_nets = self.write_ex(
fields, local_init_net, local_finish_net, stop_blob)
return (write_nets, stop_blob)
def commit(self, finish_net):
"""Add operations to `finish_net` that signal end of data.
This must be implemented by all Writers, but may be no-op for some
of them.
"""
pass
class ReaderBuilder(object):
""" Allow usage of a reader in distributed fashion. """
def schema(self):
raise NotImplementedError()
def enqueue_splits(self, net, split_queue):
raise NotImplementedError()
def splits(self, net):
raise NotImplementedError()
def new_reader(self, split_queue):
raise NotImplementedError()
class Pipe(object):
def __init__(self, schema=None, obj_key=None):
self._num_writers = 0
self._num_readers = 0
self._schema = schema
self._obj_key = obj_key
def schema(self):
return self._schema
def setup(self, global_init_net):
pass
def reader(self):
raise NotImplementedError()
def writer(self):
raise NotImplementedError()
def num_readers(self):
return self._num_readers
def num_writers(self):
return self._num_writers
def _new_writer(self, writer_schema, writer_init_net):
if writer_schema is not None and self._schema is None:
self._schema = writer_schema
self._num_writers += 1
if self._obj_key is not None:
writer_init_net.add_attribute(self._obj_key, self)
def _new_reader(self, reader_init_net):
self._num_readers += 1
if self._obj_key is not None:
reader_init_net.add_attribute(self._obj_key, self)
class CounterReader(Reader):
""" Reader that produces increasing integers. """
def __init__(self):
Reader.__init__(self, schema=Struct(('iter', np.int64)))
self.counter = None
self.should_stop = None
def setup_ex(self, global_init_net, global_finish_net):
if self.counter is None:
self.counter = global_init_net.CreateCounter([], init_count=0)
self.should_stop = global_init_net.ConstantFill(
[], shape=[], dtype=core.DataType.BOOL, value=False)
def read_ex(self, local_init_net, local_finish_net):
count_net = core.Net('limited_reader_counter')
value = count_net.CountUp([self.counter], 1)
return [count_net], self.should_stop, [value]
class ReaderWithLimit(Reader):
""" Reader that stops after `num_iter` calls. """
def __init__(self, reader, num_iter=1):
Reader.__init__(self, schema=reader._schema)
self.reader = reader
self.counter = None
self.num_iter = num_iter
net = core.Net('reader_with_limit')
self._data_finished = net.AddExternalInput(
net.NextName('data_finished'))
self.counter = net.AddExternalInput(net.NextName('counter'))
def setup_ex(self, global_init_net, global_finish_net):
global_init_net.CreateCounter(
[], [self.counter], init_count=int(self.num_iter))
self.reader.setup_ex(global_init_net, global_finish_net)
global_init_net.ConstantFill(
[], [self._data_finished],
shape=[], value=False, dtype=core.DataType.BOOL)
def read_ex(self, local_init_net, local_finish_net):
""" 1. check if we reached number of iterations """
count_net = core.Net('limited_reader_counter')
should_stop = count_net.CountDown([self.counter], 1)
""" 2. call original reader """
nets, local_data_finished, fields = self.reader.read_ex(
local_init_net, local_finish_net)
self._set_schema(self.reader._schema)
""" 3. check if original reader is done. """
check_done_net = core.Net('limited_reader_post')
check_done_net.Copy(local_data_finished, should_stop)
check_done_net.Copy([local_data_finished], [self._data_finished])
# this relies on `should_stop` being called after each net.
return [count_net] + nets + [check_done_net], should_stop, fields
def data_finished(self):
"""
Return a blob that can be checked after the end of the reading task,
which will contain a scalar float indicating whether the underlying
reader has been exhausted (True) or whether we stopped because reached
the limit of iterations (False).
"""
return self._data_finished
def CountUntil(num_iter):
return ReaderWithLimit(CounterReader(), num_iter)