#pragma once

#include <memory>
#include "blobs_queue.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"

#include <c10/util/irange.h>

namespace caffe2 {

template <typename Context>
class CreateBlobsQueueOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  CreateBlobsQueueOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws),
        ws_(ws),
        name(operator_def.output().Get(0)) {}

  bool RunOnDevice() override {
    const auto capacity = GetSingleArgument("capacity", 1);
    const auto numBlobs = GetSingleArgument("num_blobs", 1);
    const auto enforceUniqueName =
        GetSingleArgument("enforce_unique_name", false);
    const auto fieldNames =
        OperatorBase::template GetRepeatedArgument<std::string>("field_names");
    CAFFE_ENFORCE_EQ(this->OutputSize(), 1);
    auto queuePtr = Operator<Context>::Outputs()[0]
                        ->template GetMutable<std::shared_ptr<BlobsQueue>>();
    CAFFE_ENFORCE(queuePtr);
    *queuePtr = std::make_shared<BlobsQueue>(
        ws_, name, capacity, numBlobs, enforceUniqueName, fieldNames);
    return true;
  }

 private:
  Workspace* ws_{nullptr};
  const std::string name;
};

template <typename Context>
class EnqueueBlobsOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  using Operator<Context>::Operator;
  bool RunOnDevice() override {
    CAFFE_ENFORCE(InputSize() > 1);
    auto queue = Operator<Context>::Inputs()[0]
                     ->template Get<std::shared_ptr<BlobsQueue>>();
    CAFFE_ENFORCE(
        queue && static_cast<size_t>(OutputSize()) == queue->getNumBlobs());
    return queue->blockingWrite(this->Outputs());
  }

 private:
};

template <typename Context>
class DequeueBlobsOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  DequeueBlobsOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws) {
    timeout_secs_ = OperatorBase::GetSingleArgument<float>("timeout_secs", 0);
  }

  bool RunOnDevice() override {
    CAFFE_ENFORCE(InputSize() == 1);
    auto queue =
        OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>();
    CAFFE_ENFORCE(
        queue && static_cast<size_t>(OutputSize()) == queue->getNumBlobs());
    return queue->blockingRead(this->Outputs(), timeout_secs_);
  }

 private:
  float timeout_secs_;
};

template <typename Context>
class CloseBlobsQueueOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  using Operator<Context>::Operator;
  bool RunOnDevice() override {
    CAFFE_ENFORCE_EQ(InputSize(), 1);
    auto queue =
        OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>();
    CAFFE_ENFORCE(queue);
    queue->close();
    return true;
  }

 private:
};

template <typename Context>
class SafeEnqueueBlobsOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  using Operator<Context>::Operator;
  bool RunOnDevice() override {
    auto queue = Operator<Context>::Inputs()[0]
                     ->template Get<std::shared_ptr<BlobsQueue>>();
    CAFFE_ENFORCE(queue);
    auto size = queue->getNumBlobs();
    CAFFE_ENFORCE(
        static_cast<size_t>(OutputSize()) == size + 1,
        "Expected " + c10::to_string(size + 1) + ", " +
            " got: " + c10::to_string(size));
    bool status = queue->blockingWrite(this->Outputs());
    Output(size)->Resize();
    math::Set<bool, Context>(
        1, !status, Output(size)->template mutable_data<bool>(), &context_);
    return true;
  }

  void Cancel() override {
    auto queue = Operator<Context>::Inputs()[0]
                     ->template Get<std::shared_ptr<BlobsQueue>>();
    queue->close();
  }
};

template <typename Context>
class SafeDequeueBlobsOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  using Operator<Context>::Operator;

  SafeDequeueBlobsOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws),
        numRecords_(OperatorBase::GetSingleArgument<int>("num_records", 1)) {
    CAFFE_ENFORCE_GT(numRecords_, 0);
  }

  bool dequeueMany(std::shared_ptr<BlobsQueue>& queue) {
    auto size = queue->getNumBlobs();

    if (blobs_.size() != size) {
      blobs_.resize(size);
      blobPtrs_.resize(size);
      for (auto col : c10::irange(size)) {
        blobPtrs_.at(col) = &blobs_.at(col);
      }
    }

    const int kTensorGrowthPct = 40;
    for (const auto i : c10::irange(numRecords_)) {
      if (!queue->blockingRead(blobPtrs_)) {
        // if we read at least one record, status is still true
        return i > 0;
      }
      for (auto col : c10::irange(size)) {
        auto* out = this->Output(col);
        const auto& in = blobPtrs_.at(col)->template Get<Tensor>();
        if (i == 0) {
          out->CopyFrom(in);
        } else {
          auto oldSize = out->numel();

          CAFFE_ENFORCE(
              in.dim() > 0,
              "Empty tensor to dequeue at column ",
              col,
              " within ",
              size,
              " total columns");

          out->Extend(in.sizes()[0], kTensorGrowthPct);
          auto* dst =
              (char*)out->raw_mutable_data() + oldSize * in.dtype().itemsize();
          context_.template CopyItems<Context, Context>(
              in.meta(), in.numel(), in.raw_data(), dst);
        }
      }
    }
    return true;
  }

  bool dequeueOne(std::shared_ptr<BlobsQueue>& queue) {
    return queue->blockingRead(this->Outputs());
  }

  bool RunOnDevice() override {
    CAFFE_ENFORCE(InputSize() == 1);
    auto queue = Operator<Context>::Inputs()[0]
                     ->template Get<std::shared_ptr<BlobsQueue>>();
    CAFFE_ENFORCE(queue);

    auto size = queue->getNumBlobs();
    CAFFE_ENFORCE_EQ(OutputSize(), size + 1);

    bool status = numRecords_ > 1 ? dequeueMany(queue) : dequeueOne(queue);

    Output(size)->Resize();
    math::Set<bool, Context>(
        1, !status, Output(size)->template mutable_data<bool>(), &context_);
    return true;
  }

  void Cancel() override {
    auto queue = Operator<Context>::Inputs()[0]
                     ->template Get<std::shared_ptr<BlobsQueue>>();
    queue->close();
  }

 private:
  int numRecords_;
  std::vector<Blob> blobs_;
  std::vector<Blob*> blobPtrs_;
};

template <typename Context>
class WeightedSampleDequeueBlobsOp final : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;

  WeightedSampleDequeueBlobsOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws),
        table_idx_blob_(
            OperatorBase::GetSingleArgument<int>("table_idx_blob", -1)) {
    CAFFE_ENFORCE_LT(table_idx_blob_, OutputSize() - 1);
    vector<float> weights = OperatorBase::GetRepeatedArgument<float>("weights");
    if (weights.empty()) {
      weights.resize(InputSize(), 1.0f);
    }
    CAFFE_ENFORCE_EQ(InputSize(), weights.size());

    float sum = accumulate(weights.begin(), weights.end(), 0.0f);
    CAFFE_ENFORCE(sum > 0.0f, "Sum of weights must be positive");
    cumProbs_.resize(weights.size());
    for (auto i : c10::irange(weights.size())) {
      cumProbs_[i] = weights[i] / sum;
      CAFFE_ENFORCE_GE(
          cumProbs_[i], 0.0f, "Each probability must be non-negative");
    }
    std::partial_sum(cumProbs_.begin(), cumProbs_.end(), cumProbs_.begin());
    // Put last value to be 1.0001 to avoid numerical issues.
    cumProbs_.back() = 1.0001f;

    LOG(INFO) << "Dequeue weights: " << weights;
    LOG(INFO) << "cumProbs: " << cumProbs_;
  }

  bool RunOnDevice() override {
    float r;
    math::RandUniform<float, Context>(1, 0.0f, 1.0f, &r, &context_);
    auto lb = lower_bound(cumProbs_.begin(), cumProbs_.end(), r);
    CAFFE_ENFORCE(lb != cumProbs_.end(), "Cannot find ", r, " in cumProbs_.");
    const int32_t idx = lb - cumProbs_.begin();
    auto queue = Operator<Context>::Inputs()[idx]
                     ->template Get<std::shared_ptr<BlobsQueue>>();

    CAFFE_ENFORCE(queue);
    auto size = queue->getNumBlobs();
    CAFFE_ENFORCE_EQ(OutputSize(), size + 1);
    bool status = queue->blockingRead(this->Outputs());
    if (table_idx_blob_ >= 0) {
      auto* table_idx_blob_out =
          Output(table_idx_blob_, {1}, at::dtype<int32_t>());
      int32_t* data = table_idx_blob_out->template mutable_data<int32_t>();
      data[0] = idx;
    }

    Output(size)->Resize();
    math::Set<bool, Context>(
        1, !status, Output(size)->template mutable_data<bool>(), &context_);
    return true;
  }

 private:
  vector<float> cumProbs_;
  int table_idx_blob_;
};
} // namespace caffe2
