blob: 26f0859eef028a05509e52b6c2733dbfb19ba462 [file] [log] [blame]
#pragma once
#include <memory>
#include "blobs_queue.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
template <typename Context>
class CreateBlobsQueueOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
CreateBlobsQueueOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws), ws_(ws) {}
bool RunOnDevice() override {
const auto capacity =
OperatorBase::template GetSingleArgument<int>("capacity", 1);
const auto numBlobs =
OperatorBase::template GetSingleArgument<int>("num_blobs", 1);
const auto enforceUniqueName =
OperatorBase::template GetSingleArgument<int>(
"enforce_unique_name", false);
CAFFE_ENFORCE_EQ(def().output().size(), 1);
const auto name = def().output().Get(0);
auto queuePtr = Operator<Context>::Outputs()[0]
->template GetMutable<std::shared_ptr<BlobsQueue>>();
CAFFE_ENFORCE(queuePtr);
*queuePtr = std::make_shared<BlobsQueue>(
ws_, name, capacity, numBlobs, enforceUniqueName);
return true;
}
private:
Workspace* ws_{nullptr};
};
template <typename Context>
class EnqueueBlobsOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
using Operator<Context>::Operator;
bool RunOnDevice() override {
CAFFE_ENFORCE(InputSize() > 1);
auto queue = Operator<Context>::Inputs()[0]
->template Get<std::shared_ptr<BlobsQueue>>();
CAFFE_ENFORCE(queue && OutputSize() == queue->getNumBlobs());
return queue->blockingWrite(this->Outputs());
}
private:
};
template <typename Context>
class DequeueBlobsOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
using Operator<Context>::Operator;
bool RunOnDevice() override {
CAFFE_ENFORCE(InputSize() == 1);
auto queue =
OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>();
CAFFE_ENFORCE(queue && OutputSize() == queue->getNumBlobs());
return queue->blockingRead(this->Outputs());
}
private:
};
template <typename Context>
class CloseBlobsQueueOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
using Operator<Context>::Operator;
bool RunOnDevice() override {
CAFFE_ENFORCE_EQ(InputSize(), 1);
auto queue =
OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>();
CHECK(queue);
queue->close();
return true;
}
private:
};
template <typename Context>
class SafeEnqueueBlobsOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
using Operator<Context>::Operator;
bool RunOnDevice() override {
auto queue = Operator<Context>::Inputs()[0]
->template Get<std::shared_ptr<BlobsQueue>>();
CAFFE_ENFORCE(queue);
auto size = queue->getNumBlobs();
CAFFE_ENFORCE(
OutputSize() == size + 1,
"Expected " + caffe2::to_string(size + 1) + ", " + " got: " +
caffe2::to_string(size));
bool status = queue->blockingWrite(this->Outputs());
Output(size)->Resize();
*Output(size)->template mutable_data<bool>() = !status;
return true;
}
};
template <typename Context>
class SafeDequeueBlobsOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
using Operator<Context>::Operator;
bool RunOnDevice() override {
CAFFE_ENFORCE(InputSize() == 1);
auto queue = Operator<Context>::Inputs()[0]
->template Get<std::shared_ptr<BlobsQueue>>();
CAFFE_ENFORCE(queue);
auto size = queue->getNumBlobs();
CAFFE_ENFORCE(
OutputSize() == size + 1,
"Expected " + caffe2::to_string(size + 1) + ", " + " got: " +
caffe2::to_string(size));
bool status = queue->blockingRead(this->Outputs());
Output(size)->Resize();
*Output(size)->template mutable_data<bool>() = !status;
return true;
}
private:
};
}