blob: a703358d0311870cef8472b23da7c8efb6913f26 [file] [log] [blame]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core, dataio
class QueueReader(dataio.Reader):
def __init__(self, queue, num_blobs=None, schema=None):
dataio.Reader.__init__(self, schema)
assert schema is not None or num_blobs is not None, (
'Either schema or num_blobs must be provided.')
self.queue = queue
self.num_blobs = num_blobs
if schema is not None:
schema_num_blobs = len(schema.field_names())
assert num_blobs is None or num_blobs == schema_num_blobs
self.num_blobs = schema_num_blobs
def setup_ex(self, init_net, exit_net):
exit_net.CloseBlobsQueue([self.queue], 0)
def read_ex(self, local_init_net, local_finish_net):
dequeue_net = core.Net('dequeue_net')
fields, status_blob = dequeue(dequeue_net, self.queue, self.num_blobs)
return [dequeue_net], status_blob, fields
class QueueWriter(dataio.Writer):
def __init__(self, queue):
self.queue = queue
def setup_ex(self, init_net, exit_net):
exit_net.CloseBlobsQueue([self.queue], 0)
def write_ex(self, fields, local_init_net, local_finish_net, status):
enqueue_net = core.Net('enqueue_net')
enqueue(enqueue_net, self.queue, fields, status)
return [enqueue_net]
class QueueWrapper(object):
def __init__(self, init_net, capacity, schema):
self._queue = init_net.CreateBlobsQueue(
[],
capacity=capacity,
num_blobs=len(schema.field_names()))
self._schema = schema
def reader(self):
return QueueReader(self._queue, schema=self._schema)
def writer(self):
return QueueWriter(self._queue)
def queue(self):
return self._queue
def schema(self):
return self._schema
def enqueue(net, queue, data_blobs, status=None):
if status is None:
status = net.NextName("%s_enqueue_status" % str(queue))
results = net.SafeEnqueueBlobs([queue] + data_blobs, data_blobs + [status])
return results[-1]
def dequeue(net, queue, num_blobs, status=None):
data_names = [net.NextName("%s_dequeue_data", i) for i in range(num_blobs)]
if status is None:
status = net.NextName("%s_dequeue_status")
results = net.SafeDequeueBlobs(queue, data_names + [status])
results = list(results)
status_blob = results.pop(-1)
return results, status_blob
def close_queue(step, *queues):
close_net = core.Net("close_queue_net")
for queue in queues:
close_net.CloseBlobsQueue([queue], 0)
close_step = core.execution_step("%s_step" % str(close_net), close_net)
return core.execution_step(
"%s_wraper_step" % str(close_net),
[step, close_step])