blob: ff8d1a73bce101bbcf7326bf9c691e9c0544ec82 [file] [log] [blame]
# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
## @package dataio
# Module caffe2.python.dataio
"""
Defines the base interface for reading and writing operations.
Readers/Writers are objects that produce operations that read/write sequences
of data. Each operation reads or writes a list of BlobReferences.
Readers and Writers must be implemented such that read and write operations
are atomic and thread safe.
Examples of possible Readers and Writers:
QueueReader, QueueWriter,
DatasetReader, DatasetWriter,
See `dataset.py` for an example of implementation.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core
from caffe2.python.schema import Field, Struct, from_blob_list
import numpy as np
class Reader(object):
def __init__(self, schema=None):
if schema is not None:
assert isinstance(schema, Field)
self._schema = schema
def schema(self):
"""
Return the schema associated with the Reader
"""
assert self._schema is not None, 'Schema not provided for this reader.'
return self._schema
def _set_schema(self, schema):
self._schema = schema
def setup_ex(self, init_net, finish_net):
"""Nets to be executed once at startup and finish."""
pass
def read_ex(self, local_init_net, local_finish_net):
read_net = core.Net('reader_body')
return ([read_net], ) + self.read(read_net)
def read_record_ex(self, local_init_net, local_finish_net):
nets, should_stop, fields = self.read_ex(
local_init_net, local_finish_net)
if self._schema:
fields = from_blob_list(self._schema, fields)
return nets, should_stop, fields
"""
Reader is an abstract class to be implemented in order to provide
operations capable of iterating through a dataset or stream of data.
A Reader must implement at least one operation, `read`, which
adds operations to a net that read the next batch of data. Readers can
optionally support the `reset` operation, which is useful when multiple
passes over the data are required.
"""
def read(self, read_net):
"""
Add operations to read_net that will read the read batch of data
and return a list of BlobReference representing the blobs that will
contain the batches produced.
Operations added to `read_net` must be thread safe and atomic, that is,
it should be possible to clone `read_net` and run multiple instances of
it in parallel.
Args:
read_net: the net that will be appended with read operations
Returns:
A tuple (should_stop, fields), with:
should_stop: BlobReference pointing to a boolean scalar
blob that indicates whether the read operation
was succesfull or whether the end of data has
been reached.
fields: A tuple of BlobReference containing the latest batch
of data that was read.
"""
raise NotImplementedError('Readers must implement `read`.')
def reset(self, net):
"""Append operations to `net` that will reset the reader.
This can be used to read the data multiple times.
Not all readers support this operation.
"""
raise NotImplementedError('This reader cannot be resetted.')
def read_record(self, read_net):
should_stop, fields = self.read(read_net)
if self._schema:
fields = from_blob_list(self._schema, fields)
return should_stop, fields
def execution_step(self, reader_net_name=None, external_should_stop=None):
"""Create an execution step with a net containing read operators.
The execution step will contain a `stop_blob` that knows how to stop
the execution loop when end of data was reached.
E.g.:
read_step, fields = reader.execution_step()
consume_net = core.Net('consume')
consume_net.Print(fields[0], [])
p = core.Plan('reader')
p.AddStep(read_step.AddNet(consume_net))
core.RunPlan(p)
Args:
reader_net_name: (optional) the name of the reader_net to be
created. The execution step will
be named accordingly.
Returns:
A tuple (read_step, fields), with:
read_step: A newly created execution step containing a net with
read operations. The step will have `stop_blob` set,
in order to stop the loop on end of data.
fields: A tuple of BlobReference containing the latest batch
of data that was read.
"""
reader_net = core.Net(reader_net_name or 'reader')
should_stop, fields = self.read_record(reader_net)
if external_should_stop is not None:
should_stop = reader_net.Or([external_should_stop, should_stop])
read_step = core.execution_step(
'{}_step'.format(reader_net_name),
reader_net,
should_stop_blob=should_stop)
return (read_step, fields)
class Writer(object):
"""
Writer is an abstract class to be implemented in order to provide
operations capable of feeding a data stream or a dataset.
A Writer must implement 2 operations:
`write`, which adds operations to a net that write the write batch of
data, and `commit`, which adds operations to a net in order to indicate
that no more data will be written.
"""
_schema = None
def schema(self):
return self._schema
def write(self, writer_net, fields):
"""Add operations to `writer_net` that write the next batch of data.
Operations added to the net must be thread-safe and unique, that is:
multiple writers must be able to write to the dataset in parallel.
Args:
fields: a tuple of BlobReference containing the batch of data to
write.
"""
raise NotImplementedError('Writers must implement write.')
def write_record(self, writer_net, fields):
if isinstance(fields, Field):
self._schema = fields
fields = fields.field_blobs()
self.write(writer_net, fields)
def setup_ex(self, init_net, finish_net):
"""Experimental, don't use yet"""
self.commit(finish_net)
def write_ex(self, fields, local_init_net, local_finish_net, stop_blob):
"""Experimental extension to the interface. Don't use yet"""
write_net = core.Net('write_net')
self.write(write_net, fields)
return [write_net]
def write_record_ex(
self, fields, local_init_net, local_finish_net, stop_blob=None):
"""Experimental extension to the interface. Don't use yet."""
if isinstance(fields, Field):
self._schema = fields
fields = fields.field_blobs()
if stop_blob is None:
stop_blob = local_init_net.NextName("dequeue_status")
write_nets = self.write_ex(
fields, local_init_net, local_finish_net, stop_blob)
return (write_nets, stop_blob)
def commit(self, finish_net):
"""Add operations to `finish_net` that signal end of data.
This must be implemented by all Writers, but may be no-op for some
of them.
"""
pass
class ReaderBuilder(object):
""" Allow usage of a reader in distributed fashion. """
def schema(self):
raise NotImplementedError()
def enqueue_splits(self, net, split_queue):
raise NotImplementedError()
def splits(self, net):
raise NotImplementedError()
def new_reader(self, split_reader=None, **kwargs):
raise NotImplementedError()
class PipedReaderBuilder(ReaderBuilder):
"""
ReaderBuilder that modifies underlying builder by calling `piper`
function on each new reader produced, and return the result of
the function. This way, it is possible to append data processing
pipelines that will be replicated for each reader that gets created.
E.g.:
PipedReaderBuilder(
ReaderBuilder(...),
lambda reader: pipe(reader, processor=my_proc))
"""
def __init__(self, builder, piper):
self._builder = builder
self._piper = piper
def schema(self):
return self._builder.schema()
def enqueue_splits(self, net, split_queue):
return self._builder.enqueue_splits(net, split_queue)
def splits(self, net):
return self._builder.splits(net)
def new_reader(self, split_reader=None, **kwargs):
output = self._piper(self._builder.new_reader(split_reader), **kwargs)
return output if isinstance(output, Reader) else output.reader()
class Pipe(object):
def __init__(self, schema=None, obj_key=None):
self._num_writers = 0
self._num_readers = 0
self._schema = schema
self._obj_key = obj_key
def schema(self):
return self._schema
def setup(self, global_init_net):
pass
def reader(self):
raise NotImplementedError()
def writer(self):
raise NotImplementedError()
def num_readers(self):
return self._num_readers
def num_writers(self):
return self._num_writers
def _new_writer(self, writer_schema, writer_init_net):
if writer_schema is not None and self._schema is None:
self._schema = writer_schema
self._num_writers += 1
if self._obj_key is not None:
writer_init_net.add_attribute(self._obj_key, self)
def _new_reader(self, reader_init_net):
self._num_readers += 1
if self._obj_key is not None:
reader_init_net.add_attribute(self._obj_key, self)
class CounterReader(Reader):
""" Reader that produces increasing integers. """
def __init__(self):
Reader.__init__(self, schema=Struct(('iter', np.int64)))
self.counter = None
self.should_stop = None
def setup_ex(self, global_init_net, global_finish_net):
if self.counter is None:
self.counter = global_init_net.CreateCounter([], init_count=0)
self.should_stop = global_init_net.ConstantFill(
[], shape=[], dtype=core.DataType.BOOL, value=False)
def read_ex(self, local_init_net, local_finish_net):
count_net = core.Net('limited_reader_counter')
value = count_net.CountUp([self.counter], 1)
return [count_net], self.should_stop, [value]
class ReaderWithLimit(Reader):
"""
Reader that stops after `num_iter` calls.
If num_iter is None it becomes just a simple reader that exports a global
flag for "out of data".
"""
def __init__(self, reader, num_iter=1):
Reader.__init__(self, schema=reader._schema)
self.reader = reader
self.counter = None
self.num_iter = num_iter
net = core.Net('reader_with_limit')
self._data_finished = net.AddExternalInput(
net.NextName('data_finished'))
if self.num_iter is not None:
self.counter = net.AddExternalInput(net.NextName('counter'))
def setup_ex(self, global_init_net, global_finish_net):
if self.counter:
global_init_net.CreateCounter(
[], [self.counter], init_count=int(self.num_iter))
self.reader.setup_ex(global_init_net, global_finish_net)
global_init_net.ConstantFill(
[], [self._data_finished],
shape=[], value=False, dtype=core.DataType.BOOL)
def read_ex(self, local_init_net, local_finish_net):
""" 1. check if we reached number of iterations and populate the same
should_stop blob """
count_net = core.Net('limited_reader_counter')
if self.counter:
should_stop = count_net.CountDown([self.counter], 1)
else:
should_stop = count_net.ConstantFill(
[], 1,
shape=[], value=False, dtype=core.DataType.BOOL)
""" 2. call original reader """
nets, local_data_finished, fields = self.reader.read_ex(
local_init_net, local_finish_net)
self._set_schema(self.reader._schema)
""" 3. check if original reader is done. """
check_done_net = core.Net('limited_reader_post')
# copy to the same blob as the counter output to trigger reader
# stopping
check_done_net.Copy(local_data_finished, should_stop)
# update global flag that underlying reader is done
check_done_net.Or([self._data_finished, local_data_finished],
[self._data_finished])
# this relies on `should_stop` being called after each net.
return [count_net] + nets + [check_done_net], should_stop, fields
def data_finished(self):
"""
Return a blob that can be checked after the end of the reading task,
which will contain a scalar float indicating whether the underlying
reader has been exhausted (True) or whether we stopped because reached
the limit of iterations (False).
"""
return self._data_finished
def CountUntil(num_iter):
return ReaderWithLimit(CounterReader(), num_iter)