blob: a328f7a55fc1c6ae054d68e9eb20040a997d04b4 [file] [log] [blame]
#ifndef CAFFE2_OPERATORS_REDUCTION_OPS_H_
#define CAFFE2_OPERATORS_REDUCTION_OPS_H_
#include "caffe2/core/common_omp.h"
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <typename T, class Context>
class SumElementsOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SumElementsOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
average_(OperatorBase::GetSingleArgument<bool>("average", false)) {}
SumElementsOp(const OperatorDef& operator_def, Workspace* ws, bool average)
: Operator<Context>(operator_def, ws), average_(average) {}
~SumElementsOp() {}
bool RunOnDevice() override {
auto& X = Input(0);
auto* sum = Output(0);
sum->Resize(vector<TIndex>());
T* data = sum->template mutable_data<T>();
math::Sum<T, Context>(X.size(), X.template data<T>(), data, &context_);
if (average_) {
math::Scale<T, Context>(
1,
static_cast<T>(1.) / X.size(),
sum->template data<T>(),
data,
&context_);
}
return true;
}
private:
bool average_;
};
template <typename T, class Context>
class SumElementsGradientOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SumElementsGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
average_(OperatorBase::GetSingleArgument<bool>("average", false)) {}
SumElementsGradientOp(
const OperatorDef& operator_def,
Workspace* ws,
bool average)
: Operator<Context>(operator_def, ws), average_(average) {}
~SumElementsGradientOp() {}
bool RunOnDevice() override;
private:
bool average_;
};
template <typename T, class Context>
class SumSqrElementsOp : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SumSqrElementsOp)
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override {
bool average = OperatorBase::GetSingleArgument<bool>("average", false);
auto& X = Input(0);
auto* sum = Output(0);
sum->Resize(vector<TIndex>());
math::SumSqr<T, Context>(
X.size(),
X.template data<T>(),
sum->template mutable_data<T>(),
&context_);
if (average) {
math::Scale<T, Context>(
1,
static_cast<T>(1.) / X.size(),
sum->template data<T>(),
sum->template mutable_data<T>(),
&context_);
}
return true;
}
};
} // namespace caffe2
#endif