blob: db571e3fea8006665d1e9c7b063010d3c8ebe958 [file] [log] [blame]
#include "caffe2/operators/top_k.h"
#include "caffe2/proto/caffe2.pb.h"
namespace caffe2 {
namespace {
REGISTER_CPU_OPERATOR(TopK, TopKOp<float, CPUContext>);
OPERATOR_SCHEMA(TopK)
.NumInputs(1)
.NumOutputs(2)
.TensorInferenceFunction(
[](const OperatorDef& def, const vector<TensorShape>& in) {
vector<TensorShape> out = {in[0], in[0]};
ArgumentHelper helper(def);
auto k = helper.GetSingleArgument("k", -1);
auto dims_size = in[0].dims_size();
out[0].set_dims(dims_size - 1, k);
out[1].set_dims(dims_size - 1, k);
out[1].set_data_type(TensorProto_DataType_INT32);
return out;
})
.SetDoc(R"DOC(
Retrieve the top-K elements for the last dimension. Given an input tensor of
shape [a_1, a_2, ..., a_n, r] and integer argument k, return two outputs:
-Value tensor of shape [a_1, a_2, ..., a_n, k] which contains the values of
the top k elements along the last dimension
-Index tensor of shape [a_1, a_2, ..., a_n, k] which contains the indices
of the top k elements (original indices from the input tensor).
Given two equivalent values, this operator uses the indices along the last dim-
ension as a tiebreaker. That is, the element with the lower index will appear
first.
)DOC")
.Input(0, "X", "Tensor of shape [a_1, a_2, ..., a_n, r]")
.Output(
0,
"Values",
"Tensor of shape [a_1, a_2, ..., a_n, k] containing"
" top K values from the input tensor")
.Output(
1,
"Indices",
"Tensor of shape [a_1, a_2, ..., a_n, k] containing"
" the corresponding input tensor indices for the top K values.")
.Arg("k", "Number of top elements to retrieve");
}
}