blob: 767806e448438165339150aca9c9325581b3b16d [file] [log] [blame]
#pragma once
#include <atomic>
#include <condition_variable>
#include <memory>
#include <mutex>
#include <queue>
#include "caffe2/core/blob_stats.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/stats.h"
#include "caffe2/core/tensor.h"
#include "caffe2/core/workspace.h"
namespace caffe2 {
// A thread-safe, bounded, blocking queue.
// Modelled as a circular buffer.
// Containing blobs are owned by the workspace.
// On read, we swap out the underlying data for the blob passed in for blobs
class BlobsQueue : public std::enable_shared_from_this<BlobsQueue> {
public:
BlobsQueue(
Workspace* ws,
const std::string& queueName,
size_t capacity,
size_t numBlobs,
bool enforceUniqueName,
const std::vector<std::string>& fieldNames = {})
: numBlobs_(numBlobs), stats_(queueName) {
if (!fieldNames.empty()) {
CAFFE_ENFORCE_EQ(
fieldNames.size(), numBlobs, "Wrong number of fieldNames provided.");
stats_.queue_dequeued_bytes.setDetails(fieldNames);
}
queue_.reserve(capacity);
for (auto i = 0; i < capacity; ++i) {
std::vector<Blob*> blobs;
blobs.reserve(numBlobs);
for (auto j = 0; j < numBlobs; ++j) {
const auto blobName =
queueName + "_" + to_string(i) + "_" + to_string(j);
if (enforceUniqueName) {
CAFFE_ENFORCE(
!ws->GetBlob(blobName),
"Queue internal blob already exists: ",
blobName);
}
blobs.push_back(ws->CreateBlob(blobName));
}
queue_.push_back(blobs);
}
DCHECK_EQ(queue_.size(), capacity);
}
~BlobsQueue() {
close();
}
bool blockingRead(const std::vector<Blob*>& inputs) {
auto keeper = this->shared_from_this();
std::unique_lock<std::mutex> g(mutex_);
auto canRead = [this]() {
CAFFE_ENFORCE_LE(reader_, writer_);
return reader_ != writer_;
};
CAFFE_EVENT(stats_, queue_balance, -1);
cv_.wait(g, [this, canRead]() { return closing_ || canRead(); });
if (!canRead()) {
return false;
}
DCHECK(canRead());
auto& result = queue_[reader_ % queue_.size()];
CAFFE_ENFORCE(inputs.size() >= result.size());
for (auto i = 0; i < result.size(); ++i) {
auto bytes = BlobStat::sizeBytes(*result[i]);
CAFFE_EVENT(stats_, queue_dequeued_bytes, bytes, i);
using std::swap;
swap(*(inputs[i]), *(result[i]));
}
CAFFE_EVENT(stats_, queue_dequeued_records);
++reader_;
cv_.notify_all();
return true;
}
bool tryWrite(const std::vector<Blob*>& inputs) {
auto keeper = this->shared_from_this();
std::unique_lock<std::mutex> g(mutex_);
if (!canWrite()) {
return false;
}
CAFFE_EVENT(stats_, queue_balance, 1);
DCHECK(canWrite());
doWrite(inputs);
return true;
}
bool blockingWrite(const std::vector<Blob*>& inputs) {
auto keeper = this->shared_from_this();
std::unique_lock<std::mutex> g(mutex_);
CAFFE_EVENT(stats_, queue_balance, 1);
cv_.wait(g, [this]() { return closing_ || canWrite(); });
if (!canWrite()) {
return false;
}
DCHECK(canWrite());
doWrite(inputs);
return true;
}
void close() {
closing_ = true;
std::lock_guard<std::mutex> g(mutex_);
cv_.notify_all();
}
size_t getNumBlobs() const {
return numBlobs_;
}
private:
bool canWrite() {
// writer is always within [reader, reader + size)
// we can write if reader is within [reader, reader + size)
CAFFE_ENFORCE_LE(reader_, writer_);
CAFFE_ENFORCE_LE(writer_, reader_ + queue_.size());
return writer_ != reader_ + queue_.size();
}
void doWrite(const std::vector<Blob*>& inputs) {
auto& result = queue_[writer_ % queue_.size()];
CAFFE_ENFORCE(inputs.size() >= result.size());
for (auto i = 0; i < result.size(); ++i) {
using std::swap;
swap(*(inputs[i]), *(result[i]));
}
++writer_;
cv_.notify_all();
}
std::atomic<bool> closing_{false};
size_t numBlobs_;
std::mutex mutex_; // protects all variables in the class.
std::condition_variable cv_;
int64_t reader_{0};
int64_t writer_{0};
std::vector<std::vector<Blob*>> queue_;
struct QueueStats {
CAFFE_STAT_CTOR(QueueStats);
CAFFE_EXPORTED_STAT(queue_balance);
CAFFE_EXPORTED_STAT(queue_dequeued_records);
CAFFE_DETAILED_EXPORTED_STAT(queue_dequeued_bytes);
} stats_;
};
}