blob: c2b1379cc47be7695d3de7fe705c84be8006db07 [file] [log] [blame]
#include "caffe2/opt/shape_info.h"
#include "caffe2/core/tensor_int8.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
ShapeInfo getShapeInfoFromBlob(const Blob* blob) {
ShapeInfo shape_info;
shape_info.shape = GetTensorShapeOfBlob(blob);
shape_info.dim_type = shape_info.shape.unknown_shape()
? ShapeInfo::DimType::UNKNOWN
: ShapeInfo::DimType::CONSTANT;
if (blob->meta().id() == TypeMeta::Id<int8::Int8TensorCPU>()) {
shape_info.is_quantized = true;
LoadInt8TensorInfoOfBlob(
&shape_info.q_info.scale,
&shape_info.q_info.offset,
&shape_info.q_info.axis,
blob);
} else {
#ifndef C10_MOBILE
auto function_ptr =
ExternalTensorFunctionsBaseRegistry()->Create(blob->meta().id());
if (function_ptr != nullptr) {
shape_info.is_quantized = function_ptr->isQuantized();
function_ptr->LoadInfoOfBlob(
blob,
&shape_info.q_info.scale,
&shape_info.q_info.offset,
&shape_info.q_info.axis);
}
#endif
}
return shape_info;
}
bool operator==(const ShapeInfo& lhs, const ShapeInfo& rhs) {
return lhs.dim_type == rhs.dim_type &&
lhs.shape.SerializeAsString() == rhs.shape.SerializeAsString();
}
} // namespace caffe2