blob: 275ddf683aa43b05d6201920906c6907933269d4 [file] [log] [blame]
#include "lengths_reducer_ops.h"
#include "caffe2/operators/segment_reduction_op.h"
namespace caffe2 {
// Use _STR option because the schema is declared using _STR version too in
// generic fashion. Otherwise it'd break schema declaration check.
// TODO(dzhulgakov): remove _STR when all lengths ops are off generic version.
using SparseLengthsSumOp =
SparseLengthsReductionFakeFp16Op<TensorTypes<float, at::Half>, 0, 0>;
using SparseLengthsWeightedSumOp =
SparseLengthsReductionFakeFp16Op<TensorTypes<float, at::Half>, 1, 0>;
using SparseLengthsMeanOp =
SparseLengthsReductionFakeFp16Op<TensorTypes<float, at::Half>, 0, 1>;
using SparseLengthsSumAccFP16Op =
SparseLengthsReductionFakeFp16Op<TensorTypes<float, at::Half>, 0, 0, 0, 1>;
using SparseLengthsWeightedSumAccFP16Op =
SparseLengthsReductionFakeFp16Op<TensorTypes<float, at::Half>, 1, 0, 0, 1>;
using SparseLengthsMeanAccFP16Op =
SparseLengthsReductionFakeFp16Op<TensorTypes<float, at::Half>, 0, 1, 0, 1>;
using SparseLengthsSumFakeFP16EmbeddingOnlyOp =
SparseLengthsReductionFakeFp16Op<
TensorTypes<float, at::Half>,
0,
0,
0,
0,
1>;
using SparseLengthsWeightedSumFakeFP16EmbeddingOnlyOp =
SparseLengthsReductionFakeFp16Op<
TensorTypes<float, at::Half>,
1,
0,
0,
0,
1>;
using SparseLengthsMeanFakeFP16EmbeddingOnlyOp =
SparseLengthsReductionFakeFp16Op<
TensorTypes<float, at::Half>,
0,
1,
0,
0,
1>;
REGISTER_CPU_OPERATOR(SparseLengthsSumFakeFP16, SparseLengthsSumOp);
REGISTER_CPU_OPERATOR(
SparseLengthsWeightedSumFakeFP16,
SparseLengthsWeightedSumOp);
REGISTER_CPU_OPERATOR(SparseLengthsMeanFakeFP16, SparseLengthsMeanOp);
REGISTER_CPU_OPERATOR(
SparseLengthsSumFakeFP16AccFP16,
SparseLengthsSumAccFP16Op);
REGISTER_CPU_OPERATOR(
SparseLengthsWeightedSumFakeFP16AccFP16,
SparseLengthsWeightedSumAccFP16Op);
REGISTER_CPU_OPERATOR(
SparseLengthsMeanFakeFP16AccFP16,
SparseLengthsMeanAccFP16Op);
REGISTER_CPU_OPERATOR(
SparseLengthsSumFakeFP16EmbeddingOnly,
SparseLengthsSumFakeFP16EmbeddingOnlyOp);
REGISTER_CPU_OPERATOR(
SparseLengthsWeightedSumFakeFP16EmbeddingOnly,
SparseLengthsWeightedSumFakeFP16EmbeddingOnlyOp);
REGISTER_CPU_OPERATOR(
SparseLengthsMeanFakeFP16EmbeddingOnly,
SparseLengthsMeanFakeFP16EmbeddingOnlyOp);
template <typename Def>
string FormatDoc() {
string doc = Def::doc;
c10::ReplaceAll(doc, "{op}", Def::OpDef::name);
c10::ReplaceAll(doc, "{op_doc}", Def::OpDef::doc);
auto replaced = c10::ReplaceAll(doc, "{extra}", "");
CAFFE_ENFORCE_EQ(replaced, 0);
return doc;
}
using SparseLengthsSumDef = AbstractSparseLengthsDef<
float,
int,
CPUContext,
SumReducerDef,
true /*GradientNeedIndices*/>;
OPERATOR_SCHEMA(SparseLengthsSumFakeFP16)
.NumInputs(SparseLengthsSumDef::ForwardOp::kNumInputs)
.NumOutputs(1)
.ValueKeyLengthInputFillers(
SparseLengthsSumOp::DATA,
SparseLengthsSumOp::INDICES,
SparseLengthsSumOp::LENGTHS)
.SetDoc(FormatDoc<SparseLengthsSumDef>())
.Output(0, "OUTPUT", "Aggregated tensor")
.FillUsing(SparseLengthsSumDef::PopulateSchema)
.InheritOnnxSchema();
NO_GRADIENT(SparseLengthsSumFakeFP16);
using SparseLengthsWeightedSumDef = AbstractSparseLengthsDef<
float,
int,
CPUContext,
WeightedSumReducerDef,
true /*GradientNeedIndices*/>;
OPERATOR_SCHEMA(SparseLengthsWeightedSumFakeFP16)
.NumInputs(SparseLengthsWeightedSumDef::ForwardOp::kNumInputs)
.NumOutputs(1)
.WeightedValueKeyLengthInputFillers(
SparseLengthsWeightedSumOp::DATA,
SparseLengthsWeightedSumOp::INDICES,
SparseLengthsWeightedSumOp::LENGTHS,
SparseLengthsWeightedSumOp::WEIGHT)
.SetDoc(FormatDoc<SparseLengthsWeightedSumDef>())
.Output(0, "OUTPUT", "Aggregated tensor")
.FillUsing(SparseLengthsWeightedSumDef::PopulateSchema)
.InheritOnnxSchema();
NO_GRADIENT(SparseLengthsWeightedSumFakeFP16);
using SparseLengthsMeanDef = AbstractSparseLengthsDef<
float,
int,
CPUContext,
MeanReducerDef,
true /*GradientNeedIndices*/>;
OPERATOR_SCHEMA(SparseLengthsMeanFakeFP16)
.NumInputs(SparseLengthsMeanDef::ForwardOp::kNumInputs)
.NumOutputs(1)
.ValueKeyLengthInputFillers(
SparseLengthsMeanOp::DATA,
SparseLengthsMeanOp::INDICES,
SparseLengthsMeanOp::LENGTHS)
.SetDoc(FormatDoc<SparseLengthsMeanDef>())
.Output(0, "OUTPUT", "Aggregated tensor")
.FillUsing(SparseLengthsMeanDef::PopulateSchema);
NO_GRADIENT(SparseLengthsMeanFakeFP16);
OPERATOR_SCHEMA(SparseLengthsSumFakeFP16AccFP16)
.NumInputs(SparseLengthsSumDef::ForwardOp::kNumInputs)
.NumOutputs(1)
.ValueKeyLengthInputFillers(
SparseLengthsSumOp::DATA,
SparseLengthsSumOp::INDICES,
SparseLengthsSumOp::LENGTHS)
.SetDoc(FormatDoc<SparseLengthsSumDef>())
.Output(0, "OUTPUT", "Aggregated tensor")
.FillUsing(SparseLengthsSumDef::PopulateSchema)
.InheritOnnxSchema();
NO_GRADIENT(SparseLengthsSumFakeFP16AccFP16);
OPERATOR_SCHEMA(SparseLengthsWeightedSumFakeFP16AccFP16)
.NumInputs(SparseLengthsWeightedSumDef::ForwardOp::kNumInputs)
.NumOutputs(1)
.WeightedValueKeyLengthInputFillers(
SparseLengthsWeightedSumOp::DATA,
SparseLengthsWeightedSumOp::INDICES,
SparseLengthsWeightedSumOp::LENGTHS,
SparseLengthsWeightedSumOp::WEIGHT)
.SetDoc(FormatDoc<SparseLengthsWeightedSumDef>())
.Output(0, "OUTPUT", "Aggregated tensor")
.FillUsing(SparseLengthsWeightedSumDef::PopulateSchema)
.InheritOnnxSchema();
NO_GRADIENT(SparseLengthsWeightedSumFakeFP16AccFP16);
OPERATOR_SCHEMA(SparseLengthsMeanFakeFP16AccFP16)
.NumInputs(SparseLengthsMeanDef::ForwardOp::kNumInputs)
.NumOutputs(1)
.ValueKeyLengthInputFillers(
SparseLengthsMeanOp::DATA,
SparseLengthsMeanOp::INDICES,
SparseLengthsMeanOp::LENGTHS)
.SetDoc(FormatDoc<SparseLengthsMeanDef>())
.Output(0, "OUTPUT", "Aggregated tensor")
.FillUsing(SparseLengthsMeanDef::PopulateSchema);
NO_GRADIENT(SparseLengthsMeanFakeFP16AccFP16);
OPERATOR_SCHEMA(SparseLengthsSumFakeFP16EmbeddingOnly)
.NumInputs(SparseLengthsSumDef::ForwardOp::kNumInputs)
.NumOutputs(1)
.ValueKeyLengthInputFillers(
SparseLengthsSumFakeFP16EmbeddingOnlyOp::DATA,
SparseLengthsSumFakeFP16EmbeddingOnlyOp::INDICES,
SparseLengthsSumFakeFP16EmbeddingOnlyOp::LENGTHS)
.SetDoc(FormatDoc<SparseLengthsSumDef>())
.Output(0, "OUTPUT", "Aggregated tensor")
.FillUsing(SparseLengthsSumDef::PopulateSchema)
.InheritOnnxSchema();
NO_GRADIENT(SparseLengthsSumFakeFP16EmbeddingOnly);
OPERATOR_SCHEMA(SparseLengthsWeightedSumFakeFP16EmbeddingOnly)
.NumInputs(SparseLengthsWeightedSumDef::ForwardOp::kNumInputs)
.NumOutputs(1)
.WeightedValueKeyLengthInputFillers(
SparseLengthsWeightedSumFakeFP16EmbeddingOnlyOp::DATA,
SparseLengthsWeightedSumFakeFP16EmbeddingOnlyOp::INDICES,
SparseLengthsWeightedSumFakeFP16EmbeddingOnlyOp::LENGTHS,
SparseLengthsWeightedSumFakeFP16EmbeddingOnlyOp::WEIGHT)
.SetDoc(FormatDoc<SparseLengthsWeightedSumDef>())
.Output(0, "OUTPUT", "Aggregated tensor")
.FillUsing(SparseLengthsWeightedSumDef::PopulateSchema)
.InheritOnnxSchema();
NO_GRADIENT(SparseLengthsWeightedSumFakeFP16EmbeddingOnly);
OPERATOR_SCHEMA(SparseLengthsMeanFakeFP16EmbeddingOnly)
.NumInputs(SparseLengthsMeanDef::ForwardOp::kNumInputs)
.NumOutputs(1)
.ValueKeyLengthInputFillers(
SparseLengthsMeanFakeFP16EmbeddingOnlyOp::DATA,
SparseLengthsMeanFakeFP16EmbeddingOnlyOp::INDICES,
SparseLengthsMeanFakeFP16EmbeddingOnlyOp::LENGTHS)
.SetDoc(FormatDoc<SparseLengthsMeanDef>())
.Output(0, "OUTPUT", "Aggregated tensor")
.FillUsing(SparseLengthsMeanDef::PopulateSchema);
NO_GRADIENT(SparseLengthsMeanFakeFP16EmbeddingOnly);
} // namespace caffe2