blob: 0190aaf6b5f20964b50793ee9e9532c6e4d6f735 [file] [log] [blame]
#ifndef CAFFE2_OPERATORS_MINMAX_OPS_H_
#define CAFFE2_OPERATORS_MINMAX_OPS_H_
#include "caffe2/core/common_omp.h"
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/types.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <typename T, class Context>
class MaxMinOpBase : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_SIMPLE_CTOR_DTOR(MaxMinOpBase)
bool RunOnDevice() override {
auto& input0 = Input(0);
auto* output = Output(0);
output->ResizeLike(input0);
output->CopyFrom(input0, &context_);
if (InputSize() == 1) {
return true;
}
// Dimension checking
for (int i = 1; i < InputSize(); ++i) {
CAFFE_ENFORCE_EQ(
output->sizes(),
Input(i).sizes(),
"Description: Input #",
i,
", input dimension:",
Input(i).sizes(),
" should match output dimension: ",
output->sizes());
}
return this->Compute();
}
virtual bool Compute() = 0;
};
template <typename T, class Context>
class MaxOp : public MaxMinOpBase<T, Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
MaxOp(const OperatorDef& operator_def, Workspace* ws)
: MaxMinOpBase<T, Context>(operator_def, ws) {}
virtual ~MaxOp() noexcept {}
bool Compute() override;
};
template <typename T, class Context>
class SelectGradientOpBase : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_SIMPLE_CTOR_DTOR(SelectGradientOpBase)
bool RunOnDevice() override;
};
template <typename T, class Context>
class MaxGradientOp : public SelectGradientOpBase<T, Context> {
public:
MaxGradientOp(const OperatorDef& operator_def, Workspace* ws)
: SelectGradientOpBase<T, Context>(operator_def, ws) {}
virtual ~MaxGradientOp() noexcept {}
};
template <typename T, class Context>
class MinOp : public MaxMinOpBase<T, Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
MinOp(const OperatorDef& operator_def, Workspace* ws)
: MaxMinOpBase<T, Context>(operator_def, ws) {}
virtual ~MinOp() noexcept {}
bool Compute() override;
};
template <typename T, class Context>
class MinGradientOp : public SelectGradientOpBase<T, Context> {
public:
MinGradientOp(const OperatorDef& operator_def, Workspace* ws)
: SelectGradientOpBase<T, Context>(operator_def, ws) {}
virtual ~MinGradientOp() noexcept {}
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_MINMAX_OPS_H_