blob: b4f4adb3c6226f74c9d4fe3cfbd5c94f69ced2a8 [file] [log] [blame]
#include "caffe2/operators/concat_split_op.h"
namespace caffe2 {
namespace {
REGISTER_CPU_OPERATOR(Split, SplitOp<CPUContext>);
REGISTER_CPU_OPERATOR(Concat, ConcatOp<CPUContext>);
OPERATOR_SCHEMA(Split)
.NumInputs(1, 2)
.NumOutputs(1, INT_MAX)
.Arg("axis", "Which axis to split on")
.Arg("order", "Either NHWC or NCWH, will split on C axis")
.SetDoc("Split a tensor into a list of tensors.");
OPERATOR_SCHEMA(Concat)
.NumInputs(1, INT_MAX)
.NumOutputs(2)
.Arg("axis", "Which axis to concat on")
.Arg("order", "Either NHWC or HCWH, will concat on C axis")
.SetDoc("Concatenate a list of tensors into a single tensor.");
// Backward compatibility names.
REGISTER_CPU_OPERATOR(DepthSplit, SplitOp<CPUContext>);
REGISTER_CPU_OPERATOR(DepthConcat, ConcatOp<CPUContext>);
OPERATOR_SCHEMA(DepthSplit)
.NumInputs(1, 2)
.NumOutputs(1, INT_MAX)
.SetDoc("Backward compatible operator name for Split.");
OPERATOR_SCHEMA(DepthConcat)
.NumInputs(1, INT_MAX)
.NumOutputs(2)
.SetDoc("Backward compatible operator name for Concat.");
class GetSplitGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
vector<string> output_grads;
for (int i = 0; i < def_.output_size(); ++i) {
if (!GradOut(i).IsEmpty()) {
output_grads.push_back(GO(i));
}
}
if (output_grads.empty()) {
return {};
}
return SingleGradientDef(
"Concat", "", output_grads,
vector<string>{GI(0), "_" + GI(0) + "_dims"});
}
};
REGISTER_GRADIENT(Split, GetSplitGradient);
REGISTER_GRADIENT(DepthSplit, GetSplitGradient);
class GetConcatGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
if (GradOut(0).IsEmpty()) {
return {};
}
vector<string> grads;
for (int i = 0; i < def_.input_size(); ++i) {
grads.push_back(GI(i));
}
return SingleGradientDef(
"Split", "", vector<string>{GO(0), O(1)}, grads);
}
};
REGISTER_GRADIENT(Concat, GetConcatGradient);
REGISTER_GRADIENT(DepthConcat, GetConcatGradient);
} // namespace
} // namespace caffe2