blob: 9f5a090c53fde4fc65e35ffb768fdd9dccaf1d05 [file] [log] [blame]
#ifndef CAFFE2_OPERATORS_ARG_OPS_H_
#define CAFFE2_OPERATORS_ARG_OPS_H_
#include <algorithm>
#include <iterator>
#include <vector>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/types.h"
namespace caffe2 {
template <typename T, class Context>
class ArgOpBase : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
ArgOpBase(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
OP_SINGLE_ARG(int, "axis", axis_, -1),
OP_SINGLE_ARG(bool, "keepdims", keep_dims_, true) {}
bool RunOnDevice() override {
const auto& X = Input(0);
auto* Y = Output(0);
const int ndim = X.ndim();
if (axis_ == -1) {
axis_ = ndim - 1;
}
CAFFE_ENFORCE_GE(axis_, 0);
CAFFE_ENFORCE_LT(axis_, ndim);
const std::vector<TIndex>& X_dims = X.dims();
std::vector<TIndex> Y_dims;
Y_dims.reserve(ndim);
TIndex prev_size = 1;
TIndex next_size = 1;
for (int i = 0; i < axis_; ++i) {
Y_dims.push_back(X_dims[i]);
prev_size *= X_dims[i];
}
if (keep_dims_) {
Y_dims.push_back(1);
}
for (int i = axis_ + 1; i < ndim; ++i) {
Y_dims.push_back(X_dims[i]);
next_size *= X_dims[i];
}
Y->Resize(Y_dims);
const TIndex n = X_dims[axis_];
return Compute(
X.template data<T>(),
prev_size,
next_size,
n,
Y->template mutable_data<TIndex>());
}
protected:
virtual bool Compute(
const T* X,
const TIndex prev_size,
const TIndex next_size,
const TIndex n,
TIndex* Y) = 0;
private:
int axis_;
const bool keep_dims_;
};
template <typename T, class Context>
class ArgMaxOp final : public ArgOpBase<T, Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
ArgMaxOp(const OperatorDef& operator_def, Workspace* ws)
: ArgOpBase<T, Context>(operator_def, ws) {}
protected:
bool Compute(
const T* X,
const TIndex prev_size,
const TIndex next_size,
const TIndex n,
TIndex* Y) override;
};
template <typename T, class Context>
class ArgMinOp final : public ArgOpBase<T, Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
ArgMinOp(const OperatorDef& operator_def, Workspace* ws)
: ArgOpBase<T, Context>(operator_def, ws) {}
protected:
bool Compute(
const T* X,
const TIndex prev_size,
const TIndex next_size,
const TIndex n,
TIndex* Y) override;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_ARG_OPS_H_