blob: b6d6c2b0f0fdb8d6fd0b8a283da995d7063e10ad [file] [log] [blame]
#include "caffe2/operators/flexible_top_k.h"
#include "caffe2/proto/caffe2_pb.h"
namespace caffe2 {
namespace {
template <typename T>
struct ValueCmp {
bool operator()(
const std::pair<T, int64_t>& lhs,
const std::pair<T, int64_t>& rhs) {
return (
lhs.first > rhs.first ||
(lhs.first == rhs.first && lhs.second < rhs.second));
}
};
} // namespace
template <typename T, class Context>
bool FlexibleTopKOp<T, Context>::RunOnDevice() {
auto& input = Input(0);
auto& k = Input(1);
const T* input_data = input.template data<T>();
const int64_t* k_data = k.template data<int64_t>();
// get flatten shape of input
CAFFE_ENFORCE_GT(input.dim(), 0);
vector<int64_t> input_dims = input.sizes().vec();
vector<int64_t> linear_shape = {
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
size_to_dim_(input_dims.size() - 1, input_dims), input_dims.back()};
CAFFE_ENFORCE_EQ(
linear_shape[0],
k.numel(),
"first n-1 dims of input data and K does not match.");
int64_t output_size = 0;
for (int64_t i = 0; i < linear_shape[0]; ++i) {
CAFFE_ENFORCE(
linear_shape[1] >= k_data[i],
"k should not be greater than last dim, error at index ",
i,
", with value: ",
k_data[i]);
CAFFE_ENFORCE(
k_data[i] > 0,
"k should be greater than 0, error at index ",
i,
", with value: ",
k_data[i]);
output_size += k_data[i];
}
auto* values = Output(0, {output_size}, at::dtype<T>());
auto* indices = Output(1, {output_size}, at::dtype<int64_t>());
T* values_data = values->template mutable_data<T>();
int64_t* indices_data = indices->template mutable_data<int64_t>();
int64_t output_offset = 0;
// Sort preserving indices
for (int64_t i = 0; i < linear_shape[0]; ++i) {
// Build a min-heap, the heap element is pair of (value, idx)
// the top of the heap is the smallest value
std::priority_queue<
std::pair<T, int64_t>,
std::vector<std::pair<T, int64_t>>,
ValueCmp<T>>
PQ;
int64_t k_ = k_data[i];
for (int64_t j = 0; j < linear_shape[1]; ++j) {
const T value = input_data[i * linear_shape[1] + j];
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
if (PQ.size() < k_ || value > PQ.top().first) {
PQ.push(std::make_pair(value, j));
}
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
if (PQ.size() > k_) {
PQ.pop();
}
}
for (int64_t j = 0; j < k_; ++j) {
auto& pqElem = PQ.top();
values_data[output_offset + k_ - j - 1] = pqElem.first;
indices_data[output_offset + k_ - j - 1] = pqElem.second;
PQ.pop();
}
output_offset += k_;
}
return true;
}
template <typename T, class Context>
bool FlexibleTopKGradientOp<T, Context>::RunOnDevice() {
auto& original_input = Input(0);
auto& k = Input(1);
auto& values = Input(2);
auto& indices = Input(3);
const int64_t* k_data = k.template data<int64_t>();
const T* values_data = values.template data<T>();
const int64_t* indices_data = indices.template data<int64_t>();
// Resize output tensors to be as orignial_input size and initialized with 0
CAFFE_ENFORCE_GT(original_input.dim(), 0);
vector<int64_t> original_dims = original_input.sizes().vec();
auto* output = Output(0, original_dims, at::dtype<T>());
T* output_data = output->template mutable_data<T>();
math::Set<T, Context>(
output->numel(), static_cast<T>(0), output_data, &context_);
int64_t index_offset = 0;
for (int64_t i = 0; i < k.numel(); ++i) {
// offset of output_data
int64_t output_offset = i * original_dims.back();
for (int64_t j = 0; j < k_data[i]; ++j) {
int64_t index = indices_data[index_offset + j];
T value = values_data[index_offset + j];
output_data[output_offset + index] = value;
}
index_offset += k_data[i];
}
return true;
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_CPU_OPERATOR(FlexibleTopK, FlexibleTopKOp<float, CPUContext>);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_CPU_OPERATOR(
FlexibleTopKGradient,
FlexibleTopKGradientOp<float, CPUContext>);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
OPERATOR_SCHEMA(FlexibleTopK)
.NumInputs(2)
.NumOutputs(2)
.SetDoc(R"DOC(
Given two tensors: X and K,
retrieve the top K[..., 1] elements from X on the last dimension.
X is an input tensor of shape [a_1, a_2, ..., a_n, r].
K is an input tensor of shape [a_1, a_2, ..., a_n, 1],
where for each element, r >= K[..., 1] > 0
Output two outputs:
-Flatten values tensor of shape [ \sum_i K[i, 1] ] which contains the values of
the top K[..., 1] elements along the last dimension
-Flatten indices tensor of shape [ \sum_i K[i, 1] ] which contains the indices
of the top K[..., 1] elements, flatten indices from the input tensor).
These two outputs should be used with the input K, so that we know which indices
in X are picked.
Given two equivalent values, this operator uses the indices along the last dim-
ension as a tiebreaker. That is, the element with the lower index will appear
first.
)DOC")
.Input(0, "X", "Tensor of shape [a_1, a_2, ..., a_n, r]")
.Input(1, "K", "Tensor of shape [a_1, a_2, ..., a_n, 1]")
.Output(
0,
"Flatten values",
"Tensor of shape [ \\sum_i K[i, 1] ] containing"
" top K[..., 1] values from the input tensor")
.Output(
1,
"Flatten indices",
"Tensor of shape [ \\sum_i K[i, 1] ] containing the indices "
"into the flatten input");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
OPERATOR_SCHEMA(FlexibleTopKGradient).NumInputs(4).NumOutputs(1);
class GetFlexibleTopKGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
return SingleGradientDef(
"FlexibleTopKGradient",
"",
vector<string>{I(0), I(1), GO(0), O(1)},
vector<string>{GI(0)});
}
};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
REGISTER_GRADIENT(FlexibleTopK, GetFlexibleTopKGradient);
} // namespace caffe2