blob: cdc8d83720e4ccbb80560f75f115377444930b82 [file] [log] [blame]
#ifndef CAFFE2_OPERATORS_FLATTEN_OP_H_
#define CAFFE2_OPERATORS_FLATTEN_OP_H_
#include "caffe2/core/operator.h"
namespace caffe2 {
template <class Context>
class FlattenOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit FlattenOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
axis_(this->template GetSingleArgument<int>("axis", 1)) {}
bool RunOnDevice() override {
auto& input = Input(0);
auto* output = Output(0);
CAFFE_ENFORCE_GE(
input.dim(), axis_, "The rank of the tensor must be >= axis.");
output->Resize(input.size_to_dim(axis_), input.size_from_dim(axis_));
context_.CopyItemsSameDevice(
input.dtype(),
input.numel(),
input.raw_data(),
output->raw_mutable_data(input.dtype()));
return true;
}
private:
int axis_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_FLATTEN_OP_H_