blob: 838426c341411dbff288c9c4b20b5292cb22539e [file] [log] [blame]
#include "caffe2/operators/segment_reduction_op.h"
namespace caffe2 {
// registering 4 input gradient with main output
OPERATOR_SCHEMA(SparseLengthsIndicesInGradientWeightedSumWithMainInputGradient)
.NumInputs(5)
.NumOutputs(2);
REGISTER_CPU_OPERATOR(
SparseLengthsIndicesInGradientWeightedSumWithMainInputGradient,
AbstractLengthsWithMainInputGradientOp<
float,
int,
CPUContext,
WeightedSumReducerDef::template ReducerGradient<float, CPUContext>,
true /*SparseFused*/,
true /*GradientNeedIndices*/>);
// registering 4 input version
OPERATOR_SCHEMA(SparseLengthsIndicesInGradientWeightedSumGradient)
.NumInputs(4)
.NumOutputs(1);
REGISTER_CPU_OPERATOR(
SparseLengthsIndicesInGradientWeightedSumGradient,
AbstractLengthsGradientOp<
float,
int,
CPUContext,
WeightedSumReducerDef::template ReducerGradient<float, CPUContext>,
true /*GradientNeedIndices*/>);
// registering 3 input version
OPERATOR_SCHEMA(SparseLengthsIndicesInGradientSumGradient)
.NumInputs(3)
.NumOutputs(1);
REGISTER_CPU_OPERATOR(
SparseLengthsIndicesInGradientSumGradient,
AbstractLengthsGradientOp<
float,
int,
CPUContext,
SumReducerDef::template ReducerGradient<float, CPUContext>,
true /*GradientNeedIndices*/>);
OPERATOR_SCHEMA(LengthsIndicesInGradientSumGradient).NumInputs(3).NumOutputs(1);
REGISTER_CPU_OPERATOR(
LengthsIndicesInGradientSumGradient,
AbstractLengthsGradientOp<
float,
int,
CPUContext,
SumReducerDef::template ReducerGradient<float, CPUContext>,
true /*GradientNeedIndices*/>);
namespace {
template <typename Def>
string FormatDoc() {
string doc = Def::doc;
ReplaceAll(doc, "{op}", Def::OpDef::name);
ReplaceAll(doc, "{op_doc}", Def::OpDef::doc);
return doc;
}
// Helper macro when the main op is defined elsewhere, and we only need to
// define the schema, and the gradient op.
#define REGISTER_SEGMENT_DEF_SCHEMA_GRADIENT_ONLY(...) \
OPERATOR_SCHEMA_STR( \
string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name)) \
.NumInputs(__VA_ARGS__::ForwardOp::kNumInputs) \
.NumOutputs(1) \
.SetDoc(FormatDoc<__VA_ARGS__>()) \
.Output(0, "OUTPUT", "Aggregated tensor") \
.FillUsing(__VA_ARGS__::PopulateSchema); \
REGISTER_CPU_OPERATOR_STR( \
string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + "Gradient", \
__VA_ARGS__::BackwardOp); \
OPERATOR_SCHEMA_STR( \
string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + "Gradient") \
.NumInputs(__VA_ARGS__::BackwardOp::kNumInputs) \
.NumOutputs(1); \
REGISTER_GRADIENT_STR( \
string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name), \
__VA_ARGS__::GetGradient)
#define REGISTER_SEGMENT_DEF(...) \
REGISTER_CPU_OPERATOR_STR( \
string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name), \
__VA_ARGS__::ForwardOp); \
REGISTER_SEGMENT_DEF_SCHEMA_GRADIENT_ONLY(__VA_ARGS__)
REGISTER_SEGMENT_DEF(
AbstractSortedSegmentRangeDef<float, int, CPUContext, SumRangeReducerDef>);
REGISTER_SEGMENT_DEF(AbstractSortedSegmentRangeDef<
float,
int,
CPUContext,
LogSumExpRangeReducerDef>);
REGISTER_SEGMENT_DEF(AbstractSortedSegmentRangeDef<
float,
int,
CPUContext,
LogMeanExpRangeReducerDef>);
REGISTER_SEGMENT_DEF(
AbstractSortedSegmentRangeDef<float, int, CPUContext, MeanRangeReducerDef>);
REGISTER_SEGMENT_DEF(
AbstractSortedSegmentRangeDef<float, int, CPUContext, MaxRangeReducerDef>);
#define REGISTER_REDUCER_WITH_OPS(reducer_def) \
REGISTER_SEGMENT_DEF( \
AbstractReduceFrontDef<float, CPUContext, reducer_def>); \
REGISTER_SEGMENT_DEF( \
AbstractSortedSegmentDef<float, int, CPUContext, reducer_def>); \
REGISTER_SEGMENT_DEF( \
AbstractSparseSortedSegmentDef<float, int, CPUContext, reducer_def>); \
REGISTER_SEGMENT_DEF( \
AbstractUnsortedSegmentDef<float, int, CPUContext, reducer_def>); \
REGISTER_SEGMENT_DEF( \
AbstractSparseUnsortedSegmentDef<float, int, CPUContext, reducer_def>)
#define REGISTER_REDUCER_WITH_LENGTH_OPS(reducer_def, GradientNeedIndices) \
REGISTER_SEGMENT_DEF(AbstractLengthsDef< \
float, \
int, \
CPUContext, \
reducer_def, \
GradientNeedIndices>)
#define REGISTER_REDUCER_WITH_ALL_OPS(reducer_def) \
REGISTER_REDUCER_WITH_OPS(reducer_def) \
REGISTER_REDUCER_WITH_LENGTH_OPS(reducer_def, false)
REGISTER_REDUCER_WITH_OPS(SumReducerDef);
REGISTER_REDUCER_WITH_LENGTH_OPS(SumReducerDef, true);
REGISTER_REDUCER_WITH_ALL_OPS(WeightedSumReducerDef);
REGISTER_REDUCER_WITH_ALL_OPS(MeanReducerDef);
// SparseLengths[Sum,WeightedSum,Mean] are now implemented separately,
// so we only rely to the historical implementation for the backward + schema.
REGISTER_SEGMENT_DEF_SCHEMA_GRADIENT_ONLY(AbstractSparseLengthsDef<
float,
int,
CPUContext,
SumReducerDef,
true /*GradientNeedIndices*/>)
REGISTER_SEGMENT_DEF_SCHEMA_GRADIENT_ONLY(AbstractSparseLengthsDef<
float,
int,
CPUContext,
WeightedSumReducerDef,
true /*GradientNeedIndices*/>)
REGISTER_SEGMENT_DEF_SCHEMA_GRADIENT_ONLY(
AbstractSparseLengthsDef<float, int, CPUContext, MeanReducerDef>)
REGISTER_SEGMENT_DEF(AbstractReduceBackDef<float, CPUContext, SumReducerDef>);
REGISTER_SEGMENT_DEF(AbstractReduceBackDef<float, CPUContext, MeanReducerDef>);
// Auxiliary output gradients are currently implemented only for Lengths version
#define REGISTER_GRADIENT_WITH_MAIN_INPUT(...) \
REGISTER_CPU_OPERATOR_STR( \
string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + \
"WithMainInputGradient", \
__VA_ARGS__::WithMainInputBackwardOp); \
OPERATOR_SCHEMA_STR( \
string(__VA_ARGS__::basename) + (__VA_ARGS__::OpDef::name) + \
"WithMainInputGradient") \
.NumInputs(__VA_ARGS__::WithMainInputBackwardOp::kNumInputs) \
.NumOutputs(1, INT_MAX)
REGISTER_GRADIENT_WITH_MAIN_INPUT(
AbstractLengthsDef<float, int, CPUContext, WeightedSumReducerDef>);
REGISTER_GRADIENT_WITH_MAIN_INPUT(
AbstractSparseLengthsDef<float, int, CPUContext, WeightedSumReducerDef>);
}
}