| #pragma once |
| |
| #include "caffe2/core/common.h" |
| #include "caffe2/core/workspace.h" |
| #include "caffe2/opt/bound_shape_inferencer.h" |
| #include "caffe2/proto/caffe2_pb.h" |
| |
| #include <string> |
| #include <unordered_map> |
| #include <vector> |
| |
| namespace caffe2 { |
| namespace { |
| constexpr char kNetPos[] = "net_pos"; |
| constexpr char kModelId[] = "model_id"; |
| } // namespace |
| |
| struct BackendTransformOptions { |
| explicit BackendTransformOptions() : bound_shape_spec(0, 0) {} |
| |
| // Enable debugging by dumping more intermediate graphs |
| bool debug{false}; |
| |
| // Minimum number of ops to create a backend op. If the subgraph is too |
| // small, it doesn't make sense to lower it to backend. |
| size_t min_ops{1}; |
| |
| // Bound shape spec |
| BoundShapeSpec bound_shape_spec; |
| }; |
| |
| // Wrap TensorShape into TensorProto |
| TensorProto wrapShapeInfoIntoTensorProto( |
| const std::string& name, |
| const ShapeInfo& shape_info); |
| |
| // Wrap Quantized TensorShape into QTensorProto |
| QTensorProto wrapShapeInfoIntoQTensorProto( |
| const std::string& name, |
| const ShapeInfo& shape_info); |
| |
| // This class contains some common functions for backend lowering and graph |
| // cutting |
| class BackendTransformerBase { |
| public: |
| BackendTransformerBase() {} |
| virtual ~BackendTransformerBase() {} |
| |
| const std::unordered_map<std::string, std::string>& input_mapping() const { |
| return input_mapping_; |
| } |
| |
| const std::unordered_map<std::string, std::string>& reverse_input_mapping() |
| const { |
| return reverse_input_mapping_; |
| } |
| |
| virtual void transform( |
| Workspace* ws, |
| NetDef* pred_net, |
| const std::vector<std::string>& weight_names, |
| const ShapeInfoMap& shape_hints, |
| const std::unordered_set<int>& blocklisted_ops) = 0; |
| |
| static void annotateOpIndex(NetDef* net); |
| |
| // Get model ID from the NetDef |
| static std::string getModelId(const NetDef& net); |
| |
| protected: |
| // add shape info to the net |
| void addShapeToNet(NetDef& shape_net, const ShapeInfoMap& shape_hints) const; |
| |
| // Dump the net with shape info |
| void dumpNet( |
| const NetDef& pred_net, |
| const ShapeInfoMap& map, |
| const std::string& fname) const; |
| |
| // SSA rewrite the net and return name mapping |
| ShapeInfoMap ssaRewriteAndMapNames( |
| Workspace* ws, |
| NetDef* pred_net, |
| const ShapeInfoMap& input_shape_hints); |
| |
| // Do bound shape inference and collect shape infos |
| ShapeInfoMap inferShapes( |
| Workspace* ws, |
| NetDef* pred_net, |
| const ShapeInfoMap& shape_hints_mapped, |
| const BoundShapeSpec& spec); |
| |
| // Input mapping of input name -> original input name |
| std::unordered_map<std::string, std::string> input_mapping_; |
| |
| // Input mapping of original input name -> input name |
| std::unordered_map<std::string, std::string> reverse_input_mapping_; |
| }; |
| } // namespace caffe2 |