blob: a6520008b3d646d8ce497b262f29fe734016edb6 [file] [log] [blame]
#include "caffe2/operators/segment_reduction_op.h"
namespace caffe2 {
namespace {
template <typename Def>
string FormatDoc() {
string doc = Def::doc;
ReplaceAll(doc, "{op}", Def::OpDef::name);
ReplaceAll(doc, "{op_doc}", Def::OpDef::doc);
return doc;
}
#define REGISTER_SEGMENT_DEF(...) \
REGISTER_CPU_OPERATOR_STR( \
string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name), \
__VA_ARGS__::ForwardOp); \
OPERATOR_SCHEMA_STR( \
string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name)) \
.NumInputs(__VA_ARGS__::ForwardOp::kNumInputs) \
.NumOutputs(1) \
.SetDoc(FormatDoc<__VA_ARGS__>()) \
.Output(0, "OUTPUT", "Aggregated tensor") \
.FillUsing(__VA_ARGS__::PopulateSchema); \
REGISTER_CPU_OPERATOR_STR( \
string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + "Gradient", \
__VA_ARGS__::BackwardOp); \
OPERATOR_SCHEMA_STR( \
string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + "Gradient") \
.NumInputs(__VA_ARGS__::BackwardOp::kNumInputs) \
.NumOutputs(1); \
REGISTER_GRADIENT_STR( \
string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name), \
__VA_ARGS__::GetGradient)
REGISTER_SEGMENT_DEF(
AbstractSortedSegmentRangeDef<float, int, CPUContext, SumRangeReducerDef>);
REGISTER_SEGMENT_DEF(AbstractSortedSegmentRangeDef<
float,
int,
CPUContext,
LogSumExpRangeReducerDef>);
REGISTER_SEGMENT_DEF(AbstractSortedSegmentRangeDef<
float,
int,
CPUContext,
LogMeanExpRangeReducerDef>);
REGISTER_SEGMENT_DEF(
AbstractSortedSegmentRangeDef<float, int, CPUContext, MeanRangeReducerDef>);
REGISTER_SEGMENT_DEF(
AbstractSortedSegmentRangeDef<float, int, CPUContext, MaxRangeReducerDef>);
#define REGISTER_REDUCER_WITH_ALL_OPS(reducer_def) \
REGISTER_SEGMENT_DEF( \
AbstractReduceFrontDef<float, CPUContext, reducer_def>); \
REGISTER_SEGMENT_DEF( \
AbstractSortedSegmentDef<float, int, CPUContext, reducer_def>); \
REGISTER_SEGMENT_DEF( \
AbstractSparseSortedSegmentDef<float, int, CPUContext, reducer_def>); \
REGISTER_SEGMENT_DEF( \
AbstractUnsortedSegmentDef<float, int, CPUContext, reducer_def>); \
REGISTER_SEGMENT_DEF( \
AbstractSparseUnsortedSegmentDef<float, int, CPUContext, reducer_def>) \
REGISTER_SEGMENT_DEF( \
AbstractLengthsDef<float, int, CPUContext, reducer_def>) \
REGISTER_SEGMENT_DEF( \
AbstractSparseLengthsDef<float, int, CPUContext, reducer_def>)
REGISTER_REDUCER_WITH_ALL_OPS(SumReducerDef);
REGISTER_REDUCER_WITH_ALL_OPS(WeightedSumReducerDef);
REGISTER_REDUCER_WITH_ALL_OPS(MeanReducerDef);
REGISTER_SEGMENT_DEF(AbstractReduceBackDef<float, CPUContext, SumReducerDef>);
REGISTER_SEGMENT_DEF(AbstractReduceBackDef<float, CPUContext, MeanReducerDef>);
// Auxiliary output gradients are currently implemented only for Lengths version
#define REGISTER_GRADIENT_WITH_MAIN_INPUT(...) \
REGISTER_CPU_OPERATOR_STR( \
string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + \
"WithMainInputGradient", \
__VA_ARGS__::WithMainInputBackwardOp); \
OPERATOR_SCHEMA_STR( \
string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + \
"WithMainInputGradient") \
.NumInputs(__VA_ARGS__::WithMainInputBackwardOp::kNumInputs) \
.NumOutputs(1, INT_MAX)
REGISTER_GRADIENT_WITH_MAIN_INPUT(
AbstractLengthsDef<float, int, CPUContext, WeightedSumReducerDef>);
REGISTER_GRADIENT_WITH_MAIN_INPUT(
AbstractSparseLengthsDef<float, int, CPUContext, WeightedSumReducerDef>);
}
}