blob: 65044e828786ba63e3e7220c42ae5f481a1a4339 [file] [log] [blame]
#include "caffe2/operators/concat_split_op.h"
namespace caffe2 {
namespace {
REGISTER_CPU_OPERATOR(Split, SplitOp<CPUContext>);
REGISTER_CPU_OPERATOR(Concat, ConcatOp<CPUContext>);
OPERATOR_SCHEMA(Split)
.NumInputs(1, 2)
.NumOutputs(1, INT_MAX)
.Arg("axis", "Which axis to split on")
.Arg("order", "Either NHWC or NCWH, will split on C axis")
.SetDoc("Split a tensor into a list of tensors.");
OPERATOR_SCHEMA(Concat)
.NumInputs(1, INT_MAX)
.NumOutputs(2)
.Arg("axis", "Which axis to concat on")
.Arg("order", "Either NHWC or HCWH, will concat on C axis")
.SetDoc("Concatenate a list of tensors into a single tensor")
.Output(0, "concat_result", "Concatenated tensor")
.Output(1, "split_info", "The dimensions of the inputs.");
// Backward compatibility names.
REGISTER_CPU_OPERATOR(DepthSplit, SplitOp<CPUContext>);
REGISTER_CPU_OPERATOR(DepthConcat, ConcatOp<CPUContext>);
OPERATOR_SCHEMA(DepthSplit)
.NumInputs(1, 2)
.NumOutputs(1, INT_MAX)
.SetDoc("Backward compatible operator name for Split.");
OPERATOR_SCHEMA(DepthConcat)
.NumInputs(1, INT_MAX)
.NumOutputs(2)
.SetDoc("Backward compatible operator name for Concat.");
class GetSplitGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
vector<string> output_grads;
for (int i = 0; i < def_.output_size(); ++i) {
if (!GradOut(i).IsEmpty()) {
output_grads.push_back(GO(i));
}
}
if (output_grads.empty()) {
return {};
}
return SingleGradientDef(
"Concat", "", output_grads,
vector<string>{GI(0), "_" + GI(0) + "_dims"});
}
};
REGISTER_GRADIENT(Split, GetSplitGradient);
REGISTER_GRADIENT(DepthSplit, GetSplitGradient);
class GetConcatGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
if (GradOut(0).IsEmpty()) {
return {};
}
vector<string> grads;
for (int i = 0; i < def_.input_size(); ++i) {
grads.push_back(GI(i));
}
return SingleGradientDef(
"Split", "", vector<string>{GO(0), O(1)}, grads);
}
};
REGISTER_GRADIENT(Concat, GetConcatGradient);
REGISTER_GRADIENT(DepthConcat, GetConcatGradient);
} // namespace
} // namespace caffe2