|  | #pragma once | 
|  |  | 
|  | #include "caffe2/core/operator.h" | 
|  | #include "c10/util/irange.h" | 
|  |  | 
|  | #include <cmath> | 
|  | #include <limits> | 
|  |  | 
|  | namespace caffe2 { | 
|  |  | 
|  | template <typename Context> | 
|  | class QuantileOp final : public Operator<Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  | QuantileOp(const OperatorDef& operator_def, Workspace* ws) | 
|  | : Operator<Context>(operator_def, ws), | 
|  | quantile_(this->template GetSingleArgument<float>("quantile", -1.0)), | 
|  | abs_(this->template GetSingleArgument<bool>("abs", true)), | 
|  | tol_(this->template GetSingleArgument<float>("tol", 1e-3)) { | 
|  | CAFFE_ENFORCE_GE( | 
|  | quantile_, | 
|  | 0, | 
|  | "input quantile should be ", | 
|  | "no less than 0, got ", | 
|  | quantile_); | 
|  | CAFFE_ENFORCE_GE( | 
|  | 1.0f, | 
|  | quantile_, | 
|  | "input quantile should be ", | 
|  | "no larger than 1, got ", | 
|  | quantile_); | 
|  | CAFFE_ENFORCE_GT( | 
|  | tol_, 0, "tolerance should be ", "no less than 0, got ", tol_); | 
|  | } | 
|  |  | 
|  | bool RunOnDevice() override { | 
|  | return DispatchHelper<TensorTypes<float, double>>::call(this, Input(0)); | 
|  | } | 
|  |  | 
|  | template <typename T> | 
|  | bool DoRunWithType() { | 
|  | Output(QUANTILE_VAL)->Resize(1); | 
|  | auto* quantile_val = Output(QUANTILE_VAL)->template mutable_data<T>(); | 
|  |  | 
|  | auto& input_zero = Input(0); | 
|  | int64_t numel = input_zero.numel(); | 
|  | for (const auto i : c10::irange(1, InputSize())) { | 
|  | CAFFE_ENFORCE_EQ( | 
|  | Input(i).dtype(), | 
|  | input_zero.dtype(), | 
|  | "All inputs must have the same type, expected: ", | 
|  | input_zero.dtype().name(), | 
|  | " but got: ", | 
|  | Input(i).dtype().name(), | 
|  | " for input: ", | 
|  | i); | 
|  | numel += Input(i).numel(); | 
|  | } | 
|  | CAFFE_ENFORCE_GT( | 
|  | numel, | 
|  | 0, | 
|  | "number of total element in input tensor should be ", | 
|  | "larger than 0, got ", | 
|  | numel); | 
|  |  | 
|  | // the expected number of elements lessEq to the target value | 
|  | const int64_t target_cnt = | 
|  | static_cast<int64_t>(std::ceil(numel * quantile_)); | 
|  |  | 
|  | T hi = 0.0; | 
|  | T lo = 0.0; | 
|  | GetRangeFromInputs(&lo, &hi); | 
|  | if (target_cnt == 0) { | 
|  | // lowest possible value | 
|  | quantile_val[0] = lo; | 
|  | return true; | 
|  | } | 
|  | if (target_cnt == numel) { | 
|  | // highest possible value | 
|  | quantile_val[0] = hi; | 
|  | return true; | 
|  | } | 
|  | int64_t lo_cnt = CountLowerEq(lo); | 
|  | if (lo_cnt >= target_cnt) { | 
|  | // the target is one of the lowest value | 
|  | quantile_val[0] = lo; | 
|  | return true; | 
|  | } | 
|  | while (std::abs(hi - lo) > tol_ * (std::abs(hi) + std::abs(lo))) { | 
|  | // keep hi_cnt > target_idx and lo_cnt <= target_idx | 
|  | const T mid = lo + (hi - lo) / 2.0; | 
|  | const int64_t mid_cnt = CountLowerEq(mid); | 
|  | if (mid_cnt > target_cnt) { | 
|  | CAFFE_ENFORCE_NE( | 
|  | hi, mid, "numeric precision at limit, unable to continue bisect"); | 
|  | hi = mid; | 
|  | } else if (mid_cnt < target_cnt) { | 
|  | CAFFE_ENFORCE_NE( | 
|  | lo, mid, "numeric precision at limit, unable to continue bisect"); | 
|  | lo = mid; | 
|  | } else { | 
|  | // mid_cnt == target_cnt | 
|  | quantile_val[0] = mid; | 
|  | return true; | 
|  | } | 
|  | } | 
|  | quantile_val[0] = hi; | 
|  | return true; | 
|  | } | 
|  |  | 
|  | protected: | 
|  | float quantile_; | 
|  | bool abs_; | 
|  | float tol_; | 
|  | OUTPUT_TAGS(QUANTILE_VAL); | 
|  |  | 
|  | template <typename T> | 
|  | void GetRangeFromInputs(T* lo, T* hi) { | 
|  | *hi = std::numeric_limits<T>::lowest(); | 
|  | *lo = std::numeric_limits<T>::max(); | 
|  | for (const auto i : c10::irange(InputSize())) { | 
|  | const auto* input = Input(i).template data<T>(); | 
|  | for (const auto j : c10::irange(Input(i).numel())) { | 
|  | const T val = abs_ ? std::abs(input[j]) : input[j]; | 
|  | if (*hi < val) { | 
|  | *hi = val; | 
|  | } | 
|  | if (*lo > val) { | 
|  | *lo = val; | 
|  | } | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | template <typename T> | 
|  | int64_t CountLowerEq(const T& thd) { | 
|  | int64_t count = 0; | 
|  | for (const auto i : c10::irange(InputSize())) { | 
|  | const auto* input = Input(i).template data<T>(); | 
|  | for (const auto j : c10::irange(Input(i).numel())) { | 
|  | const T val = abs_ ? std::abs(input[j]) : input[j]; | 
|  | if (val <= thd) { | 
|  | count++; | 
|  | } | 
|  | } | 
|  | } | 
|  | return count; | 
|  | } | 
|  | }; | 
|  |  | 
|  | } // namespace caffe2 |