blob: 0a3e700b4b887200f16192d7a6f990dc60b02a48 [file] [log] [blame]
#ifndef CAFFE2_OPERATORS_REDUCTION_OPS_H_
#define CAFFE2_OPERATORS_REDUCTION_OPS_H_
#include "caffe2/core/common_omp.h"
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <typename T, class Context>
class SumElementsOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SumElementsOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
average_(this->template GetSingleArgument<bool>("average", false)) {}
SumElementsOp(const OperatorDef& operator_def, Workspace* ws, bool average)
: Operator<Context>(operator_def, ws), average_(average) {}
~SumElementsOp() {}
bool RunOnDevice() override {
auto& X = Input(0);
auto* sum = Output(0);
sum->Resize(vector<int64_t>());
T* data = sum->template mutable_data<T>();
math::Sum<T, Context>(
X.numel(), X.template data<T>(), data, &context_, &scratch_);
if (average_ && X.numel() > 0) {
math::Scale<float, T, Context>(
1,
static_cast<T>(1.) / X.numel(),
sum->template data<T>(),
data,
&context_);
}
return true;
}
private:
bool average_;
Tensor scratch_{Context::GetDeviceType()};
};
template <typename T, class Context>
class SumElementsIntOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SumElementsIntOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {}
~SumElementsIntOp() {}
bool RunOnDevice() override {
auto& X = Input(0);
auto* sum = Output(0);
sum->Resize(vector<int64_t>());
T* data = sum->template mutable_data<T>();
math::Sum<T, Context>(
X.numel(), X.template data<T>(), data, &context_, &scratch_);
return true;
}
private:
Tensor scratch_{Context::GetDeviceType()};
};
template <typename T, class Context>
class SumElementsGradientOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SumElementsGradientOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
average_(this->template GetSingleArgument<bool>("average", false)) {}
SumElementsGradientOp(
const OperatorDef& operator_def,
Workspace* ws,
bool average)
: Operator<Context>(operator_def, ws), average_(average) {}
~SumElementsGradientOp() {}
bool RunOnDevice() override;
private:
bool average_;
};
template <class Context>
class SumSqrElementsOp : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(SumSqrElementsOp)
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
}
template <typename T>
bool DoRunWithType() {
bool average = this->template GetSingleArgument<bool>("average", false);
auto& X = Input(0);
auto* sum = Output(0);
sum->Resize(vector<int64_t>());
math::SumSqr<T, Context>(
X.numel(),
X.template data<T>(),
sum->template mutable_data<T>(),
&context_,
&scratch_);
if (average && X.numel() > 0) {
math::Scale<float, T, Context>(
1,
float(1.) / X.numel(),
sum->template data<T>(),
sum->template mutable_data<T>(),
&context_);
}
return true;
}
private:
Tensor scratch_{Context::GetDeviceType()};
};
template <typename T, class Context, bool ROWWISE>
class MaxReductionOp : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(MaxReductionOp)
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override {
auto& X = Input(0);
CAFFE_ENFORCE_EQ(X.ndim(), 3);
const int batch_size = X.dim32(0);
const int M = X.dim32(1);
const int N = X.dim32(2);
auto* Y = Output(0);
ROWWISE ? Y->Resize(batch_size, M) : Y->Resize(batch_size, N);
if (ROWWISE) {
math::RowwiseMax<T, Context>(
batch_size * M,
N,
X.template data<T>(),
Y->template mutable_data<T>(),
&context_);
} else {
const int input_size = N * M;
for (int i = 0; i < batch_size; ++i) {
math::ColwiseMax<T, Context>(
M,
N,
X.template data<T>() + i * input_size,
Y->template mutable_data<T>() + i * N,
&context_);
}
}
return true;
}
};
template <typename T, class Context, bool ROWWISE>
class MaxReductionGradientOp : public Operator<Context> {
public:
USE_SIMPLE_CTOR_DTOR(MaxReductionGradientOp)
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override;
};
} // namespace caffe2
#endif