blob: 6878afc668023d4d9a952e7a541a7e0ca2a43f7c [file] [log] [blame]
"""
Defines the base interface for reading and writing operations.
Readers/Writers are objects that produce operations that read/write sequences
of data. Each operation reads or writes a list of BlobReferences.
Readers and Writers must be implemented such that read and write operations
are atomic and thread safe.
Examples of possible Readers and Writers:
HiveReader, HiveWriter,
QueueReader, QueueWriter,
DatasetReader, DatasetWriter,
DBReader, DBWriter,
See `dataset.py` for an example of implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core
from caffe2.python.schema import Field, from_blob_list
class Reader(object):
def __init__(self, schema=None):
if schema is not None:
assert isinstance(schema, Field)
self._schema = schema
def schema(self):
"""
Return the schema associated with the Hive Reader
"""
assert self._schema is not None, 'Schema not provided for this reader.'
return self._schema
def setup_ex(self, init_net, finish_net):
"""Nets to be executed once at startup and finish.
Experimental extension. Don't use yet"""
pass
def read_ex(self, local_init_net, local_finish_net):
"""Experimental extension to the interface. Don't use yet"""
read_net = core.Net('reader_body')
return ([read_net], ) + self.read(read_net)
def read_record_ex(self, local_init_net, local_finish_net):
"""Experimental extension to the interface. Don't use yet"""
nets, should_stop, fields = self.read_ex(
local_init_net, local_finish_net)
if self._schema:
fields = from_blob_list(self._schema, fields)
return nets, should_stop, fields
"""
Reader is a abstract class to be implemented in order to provide
operations capable of iterating through a dataset or stream of data.
A Reader must implement at least one operation, `read`, which
adds operations to a net that read the next batch of data. Readers can
optionally support the `reset` operation, which is useful when multiple
passes over the data are required.
"""
def read(self, read_net):
"""
Add operations to read_net that will read the read batch of data
and return a list of BlobReference representing the blobs that will
contain the batches produced.
Operations added to `read_net` must be thread safe and atomic, that is,
it should be possible to clone `read_net` and run multiple instances of
it in parallel.
Args:
read_net: the net that will be appended with read operations
Returns:
A tuple (should_stop, fields), with:
should_stop: BlobReference pointing to a boolean scalar
blob that indicates whether the read operation
was succesfull or whether the end of data has
been reached.
fields: A tuple of BlobReference containing the latest batch
of data that was read.
"""
raise NotImplementedError('Readers must implement `read`.')
def reset(self, net):
"""Append operations to `net` that will reset the reader.
This can be used to read the data multiple times.
Not all readers support this operation.
"""
raise NotImplementedError('This reader cannot be resetted.')
def read_record(self, read_net):
should_stop, fields = self.read(read_net)
if self._schema:
fields = from_blob_list(self._schema, fields)
return should_stop, fields
def execution_step(self, reader_net_name=None):
"""Create an execution step with a net containing read operators.
The execution step will contain a `stop_blob` that knows how to stop
the execution loop when end of data was reached.
E.g.:
read_step, fields = reader.execution_step()
consume_net = core.Net('consume')
consume_net.Print(fields[0], [])
p = core.Plan('reader')
p.AddStep(read_step.AddNet(consume_net))
core.RunPlan(p)
Args:
reader_net_name: (optional) the name of the reader_net to be
created. The execution step will
be named accordingly.
Returns:
A tuple (read_step, fields), with:
read_step: A newly created execution step containing a net with
read operations. The step will have `stop_blob` set,
in order to stop the loop on end of data.
fields: A tuple of BlobReference containing the latest batch
of data that was read.
"""
reader_net = core.Net(reader_net_name or 'reader')
should_stop, fields = self.read_record(reader_net)
read_step = core.execution_step(
'{}_step'.format(reader_net_name),
reader_net,
should_stop_blob=should_stop)
return (read_step, fields)
class Writer(object):
"""
Writer is a abstract class to be implemented in order to provide
operations capable of feeding a data stream or a dataset.
A Writer must implement 2 operations:
`write`, which adds operations to a net that write the write batch of
data, and `commit`, which adds operations to a net in order to indicate
that no more data will be written.
"""
def write(self, writer_net, fields):
"""Add operations to `writer_net` that write the next batch of data.
Operations added to the net must be thread-safe and unique, that is:
multiple writers must be able to write to the dataset in parallel.
Args:
fields: a tuple of BlobReference containing the batch of data to
write.
"""
raise NotImplementedError('Writers must implement write.')
def write_record(self, writer_net, fields):
if isinstance(fields, Field):
fields = fields.field_blobs()
self.write(writer_net, fields)
def setup_ex(self, init_net, finish_net):
"""Experimental, don't use yet"""
self.commit(finish_net)
def write_ex(self, fields, local_init_net, local_finish_net, stop_blob):
"""Experimental extension to the interface. Don't use yet"""
write_net = core.Net('write_net')
self.write(write_net, fields)
return [write_net]
def write_record_ex(
self, fields, local_init_net, local_finish_net, stop_blob=None):
"""Experimental extension to the interface. Don't use yet."""
if isinstance(fields, Field):
fields = fields.field_blobs()
if stop_blob is None:
stop_blob = local_init_net.NextName("dequeue_status")
write_nets = self.write_ex(
fields, local_init_net, local_finish_net, stop_blob)
return (write_nets, stop_blob)
def commit(self, finish_net):
"""Add operations to `finish_net` that signal end of data.
This must be implemented by all Writers, but may be no-op for some
of them.
"""
pass