| #include "caffe2/operators/top_k.h" | 
 |  | 
 | #include <algorithm> | 
 | #include <functional> | 
 | #include <queue> | 
 | #include <utility> | 
 | #include <vector> | 
 |  | 
 | #include "caffe2/proto/caffe2_pb.h" | 
 | #include "caffe2/utils/math.h" | 
 |  | 
 | namespace caffe2 { | 
 |  | 
 | namespace { | 
 |  | 
 | template <typename T> | 
 | struct ValueComp { | 
 |   bool operator()( | 
 |       const std::pair<T, int64_t>& lhs, | 
 |       const std::pair<T, int64_t>& rhs) const { | 
 |     return lhs.first > rhs.first || | 
 |         (lhs.first == rhs.first && lhs.second < rhs.second); | 
 |   } | 
 | }; | 
 |  | 
 | template <typename T> | 
 | void GetTopK( | 
 |     const T* input, | 
 |     const int64_t n, | 
 |     const int64_t k, | 
 |     const int64_t src_offset, | 
 |     const int64_t dst_offset, | 
 |     const int64_t stride, | 
 |     T* values, | 
 |     int64_t* indices, | 
 |     int64_t* flatten_indices) { | 
 |   const T* src_ptr = input + src_offset; | 
 |   std::vector<std::pair<T, int64_t>> heap_data; | 
 |   heap_data.reserve(k); | 
 |   for (int64_t i = 0; i < k && i < n; ++i) { | 
 |     heap_data.emplace_back(*src_ptr, i); | 
 |     src_ptr += stride; | 
 |   } | 
 |   std::priority_queue< | 
 |       std::pair<T, int64_t>, | 
 |       std::vector<std::pair<T, int64_t>>, | 
 |       ValueComp<T>> | 
 |       pq(ValueComp<T>(), std::move(heap_data)); | 
 |   for (int64_t i = k; i < n; ++i) { | 
 |     if (pq.top().first < *src_ptr) { | 
 |       pq.pop(); | 
 |       pq.emplace(*src_ptr, i); | 
 |     } | 
 |     src_ptr += stride; | 
 |   } | 
 |   int64_t dst_pos = dst_offset + (std::min(k, n) - 1) * stride; | 
 |   while (!pq.empty()) { | 
 |     const auto& item = pq.top(); | 
 |     values[dst_pos] = item.first; | 
 |     indices[dst_pos] = item.second; | 
 |     if (flatten_indices != nullptr) { | 
 |       flatten_indices[dst_pos] = src_offset + item.second * stride; | 
 |     } | 
 |     pq.pop(); | 
 |     dst_pos -= stride; | 
 |   } | 
 | } | 
 |  | 
 | template <typename T> | 
 | void SetTopKGradient( | 
 |     const T* values, | 
 |     const int64_t* indices, | 
 |     const int k, | 
 |     const int64_t src_offset, | 
 |     const int64_t dst_offset, | 
 |     const int64_t stride, | 
 |     T* gradient) { | 
 |   int64_t src_pos = src_offset; | 
 |   for (int i = 0; i < k; ++i) { | 
 |     if (indices[src_pos] < 0) { | 
 |       continue; | 
 |     } | 
 |     gradient[dst_offset + indices[src_pos] * stride] = values[src_pos]; | 
 |     src_pos += stride; | 
 |   } | 
 | } | 
 |  | 
 | } // namespace | 
 |  | 
 | template <typename T, class Context> | 
 | bool TopKOp<T, Context>::RunOnDevice() { | 
 |   const auto& input = Input(0); | 
 |   auto* values = Output(0); | 
 |   auto* indices = Output(1); | 
 |   auto* flatten_indices = OutputSize() > 2 ? Output(2) : nullptr; | 
 |  | 
 |   int64_t k = k_; | 
 |   if(k == -1 && InputSize() == 2) { | 
 |     k = Input(1).template data<int64_t>()[0]; | 
 |   } | 
 |   CAFFE_ENFORCE(k >= 1, "k argument must be >= 1"); | 
 |  | 
 |   at::IntArrayRef input_dims = input.sizes(); | 
 |   if (axis_ == -1) { | 
 |     // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) | 
 |     axis_ = input_dims.size() - 1; | 
 |   } | 
 |   CAFFE_ENFORCE_GE(axis_, 0); | 
 |   CAFFE_ENFORCE_LT(axis_, input_dims.size()); | 
 |  | 
 |   std::vector<int64_t> output_dims = input_dims.vec(); | 
 |   output_dims[axis_] = k; | 
 |   values->Resize(output_dims); | 
 |   indices->Resize(output_dims); | 
 |   if (flatten_indices != nullptr) { | 
 |     flatten_indices->Resize(indices->numel()); | 
 |   } | 
 |   const T* input_data = input.template data<T>(); | 
 |   T* values_data = values->template mutable_data<T>(); | 
 |   int64_t* indices_data = indices->template mutable_data<int64_t>(); | 
 |   int64_t* flatten_indices_data = flatten_indices == nullptr | 
 |       ? nullptr | 
 |       : flatten_indices->template mutable_data<int64_t>(); | 
 |   // init values as the default value | 
 |   math::Set<T, Context>(values->numel(), T(0), values_data, &context_); | 
 |   math::Set<int64_t, Context>( | 
 |       indices->numel(), int64_t(-1), indices_data, &context_); | 
 |   if (flatten_indices_data != nullptr) { | 
 |     math::Set<int64_t, Context>( | 
 |         flatten_indices->numel(), int64_t(-1), flatten_indices_data, &context_); | 
 |   } | 
 |  | 
 |   const int64_t prev_size = std::accumulate( | 
 |       input_dims.cbegin(), | 
 |       input_dims.cbegin() + axis_, | 
 |       int64_t(1), | 
 |       // NOLINTNEXTLINE(modernize-use-transparent-functors) | 
 |       std::multiplies<int64_t>()); | 
 |   const int64_t next_size = std::accumulate( | 
 |       input_dims.cbegin() + axis_ + 1, | 
 |       input_dims.cend(), | 
 |       int64_t(1), | 
 |       // NOLINTNEXTLINE(modernize-use-transparent-functors) | 
 |       std::multiplies<int64_t>()); | 
 |   const int64_t src_offset_stride = input_dims[axis_] * next_size; | 
 |   const int64_t dst_offset_stride = k * next_size; | 
 |   int64_t src_offset = 0; | 
 |   int64_t dst_offset = 0; | 
 |   for (int64_t i = 0; i < prev_size; ++i) { | 
 |     for (int64_t j = 0; j < next_size; ++j) { | 
 |       GetTopK( | 
 |           input_data, | 
 |           input_dims[axis_], | 
 |           k, | 
 |           src_offset + j, | 
 |           dst_offset + j, | 
 |           next_size, | 
 |           values_data, | 
 |           indices_data, | 
 |           flatten_indices_data); | 
 |     } | 
 |     src_offset += src_offset_stride; | 
 |     dst_offset += dst_offset_stride; | 
 |   } | 
 |   return true; | 
 | } | 
 |  | 
 | template <typename T, class Context> | 
 | bool TopKGradientOp<T, Context>::RunOnDevice() { | 
 |   const auto& values = Input(0); | 
 |   const auto& indices = Input(1); | 
 |   const auto& original_input = Input(2); | 
 |   auto* output = Output(0); | 
 |   at::IntArrayRef values_dims = values.sizes(); | 
 |   at::IntArrayRef origin_dims = original_input.sizes(); | 
 |   CAFFE_ENFORCE_EQ(values_dims.size(), origin_dims.size()); | 
 |   output->Resize(origin_dims); | 
 |   const T* values_data = values.template data<T>(); | 
 |   const int64_t* indices_data = indices.template data<int64_t>(); | 
 |   T* output_data = output->template mutable_data<T>(); | 
 |   if (axis_ == -1) { | 
 |     // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) | 
 |     axis_ = values_dims.size() - 1; | 
 |   } | 
 |   const int k = values_dims[axis_]; | 
 |   math::Set<T, Context>(output->numel(), T(0), output_data, &context_); | 
 |   const int64_t prev_size = std::accumulate( | 
 |       values_dims.cbegin(), | 
 |       values_dims.cbegin() + axis_, | 
 |       int64_t(1), | 
 |       // NOLINTNEXTLINE(modernize-use-transparent-functors) | 
 |       std::multiplies<int64_t>()); | 
 |   const int64_t next_size = std::accumulate( | 
 |       values_dims.cbegin() + axis_ + 1, | 
 |       values_dims.cend(), | 
 |       int64_t(1), | 
 |       // NOLINTNEXTLINE(modernize-use-transparent-functors) | 
 |       std::multiplies<int64_t>()); | 
 |   const int64_t src_offset_stride = k * next_size; | 
 |   const int64_t dst_offset_stride = origin_dims[axis_] * next_size; | 
 |   int64_t src_offset = 0; | 
 |   int64_t dst_offset = 0; | 
 |   for (int64_t i = 0; i < prev_size; ++i) { | 
 |     for (int64_t j = 0; j < next_size; ++j) { | 
 |       SetTopKGradient( | 
 |           values_data, | 
 |           indices_data, | 
 |           k, | 
 |           src_offset + j, | 
 |           dst_offset + j, | 
 |           next_size, | 
 |           output_data); | 
 |     } | 
 |     src_offset += src_offset_stride; | 
 |     dst_offset += dst_offset_stride; | 
 |   } | 
 |   return true; | 
 | } | 
 |  | 
 | REGISTER_CPU_OPERATOR(TopK, TopKOp<float, CPUContext>); | 
 | REGISTER_CPU_OPERATOR(TopKGradient, TopKGradientOp<float, CPUContext>); | 
 |  | 
 | OPERATOR_SCHEMA(TopK) | 
 |     .NumInputs(1, 2) | 
 |     .NumOutputs(2, 3) | 
 |     .TensorInferenceFunction([](const OperatorDef& def, | 
 |                                 const vector<TensorShape>& in) { | 
 |       vector<TensorShape> out = {in[0], in[0]}; | 
 |       ArgumentHelper helper(def); | 
 |       auto k = helper.GetSingleArgument("k", -1); | 
 |       auto dims_size = in[0].dims_size(); | 
 |       out[0].set_dims(dims_size - 1, k); | 
 |       out[1].set_dims(dims_size - 1, k); | 
 |       out[1].set_data_type(TensorProto_DataType_INT32); | 
 |       if (def.output_size() > 2) { | 
 |         TensorShape flatten_indices_shape; | 
 |         flatten_indices_shape.set_data_type(TensorProto_DataType_INT32); | 
 |         flatten_indices_shape.add_dims( | 
 |             std::accumulate( | 
 |                 in[0].dims().begin(), | 
 |                 in[0].dims().end() - 1, | 
 |                 1, | 
 |                 // NOLINTNEXTLINE(modernize-use-transparent-functors) | 
 |                 std::multiplies<long>()) * | 
 |             k); | 
 |         out.push_back(flatten_indices_shape); | 
 |       } | 
 |       return out; | 
 |     }) | 
 |     .SetDoc(R"DOC( | 
 | Retrieve the top-K elements of the last dimension. | 
 | Given an input tensor of shape $(a_1, a_2, ..., a_n, r)$. `k` can be passed as an integer argument or a 1D tensor containing a single integer. | 
 | Returns up to three outputs: | 
 |  | 
 | 1. Value tensor of shape $(a_1, a_2, ..., a_n, k)$ which contains the values of the top k elements along the last dimension | 
 | 2. Index tensor of shape $(a_1, a_2, ..., a_n, k)$ which contains the indices of the top k elements (original indices from the input tensor). | 
 | 3. [OPTIONAL] Flattened index tensor of shape $(a_1 * a_2 * ... * a_n * k,)$. | 
 |  | 
 | Given two equivalent values, this operator uses the indices along the last dimension as a tiebreaker. That is, the element with the lower index will appear first. | 
 |  | 
 | Github Links: | 
 | - https://github.com/pytorch/pytorch/blob/main/caffe2/operators/top_k.cc | 
 |  | 
 |  | 
 | <details> | 
 |  | 
 | <summary> <b>Example</b> </summary> | 
 |  | 
 | **Code** | 
 |  | 
 | ``` | 
 |  | 
 | workspace.ResetWorkspace() | 
 |  | 
 | op = core.CreateOperator( | 
 |     "TopK", | 
 |     ["X"], | 
 |     ["Values", "Indices", "Flattened_indices"], | 
 |     k=2 | 
 | ) | 
 |  | 
 | workspace.FeedBlob("X", np.random.randint(10, size=(3,3,3)).astype(np.float32)) | 
 | print("X:", workspace.FetchBlob("X")) | 
 | workspace.RunOperatorOnce(op) | 
 | print("Values:", workspace.FetchBlob("Values")) | 
 | print("Indices:", workspace.FetchBlob("Indices")) | 
 | print("Flattened_indices:", workspace.FetchBlob("Flattened_indices")) | 
 |  | 
 | ``` | 
 |  | 
 | **Result** | 
 |  | 
 | ``` | 
 |  | 
 | X: | 
 | [[[6. 7. 0.] | 
 |   [8. 7. 7.] | 
 |   [1. 5. 6.]] | 
 |  | 
 |  [[0. 6. 1.] | 
 |   [2. 8. 4.] | 
 |   [1. 2. 9.]] | 
 |  | 
 |  [[4. 3. 7.] | 
 |   [0. 1. 7.] | 
 |   [0. 1. 8.]]] | 
 | Values: | 
 | [[[7. 6.] | 
 |   [8. 7.] | 
 |   [6. 5.]] | 
 |  | 
 |  [[6. 1.] | 
 |   [8. 4.] | 
 |   [9. 2.]] | 
 |  | 
 |  [[7. 4.] | 
 |   [7. 1.] | 
 |   [8. 1.]]] | 
 | Indices: | 
 | [[[1 0] | 
 |   [0 1] | 
 |   [2 1]] | 
 |  | 
 |  [[1 2] | 
 |   [1 2] | 
 |   [2 1]] | 
 |  | 
 |  [[2 0] | 
 |   [2 1] | 
 |   [2 1]]] | 
 | Flattened_indices: [ 1  0  3  4  8  7 10 11 13 14 17 16 20 18 23 22 26 25] | 
 |  | 
 | ``` | 
 |  | 
 | </details> | 
 |  | 
 |   )DOC") | 
 |     .Input( | 
 |       0, | 
 |       "X", | 
 |       "(*Tensor`<float>`*): input tensor of shape $(a_1, a_2, ..., a_n, r)$") | 
 |     .Input( | 
 |       1, | 
 |       "k", | 
 |       "(*int*): number of top elements to retrieve") | 
 |     .Output( | 
 |         0, | 
 |         "Values", | 
 |         "(*Tensor`<float>`*): output tensor of shape $(a_1, a_2, ..., a_n, k)$") | 
 |     .Output( | 
 |         1, | 
 |         "Indices", | 
 |         "(*Tensor`<int>`*): tensor of indices of shape $(a_1, a_2, ..., a_n, k)$; indices values refer to each element's index in the last dimension of the `X` input tensor") | 
 |     .Output( | 
 |         2, | 
 |         "Flattened_indices", | 
 |         "(*Tensor`<int>`*): tensor of indices of shape $(a_1 * a_2 * ... * a_n * k,)$; indices values refer to each element's index in the flattened input tensor `X`"); | 
 |  | 
 | OPERATOR_SCHEMA(TopKGradient).NumInputs(3).NumOutputs(1); | 
 |  | 
 | class GetTopKGradient : public GradientMakerBase { | 
 |   using GradientMakerBase::GradientMakerBase; | 
 |   vector<OperatorDef> GetGradientDefs() override { | 
 |     return SingleGradientDef( | 
 |         "TopKGradient", | 
 |         "", | 
 |         vector<string>{GO(0), O(1), I(0)}, | 
 |         vector<string>{GI(0)}); | 
 |   } | 
 | }; | 
 |  | 
 | REGISTER_GRADIENT(TopK, GetTopKGradient); | 
 |  | 
 | } // namespace caffe2 |