blob: fe78e30a7d3369d4c977efb3dbea082b5e30d5e8 [file] [log] [blame]
#include <caffe2/ideep/ideep_utils.h>
#include "caffe2/operators/expand_squeeze_dims_op.h"
namespace caffe2 {
class IDEEPSqueezeOp final : public IDEEPOperator {
public:
USE_IDEEP_DEF_ALIASES();
USE_IDEEP_OPERATOR_FUNCTIONS();
IDEEPSqueezeOp(const OperatorDef& operator_def, Workspace* ws)
: IDEEPOperator(operator_def, ws),
dims_(OperatorBase::GetRepeatedArgument<int>("dims")) {
auto originalSize = dims_.size();
CAFFE_ENFORCE(originalSize > 0, "Parameter `dims` must be provided.");
std::sort(dims_.begin(), dims_.end());
dims_.erase(std::unique(dims_.begin(), dims_.end()), dims_.end());
if (dims_.size() < originalSize) {
LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
}
CAFFE_ENFORCE(dims_.front() >= 0, "Dimension ids must be non-negative.");
}
virtual ~IDEEPSqueezeOp() {}
bool RunOnDevice() override {
const auto& X = Input(INPUT);
auto* Y = Output(OUTPUT);
CAFFE_ENFORCE_GT(
X.ndims(),
dims_.back(),
"Input needs at least ",
(dims_.back() + 1),
" dimensions.");
const auto& ideep_dims = X.get_dims();
vector<TIndex> dims(ideep_dims.begin(), ideep_dims.end());
const auto& new_dims = SqueezeOp<IDEEPContext>::ComputeDims(dims, dims_);
itensor::dims new_dims_ideep(new_dims.begin(), new_dims.end());
if (&X != Y) {
// Copy if not inplace
ideep::direct_copy::compute(X, *Y);
}
Y->reshape(new_dims_ideep);
return true;
}
private:
vector<int> dims_;
INPUT_TAGS(INPUT);
OUTPUT_TAGS(OUTPUT);
};
REGISTER_IDEEP_OPERATOR(Squeeze, IDEEPSqueezeOp);
} // namespace caffe2