|  | #pragma once | 
|  |  | 
|  | #include <memory> | 
|  | #include "blobs_queue.h" | 
|  | #include "caffe2/core/operator.h" | 
|  | #include "caffe2/utils/math.h" | 
|  |  | 
|  | namespace caffe2 { | 
|  |  | 
|  | template <typename Context> | 
|  | class CreateBlobsQueueOp final : public Operator<Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  |  | 
|  | CreateBlobsQueueOp(const OperatorDef& operator_def, Workspace* ws) | 
|  | : Operator<Context>(operator_def, ws), | 
|  | ws_(ws), | 
|  | name(operator_def.output().Get(0)) {} | 
|  |  | 
|  | bool RunOnDevice() override { | 
|  | const auto capacity = GetSingleArgument("capacity", 1); | 
|  | const auto numBlobs = GetSingleArgument("num_blobs", 1); | 
|  | const auto enforceUniqueName = | 
|  | GetSingleArgument("enforce_unique_name", false); | 
|  | const auto fieldNames = | 
|  | OperatorBase::template GetRepeatedArgument<std::string>("field_names"); | 
|  | CAFFE_ENFORCE_EQ(this->OutputSize(), 1); | 
|  | auto queuePtr = Operator<Context>::Outputs()[0] | 
|  | ->template GetMutable<std::shared_ptr<BlobsQueue>>(); | 
|  | CAFFE_ENFORCE(queuePtr); | 
|  | *queuePtr = std::make_shared<BlobsQueue>( | 
|  | ws_, name, capacity, numBlobs, enforceUniqueName, fieldNames); | 
|  | return true; | 
|  | } | 
|  |  | 
|  | private: | 
|  | Workspace* ws_{nullptr}; | 
|  | const std::string name; | 
|  | }; | 
|  |  | 
|  | template <typename Context> | 
|  | class EnqueueBlobsOp final : public Operator<Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  | using Operator<Context>::Operator; | 
|  | bool RunOnDevice() override { | 
|  | CAFFE_ENFORCE(InputSize() > 1); | 
|  | auto queue = Operator<Context>::Inputs()[0] | 
|  | ->template Get<std::shared_ptr<BlobsQueue>>(); | 
|  | CAFFE_ENFORCE(queue && OutputSize() == queue->getNumBlobs()); | 
|  | return queue->blockingWrite(this->Outputs()); | 
|  | } | 
|  |  | 
|  | private: | 
|  | }; | 
|  |  | 
|  | template <typename Context> | 
|  | class DequeueBlobsOp final : public Operator<Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  |  | 
|  | DequeueBlobsOp(const OperatorDef& operator_def, Workspace* ws) | 
|  | : Operator<Context>(operator_def, ws) { | 
|  | timeout_secs_ = OperatorBase::GetSingleArgument<float>("timeout_secs", 0); | 
|  | } | 
|  |  | 
|  | bool RunOnDevice() override { | 
|  | CAFFE_ENFORCE(InputSize() == 1); | 
|  | auto queue = | 
|  | OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>(); | 
|  | CAFFE_ENFORCE(queue && OutputSize() == queue->getNumBlobs()); | 
|  | return queue->blockingRead(this->Outputs(), timeout_secs_); | 
|  | } | 
|  |  | 
|  | private: | 
|  | float timeout_secs_; | 
|  | }; | 
|  |  | 
|  | template <typename Context> | 
|  | class CloseBlobsQueueOp final : public Operator<Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  | using Operator<Context>::Operator; | 
|  | bool RunOnDevice() override { | 
|  | CAFFE_ENFORCE_EQ(InputSize(), 1); | 
|  | auto queue = | 
|  | OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>(); | 
|  | CAFFE_ENFORCE(queue); | 
|  | queue->close(); | 
|  | return true; | 
|  | } | 
|  |  | 
|  | private: | 
|  | }; | 
|  |  | 
|  | template <typename Context> | 
|  | class SafeEnqueueBlobsOp final : public Operator<Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  | using Operator<Context>::Operator; | 
|  | bool RunOnDevice() override { | 
|  | auto queue = Operator<Context>::Inputs()[0] | 
|  | ->template Get<std::shared_ptr<BlobsQueue>>(); | 
|  | CAFFE_ENFORCE(queue); | 
|  | auto size = queue->getNumBlobs(); | 
|  | CAFFE_ENFORCE( | 
|  | OutputSize() == size + 1, | 
|  | "Expected " + c10::to_string(size + 1) + ", " + | 
|  | " got: " + c10::to_string(size)); | 
|  | bool status = queue->blockingWrite(this->Outputs()); | 
|  | Output(size)->Resize(); | 
|  | math::Set<bool, Context>( | 
|  | 1, !status, Output(size)->template mutable_data<bool>(), &context_); | 
|  | return true; | 
|  | } | 
|  | }; | 
|  |  | 
|  | template <typename Context> | 
|  | class SafeDequeueBlobsOp final : public Operator<Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  | using Operator<Context>::Operator; | 
|  |  | 
|  | SafeDequeueBlobsOp(const OperatorDef& operator_def, Workspace* ws) | 
|  | : Operator<Context>(operator_def, ws), | 
|  | numRecords_(OperatorBase::GetSingleArgument<int>("num_records", 1)) { | 
|  | CAFFE_ENFORCE_GT(numRecords_, 0); | 
|  | } | 
|  |  | 
|  | bool dequeueMany(std::shared_ptr<BlobsQueue>& queue) { | 
|  | auto size = queue->getNumBlobs(); | 
|  |  | 
|  | if (blobs_.size() != size) { | 
|  | blobs_.resize(size); | 
|  | blobPtrs_.resize(size); | 
|  | for (int col = 0; col < size; ++col) { | 
|  | blobPtrs_.at(col) = &blobs_.at(col); | 
|  | } | 
|  | } | 
|  |  | 
|  | const int kTensorGrowthPct = 40; | 
|  | for (int i = 0; i < numRecords_; ++i) { | 
|  | if (!queue->blockingRead(blobPtrs_)) { | 
|  | // if we read at least one record, status is still true | 
|  | return i > 0; | 
|  | } | 
|  | for (int col = 0; col < size; ++col) { | 
|  | auto* out = this->Output(col); | 
|  | const auto& in = blobPtrs_.at(col)->template Get<Tensor>(); | 
|  | if (i == 0) { | 
|  | out->CopyFrom(in); | 
|  | } else { | 
|  | auto oldSize = out->numel(); | 
|  |  | 
|  | CAFFE_ENFORCE( | 
|  | in.dim() > 0, | 
|  | "Empty tensor to dequeue at column ", | 
|  | col, | 
|  | " within ", | 
|  | size, | 
|  | " total columns"); | 
|  |  | 
|  | out->Extend(in.sizes()[0], kTensorGrowthPct); | 
|  | auto* dst = | 
|  | (char*)out->raw_mutable_data() + oldSize * in.dtype().itemsize(); | 
|  | context_.template CopyItems<Context, Context>( | 
|  | in.meta(), in.numel(), in.raw_data(), dst); | 
|  | } | 
|  | } | 
|  | } | 
|  | return true; | 
|  | } | 
|  |  | 
|  | bool dequeueOne(std::shared_ptr<BlobsQueue>& queue) { | 
|  | return queue->blockingRead(this->Outputs()); | 
|  | } | 
|  |  | 
|  | bool RunOnDevice() override { | 
|  | CAFFE_ENFORCE(InputSize() == 1); | 
|  | auto queue = Operator<Context>::Inputs()[0] | 
|  | ->template Get<std::shared_ptr<BlobsQueue>>(); | 
|  | CAFFE_ENFORCE(queue); | 
|  |  | 
|  | auto size = queue->getNumBlobs(); | 
|  | CAFFE_ENFORCE_EQ(OutputSize(), size + 1); | 
|  |  | 
|  | bool status = numRecords_ > 1 ? dequeueMany(queue) : dequeueOne(queue); | 
|  |  | 
|  | Output(size)->Resize(); | 
|  | math::Set<bool, Context>( | 
|  | 1, !status, Output(size)->template mutable_data<bool>(), &context_); | 
|  | return true; | 
|  | } | 
|  |  | 
|  | private: | 
|  | int numRecords_; | 
|  | std::vector<Blob> blobs_; | 
|  | std::vector<Blob*> blobPtrs_; | 
|  | }; | 
|  |  | 
|  | template <typename Context> | 
|  | class WeightedSampleDequeueBlobsOp final : public Operator<Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  |  | 
|  | WeightedSampleDequeueBlobsOp(const OperatorDef& operator_def, Workspace* ws) | 
|  | : Operator<Context>(operator_def, ws), | 
|  | table_idx_blob_( | 
|  | OperatorBase::GetSingleArgument<int>("table_idx_blob", -1)) { | 
|  | CAFFE_ENFORCE_LT(table_idx_blob_, OutputSize() - 1); | 
|  | vector<float> weights = OperatorBase::GetRepeatedArgument<float>("weights"); | 
|  | if (weights.empty()) { | 
|  | weights.resize(InputSize(), 1.0f); | 
|  | } | 
|  | CAFFE_ENFORCE_EQ(InputSize(), weights.size()); | 
|  |  | 
|  | float sum = accumulate(weights.begin(), weights.end(), 0.0f); | 
|  | CAFFE_ENFORCE(sum > 0.0f, "Sum of weights must be positive"); | 
|  | cumProbs_.resize(weights.size()); | 
|  | for (int i = 0; i < weights.size(); i++) { | 
|  | cumProbs_[i] = weights[i] / sum; | 
|  | CAFFE_ENFORCE_GE( | 
|  | cumProbs_[i], 0.0f, "Each probability must be non-negative"); | 
|  | } | 
|  | std::partial_sum(cumProbs_.begin(), cumProbs_.end(), cumProbs_.begin()); | 
|  | // Put last value to be 1.0001 to avoid numerical issues. | 
|  | cumProbs_.back() = 1.0001f; | 
|  |  | 
|  | LOG(INFO) << "Dequeue weights: " << weights; | 
|  | LOG(INFO) << "cumProbs: " << cumProbs_; | 
|  | } | 
|  |  | 
|  | bool RunOnDevice() override { | 
|  | float r; | 
|  | math::RandUniform<float, Context>(1, 0.0f, 1.0f, &r, &context_); | 
|  | auto lb = lower_bound(cumProbs_.begin(), cumProbs_.end(), r); | 
|  | CAFFE_ENFORCE(lb != cumProbs_.end(), "Cannot find ", r, " in cumProbs_."); | 
|  | const int32_t idx = lb - cumProbs_.begin(); | 
|  | auto queue = Operator<Context>::Inputs()[idx] | 
|  | ->template Get<std::shared_ptr<BlobsQueue>>(); | 
|  |  | 
|  | CAFFE_ENFORCE(queue); | 
|  | auto size = queue->getNumBlobs(); | 
|  | CAFFE_ENFORCE_EQ(OutputSize(), size + 1); | 
|  | bool status = queue->blockingRead(this->Outputs()); | 
|  | if (table_idx_blob_ >= 0) { | 
|  | auto* table_idx_blob_out = | 
|  | Output(table_idx_blob_, {1}, at::dtype<int32_t>()); | 
|  | int32_t* data = table_idx_blob_out->template mutable_data<int32_t>(); | 
|  | data[0] = idx; | 
|  | } | 
|  |  | 
|  | Output(size)->Resize(); | 
|  | math::Set<bool, Context>( | 
|  | 1, !status, Output(size)->template mutable_data<bool>(), &context_); | 
|  | return true; | 
|  | } | 
|  |  | 
|  | private: | 
|  | vector<float> cumProbs_; | 
|  | int table_idx_blob_; | 
|  | }; | 
|  | } // namespace caffe2 |