blob: 79f1a2054ac03f9a14c304658a004e78ebcf8328 [file] [log] [blame]
#ifndef CAFFE2_OPERATORS_TOP_K_H_
#define CAFFE2_OPERATORS_TOP_K_H_
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
namespace {
template <typename T>
struct ValueCmp {
bool operator()(
const std::pair<T, TIndex>& lhs,
const std::pair<T, TIndex>& rhs) {
return (
lhs.first > rhs.first ||
(lhs.first == rhs.first && lhs.second < rhs.second));
}
};
// Define these two names to allow lookup into the 2d tensors like
// mytensor(i, j)
template <typename T>
using EigenMatrixMapRowMajor = Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
template <typename T>
using ConstEigenMatrixMapRowMajor = Eigen::Map<
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
} // namespace
template <typename T, class Context>
class TopKOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
TopKOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws), OP_SINGLE_ARG(int, "k", k_, -1) {
CAFFE_ENFORCE(k_ >= 1, "k argument must be >= 1");
}
bool RunOnDevice() override {
auto& input = Input(0);
auto* values = Output(0);
auto* indices = Output(1);
vector<TIndex> in_dims = input.dims();
// Linearize input tensor except for last dimension
// e.g. [3, 4, 5] -> [12, 5]
// [5] -> [5]
vector<TIndex> linear_shape = {size_to_dim_(in_dims.size() - 1, in_dims),
in_dims[in_dims.size() - 1]};
auto input_map = ConstEigenMatrixMapRowMajor<T>(
static_cast<const T*>(input.raw_data()),
linear_shape[0],
linear_shape[1]);
// Resize output tensors to be the same shape as the linearized input except
// for the last dimension, which will be of size k. E.x. for an input tensor
// of shape [3, 4, 5] and k=2, both of these will be shape [3, 4, 2]
vector<TIndex> output_linear_shape = {linear_shape[0], k_};
values->Resize(output_linear_shape);
indices->Resize(output_linear_shape);
// Use Eigen maps to allow indexing into the 2d tensors like values_map(i,j)
auto values_map = EigenMatrixMapRowMajor<T>(
values->template mutable_data<T>(), linear_shape[0], k_);
auto indices_map = EigenMatrixMapRowMajor<TIndex>(
indices->template mutable_data<TIndex>(), linear_shape[0], k_);
// Sort preserving indices
for (TIndex i = 0; i < linear_shape[0]; ++i) {
// Build a min-heap, the heap element is pair of (value, idx)
// the top of the heap is the smallest value
std::priority_queue<
std::pair<T, TIndex>,
std::vector<std::pair<T, TIndex>>,
ValueCmp<T>>
PQ;
// Maintain the size of heap to be less or equal to k_, so the
// heap will hold the k_ largest values
for (TIndex j = 0; j < linear_shape[1]; ++j) {
const auto value = input_map(i, j);
if (PQ.size() < k_ || value > PQ.top().first) {
PQ.push(std::make_pair(value, j));
}
if (PQ.size() > k_) {
PQ.pop();
}
}
for (TIndex j = 0; j < k_; ++j) {
auto& pqElem = PQ.top();
values_map(i, k_ - j - 1) = pqElem.first;
indices_map(i, k_ - j - 1) = pqElem.second;
PQ.pop();
}
}
// Reshape output tensors to [a_1, a_2, ..., a_n, k]
auto out_dims = in_dims;
out_dims[out_dims.size() - 1] = k_;
values->Reshape(out_dims);
indices->Reshape(out_dims);
return true;
}
private:
int k_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_TOP_K_H_