| #pragma once |
| |
| #include "caffe2/core/logging.h" |
| #include "caffe2/opt/shape_info.h" |
| #include "caffe2/proto/caffe2_pb.h" |
| |
| #include <sstream> |
| #include <string> |
| #include <unordered_map> |
| #include <unordered_set> |
| |
| namespace caffe2 { |
| // This struct stores the max bound size for batch in the general sense. |
| // max_batch_size is the upper bound of batch_size. |
| // max_seq_size is the upper bound of length of every item in a batch. |
| // Upper bound of length of a batch of items should be max_batch_size * |
| // max_seq_size. |
| struct TORCH_API BoundShapeSpec { |
| explicit BoundShapeSpec(int64_t b, int64_t q) |
| : max_batch_size(b), |
| max_seq_size(q), |
| num_embeddings(0), |
| embedding_length(0) {} |
| explicit BoundShapeSpec(int64_t b, int64_t q, int64_t n, int64_t e) |
| : max_batch_size(b), |
| max_seq_size(q), |
| num_embeddings(n), |
| embedding_length(e) {} |
| int64_t max_batch_size; |
| int64_t max_seq_size; |
| // The following two parameters are for shape inference of UnPackRecords |
| int64_t num_embeddings; |
| int64_t embedding_length; |
| }; |
| |
| /// \class A class that does bound shape inference given a C2 net. Depending on |
| /// its type, each op have a maximum shape that it accepts. We define some |
| /// initial bound for certain dimension, for example max batch size or max |
| /// sequnce lookup size. And the inference will first infer the input size and |
| /// then propagates the bound shape down the network. For now the variable part |
| /// (bound part) is the first dimension of the shape, which usually corresponds |
| /// to the batch size or sequence lookup size. |
| class BoundShapeInferencerBase { |
| public: |
| explicit BoundShapeInferencerBase(const BoundShapeSpec& spec) : spec_(spec) { |
| CAFFE_ENFORCE_GE(spec_.max_batch_size, 0); |
| CAFFE_ENFORCE_GE(spec_.max_seq_size, 0); |
| } |
| |
| virtual ~BoundShapeInferencerBase() {} |
| |
| // Initializes BoundShapeInferencer and infers bound shape and type. |
| // info: shape information of some tensors, |
| // e.g. shape information of external input / output tensors; |
| // extract_feature_len: |
| // indicating whether to extract feature length from SigridTransform |
| // and other related operators. When enabled, |
| // extracted feature length information will be used to infer tensor shapes. |
| virtual void InferBoundShapeAndType( |
| const NetDef& net, |
| const ShapeInfoMap& info, |
| caffe2::Workspace* ws, |
| bool extract_feature_len = false) = 0; |
| |
| const ShapeInfoMap& shape_info() const { |
| return shape_info_; |
| } |
| |
| /// Print out all the shape info |
| std::string PrintShapeInfo() const { |
| std::stringstream ss; |
| for (const auto& kv : shape_info_) { |
| const auto& s = kv.second; |
| ss << s.shape.name() << ": dim_type: " << s.getDimType() << ", dims: ["; |
| for (const auto d : s.shape.dims()) { |
| ss << d << ", "; |
| } |
| ss << "], dtype: " << s.shape.data_type() << "\n"; |
| } |
| return ss.str(); |
| } |
| |
| protected: |
| const BoundShapeSpec spec_; |
| ShapeInfoMap shape_info_; |
| bool extract_feature_len_; |
| }; |
| |
| class TORCH_API BoundShapeInferencer : public BoundShapeInferencerBase { |
| public: |
| explicit BoundShapeInferencer(const BoundShapeSpec& spec) |
| : BoundShapeInferencerBase(spec) {} |
| |
| ~BoundShapeInferencer() override {} |
| void InferBoundShapeAndType( |
| const NetDef& net, |
| const ShapeInfoMap& info, |
| caffe2::Workspace* ws, |
| bool extract_feature_len = false) override; |
| |
| protected: |
| TensorShape& CheckAndSetTensorBoundShape( |
| const std::string& name, |
| const std::vector<TensorBoundShape::DimType>& t, |
| std::vector<int64_t> bound_dims, |
| TensorProto::DataType type, |
| bool is_quantized, |
| bool allow_existing_shape = false, |
| float scale = 1, |
| int offset = 0, |
| bool in_place_op = false); |
| |
| TensorShape& SetTensorBoundShapeIfNotExist( |
| const std::string& name, |
| const std::vector<TensorBoundShape::DimType>& t, |
| std::vector<int64_t> bound_dims, |
| TensorProto::DataType type, |
| bool is_quantized); |
| |
| virtual void InferOps(const OperatorDef& op, caffe2::Workspace* ws); |
| |
| void InferConcatInputs(const OperatorDef& op); |
| void InferInt8QuantizeInput(const OperatorDef& op); |
| void InferElementwiseOpInput(const OperatorDef& op); |
| |
| void InferElementwiseOp(const OperatorDef& op); |
| void InferGivenTensorFill(const OperatorDef& op); |
| void InferSparseLengthsSum(const OperatorDef& op); |
| void InferFC(const OperatorDef& op); |
| void InferConcat(const OperatorDef& op); |
| void InferShape(const OperatorDef& op); |
| void InferReshape(const OperatorDef& op); |
| void InferLengthsRangeFill(const OperatorDef& op); |
| void InferQuantizationTransformation(const OperatorDef& op); |
| void InferUnPackRecords(const OperatorDef& op); |
| void InferTile(const OperatorDef& op); |
| void InferSparseLengthsSumSparseLookup(const OperatorDef& op); |
| void InferSoftmax(const OperatorDef& op); |
| void InferBucketize(const OperatorDef& op); |
| void InferLpNorm(const OperatorDef& op); |
| void InferClip(const OperatorDef& op); |
| void InferMean(const OperatorDef& op); |
| void InferDiv(const OperatorDef& op); |
| void InferTranspose(const OperatorDef& op); |
| |
| // Standard shape/type inference using op schema registered shape inference |
| // function |
| void InferCommonOp(const OperatorDef& op, const OpSchema* schema = nullptr, bool bypass_input_check = false, bool in_place_op = false); |
| |
| // Initialize private parameters, such as shape_info, extract_feature_len_ |
| // This is called at the beginning of InferBoundShapeAndType() |
| virtual void Initialize(const ShapeInfoMap& info, bool extract_feature_len); |
| |
| void EnsureShapeNames(ShapeInfoMap* info) const; |
| |
| TensorBoundShape::DimType current_dim_type_{TensorBoundShape_DimType_BATCH}; |
| int64_t current_max_batch_size_{0}; |
| }; |
| |
| TORCH_API std::shared_ptr<BoundShapeInferencerBase> getBoundShapeInferencer( |
| const BoundShapeSpec& spec); |
| |
| C10_DECLARE_SHARED_REGISTRY( |
| BoundShapeInferencerRegistry, |
| BoundShapeInferencerBase, |
| const BoundShapeSpec&); |
| |
| } // namespace caffe2 |