blob: 8e7ede80c42ef7ce852d9d11c4771da646a4e2ab [file] [log] [blame]
#ifndef CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_
#define CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/types.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
namespace {
inline int GetDimFromOrderString(const string& str) {
auto order = StringToStorageOrder(str);
switch (order) {
case StorageOrder::NHWC:
return 3;
case StorageOrder::NCHW:
return 1;
default:
CAFFE_THROW("Unsupported storage order: ", str);
return -1;
}
}
} // namespace
template <class Context>
class SplitOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
SplitOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
split_(OperatorBase::GetRepeatedArgument<int>("split")) {
CAFFE_ENFORCE(
OperatorBase::HasArgument("axis") ^ OperatorBase::HasArgument("order"),
"You should either specify the dim to split, or the order "
"in the case of 4-D images.");
if (OperatorBase::HasArgument("axis")) {
axis_ = OperatorBase::GetSingleArgument<int>("axis", -1);
} else {
axis_ = GetDimFromOrderString(
OperatorBase::GetSingleArgument<string>("order", ""));
}
CAFFE_ENFORCE_GE(axis_, 0);
}
bool RunOnDevice() override;
protected:
int axis_;
vector<int> split_;
// Input: X, optionally split
// The split tensor is stored in CPU.
};
template <class Context>
class ConcatOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
ConcatOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {
CAFFE_ENFORCE(
OperatorBase::HasArgument("axis") ^ OperatorBase::HasArgument("order"),
"You should either specify the dim to split, or the order "
"in the case of 4-D images.");
if (OperatorBase::HasArgument("axis")) {
axis_ = OperatorBase::GetSingleArgument<int>("axis", -1);
} else {
axis_ = GetDimFromOrderString(
OperatorBase::GetSingleArgument<string>("order", ""));
}
CAFFE_ENFORCE_GE(axis_, 0);
}
bool RunOnDevice() override;
protected:
int axis_;
// Input: a number of tensors. Output: Y, split
// The split are stored in CPU.
};
// Implementations
template <class Context>
bool SplitOp<Context>::RunOnDevice() {
auto& input = Input(0);
CAFFE_ENFORCE_LT(axis_, input.ndim(), "Axis not in input ndim range.");
const int input_channels = input.dim32(axis_);
const int* axis_data;
vector<int> equal_split;
if (InputSize() == 2) {
// We obtain split from the input tensor.
CAFFE_ENFORCE_EQ(
split_.size(),
0,
"If you set split with an input blob, do not pass in "
"split in the argument.");
auto& split_tensor = OperatorBase::Input<TensorCPU>(1);
CAFFE_ENFORCE_EQ(split_tensor.size(), OutputSize());
axis_data = split_tensor.template data<int>();
} else if (split_.size() == 0) {
CAFFE_ENFORCE_EQ(
input_channels % OutputSize(),
0,
"If you did not specify split explicitly, the number of "
"input channels should be divisible by the output size.");
equal_split.resize(OutputSize(), input_channels / OutputSize());
axis_data = equal_split.data();
} else {
// We obtain split from the parameters.
CAFFE_ENFORCE_EQ(
split_.size(),
OutputSize(),
"The number of splits specified should be equal to the "
"number of outputs.");
axis_data = split_.data();
}
CAFFE_ENFORCE_EQ(
std::accumulate(axis_data, axis_data + OutputSize(), 0),
input_channels,
"Sum of split dimensions do not match: should be ",
input_channels);
int input_offset = 0;
vector<TIndex> output_dims(input.dims());
int before = 1, after = 1;
for (int i = 0; i < axis_; ++i) {
before *= input.dim32(i);
}
for (int i = axis_ + 1; i < input.ndim(); ++i) {
after *= input.dim32(i);
}
for (int i = 0; i < OutputSize(); ++i) {
auto* output = Output(i);
output_dims[axis_] = axis_data[i];
output->Resize(output_dims);
math::CopyMatrix<Context>(
input.itemsize(),
before,
axis_data[i] * after,
static_cast<const char*>(input.raw_data()) + input_offset,
input.dim32(axis_) * after,
output->raw_mutable_data(input.meta()),
axis_data[i] * after,
&context_);
input_offset += axis_data[i] * after * input.itemsize();
}
return true;
}
template <class Context>
bool ConcatOp<Context>::RunOnDevice() {
auto* output = Output(0);
TensorCPU* split = OperatorBase::Output<TensorCPU>(1);
split->Resize(vector<TIndex>(1, InputSize()));
int* axis_data = split->template mutable_data<int>();
auto& input_zero = Input(0);
CAFFE_ENFORCE_LT(axis_, input_zero.ndim(), "Axis not in input ndim range.");
for (int i = 1; i < InputSize(); ++i) {
CAFFE_ENFORCE(
Input(i).meta() == input_zero.meta(),
"All inputs must have the same type, expected: ",
input_zero.meta().name(),
" but got: ",
Input(i).meta().name(),
" for input: ",
i);
}
int before = 1, after = 1;
vector<TIndex> output_dims(input_zero.dims());
for (int i = 0; i < input_zero.ndim(); ++i) {
if (i == axis_) {
continue;
}
int dim = input_zero.dim32(i);
if (i < axis_) {
before *= dim;
} else { // i > axis_
after *= dim;
}
// check the input dims are compatible.
for (int j = 1; j < InputSize(); ++j) {
int dim_j = Input(j).dim32(i);
CAFFE_ENFORCE(
dim == dim_j,
"Expect dimension = ",
dim,
" got ",
dim_j,
" at axis = ",
i,
" for input: ",
j,
". The input tensors can only have different dimensions "
"along the axis = ",
axis_,
" <",
Input(0).dims(),
"> vs <",
Input(j).dims(),
">.");
}
}
int output_channels = 0;
for (int i = 0; i < InputSize(); ++i) {
axis_data[i] = Input(i).dim32(axis_);
output_channels += axis_data[i];
}
output_dims[axis_] = output_channels;
output->Resize(output_dims);
int output_offset = 0;
for (int i = 0; i < InputSize(); ++i) {
auto& input = Input(i);
math::CopyMatrix<Context>(
input.itemsize(),
before,
input.dim32(axis_) * after,
input.raw_data(),
input.dim32(axis_) * after,
static_cast<char*>(output->raw_mutable_data(input_zero.meta())) +
output_offset,
output_channels * after,
&context_);
output_offset += input.dim32(axis_) * after * input.itemsize();
}
return true;
}
} // namespace caffe2
#endif // CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_