|  | #pragma once | 
|  |  | 
|  | #include "caffe2/core/common.h" | 
|  | #include "caffe2/proto/caffe2.pb.h" | 
|  | #include "caffe2/utils/proto_utils.h" | 
|  | #include "caffe2/utils/string_utils.h" | 
|  |  | 
|  | #include <algorithm> | 
|  | #include <unordered_map> | 
|  | #include <unordered_set> | 
|  |  | 
|  | namespace caffe2 { | 
|  |  | 
|  | namespace transform { | 
|  |  | 
|  | /** | 
|  | *  Graph representation of an operator. | 
|  | */ | 
|  | struct Node { | 
|  | public: | 
|  | // Empty constructor for resize | 
|  | Node() {} | 
|  |  | 
|  | // Alternate constructor | 
|  | Node( | 
|  | const OperatorDef& op, | 
|  | bool active, | 
|  | std::map<int, std::vector<string>> parents, | 
|  | std::map<int, std::vector<string>> children) | 
|  | : op(op), active(active), parents(parents), children(children) {} | 
|  |  | 
|  | // The OperatorDef which this node represents. | 
|  | OperatorDef op; | 
|  |  | 
|  | // Keeps track of if an operator has been deleted through a transformation. | 
|  | bool active = true; | 
|  |  | 
|  | // Stores a pair (idx, blob_list), | 
|  | //  idx = index of the child | 
|  | //  blob_list = a list of strings, containing the blobs that connect the nodes | 
|  | std::map<int, std::vector<string>> parents; | 
|  | std::map<int, std::vector<string>> children; | 
|  | }; | 
|  |  | 
|  | /** | 
|  | *  Graph representation of a Netdef. | 
|  | */ | 
|  | struct Graph { | 
|  | public: | 
|  | /** | 
|  | * Given a subgraph, gets all of the parents of the subgraph, as well as | 
|  | * their associated blob names. Sorted by blob names. | 
|  | * | 
|  | * <string, int> := (name of blob writing into subgraph, | 
|  | *                  index of node that writes into subgraph using that blob) | 
|  | */ | 
|  | const std::vector<std::pair<string, int>> GetSubgraphInput( | 
|  | const std::vector<int>& subgraph); | 
|  |  | 
|  | /** | 
|  | * Given a subgraph, gets all of the children of the subgraph, as well as | 
|  | * their associated blob names. Sorted by blob names. | 
|  | * | 
|  | * <string, int> := (name of blob reading from subgraph, | 
|  | *                  index of node that reads from subgraph using that blob) | 
|  | */ | 
|  | const std::vector<std::pair<string, int>> GetSubgraphOutput( | 
|  | const std::vector<int>& subgraph); | 
|  |  | 
|  | /** | 
|  | * Graph generation. | 
|  | * Given a netdef, returns a Graph. | 
|  | * | 
|  | * Each node represents an operator. | 
|  | * An edge exists between two nodes if the parent op writes to a blob, which | 
|  | * is the input of the child blob, with no other op writing to the blob in | 
|  | * between the execution order. | 
|  | * | 
|  | * Time Complexity: O(E), where E is the number of blobs | 
|  | */ | 
|  | explicit Graph(const NetDef& net_def); | 
|  |  | 
|  | /** | 
|  | * Generates a NetDef Representation for the current graph. | 
|  | * Nodes are visited in topological order, which is proper Opdef ordering. | 
|  | * TODO(benz): | 
|  | * There exists conflicts with repeated blob names, where topological sorting | 
|  | * is not sufficient for correct netdef representation, unless blobs are | 
|  | * renamed. | 
|  | * For example, if after a transformation, We have operator ancestry: | 
|  | * A --> B --> C, and also A --> D --> E, where B -> C and D -> E uses the | 
|  | * same blob name, then A, B, D, E, C is a correct topological ordering, | 
|  | * but D will write to the blob that C reads from, instead of B. | 
|  | * Currently believe that there will always be ambiguity unless blobs are | 
|  | * renamed. | 
|  | * This is solved by performing SSA on all transformed blob names. | 
|  | */ | 
|  | NetDef GetNetDef(); | 
|  |  | 
|  | /** | 
|  | * Deactivate a subgraph, and get rid of all edges into this subgraph. | 
|  | */ | 
|  | void DeactivateSubgraph(std::vector<int> subgraph); | 
|  |  | 
|  | size_t size() const { | 
|  | return nodes_.size(); | 
|  | } | 
|  |  | 
|  | void push_node(const Node& new_node) { | 
|  | return nodes_.push_back(new_node); | 
|  | } | 
|  |  | 
|  | void resize_nodes(size_t new_size) { | 
|  | nodes_.resize(new_size); | 
|  | } | 
|  |  | 
|  | // Index safe, less verbose way to access nodes | 
|  | inline const Node& node(size_t idx) const { | 
|  | return nodes_.at(idx); | 
|  | } | 
|  |  | 
|  | inline Node& node(size_t idx) { | 
|  | return nodes_.at(idx); | 
|  | } | 
|  |  | 
|  | inline bool is_node_active(size_t idx) { | 
|  | return node(idx).active; | 
|  | } | 
|  |  | 
|  | inline const std::set<string>& external_input() const { | 
|  | return external_input_; | 
|  | } | 
|  |  | 
|  | inline const std::set<string>& external_output() const { | 
|  | return external_output_; | 
|  | } | 
|  |  | 
|  | private: | 
|  | const std::vector<std::pair<string, int>> GetSubgraphPerimeterHelper( | 
|  | bool from_children, | 
|  | const std::vector<int>& match); | 
|  |  | 
|  | // Stores the netdef representation. Is updated upon calls to GetNetDef. | 
|  | NetDef netdef_; | 
|  |  | 
|  | // Stores which blobs the graph reads from, and writes to. | 
|  | std::set<string> external_input_; | 
|  | std::set<string> external_output_; | 
|  |  | 
|  | // Keeps track of all the Operators currently within graph, even if inactive. | 
|  | std::vector<Node> nodes_; | 
|  | }; | 
|  |  | 
|  | } // namespace transform | 
|  |  | 
|  | // Adds an operator def to a netdef. | 
|  | // Returns the ptr, if you want to add anything extra (such as device_option) | 
|  | OperatorDef* AddOp( | 
|  | NetDef* netdef_ptr, | 
|  | string op_type, | 
|  | std::vector<string> inputs, | 
|  | std::vector<string> outputs); | 
|  |  | 
|  | /** | 
|  | * This allows for the use of * and | to match operator types, | 
|  | * engines, or any other property that is represented by strings. | 
|  | * | 
|  | * For example, if we wanted to match an operator to Conv or FC, we can give: | 
|  | * "Conv|FC" as the type() of that op. | 
|  | */ | 
|  | bool MatchStrings(string p, string s); | 
|  |  | 
|  | /** | 
|  | * This ensures that each named arg that exists in the pattern exists in g_op, | 
|  | * is equal in value. | 
|  | */ | 
|  | bool MatchArguments(const OperatorDef& p_op, const OperatorDef& g_op); | 
|  |  | 
|  | } // namespace caffe2 |