blob: cc997883bce2fe14da7dc13d97668e3b989b438c [file] [log] [blame]
#pragma once
#include "caffe2/core/operator.h"
namespace caffe2 {
struct CAFFE2_API QShapeInfo {
QShapeInfo(float o = 0, float s = 1, uint32_t a = 1) {
offset.clear();
scale.clear();
offset.push_back(o);
scale.push_back(s);
axis = a;
}
uint32_t axis;
vector<float> offset;
vector<float> scale;
};
struct CAFFE2_API ShapeInfo {
enum DimType : int8_t { UNKNOWN = 0, CONSTANT = 1, BATCH = 2, SEQ = 3 };
ShapeInfo(bool q = false) : is_quantized(q) {}
ShapeInfo(DimType t, TensorShape&& s, bool q = false)
: dim_type(t), shape(std::move(s)), is_quantized(q) {}
ShapeInfo(DimType t, const TensorShape& s, bool q = false)
: dim_type(t), shape(s), is_quantized(q) {}
ShapeInfo(bool q, const QShapeInfo& info) : is_quantized(q), q_info(info) {}
ShapeInfo(DimType t, TensorShape&& s, bool q, const QShapeInfo& info)
: dim_type(t), shape(std::move(s)), is_quantized(q), q_info(info) {}
ShapeInfo(DimType t, const TensorShape& s, bool q, const QShapeInfo& info)
: dim_type(t), shape(s), is_quantized(q), q_info(info) {}
// type of the shape according its first dim
DimType dim_type{DimType::UNKNOWN};
TensorShape shape;
// quantization related information
bool is_quantized;
QShapeInfo q_info;
};
using ShapeInfoMap = std::unordered_map<std::string, ShapeInfo>;
// Generates ShapeInfo from Blob.
ShapeInfo getShapeInfoFromBlob(const Blob* blob);
bool operator==(const ShapeInfo& lhs, const ShapeInfo& rhs);
} // namespace caffe2