| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // TODO(intel): Improve error handling in this file; instead of CHECK failing |
| // all over the place, we should log an error and execute the original graph. |
| #ifdef INTEL_MKL |
| |
| #include <algorithm> |
| #include <functional> |
| #include <memory> |
| #include <queue> |
| #include <set> |
| #include <unordered_set> |
| #include <utility> |
| #include <vector> |
| #include "tensorflow/core/common_runtime/function.h" |
| #include "tensorflow/core/common_runtime/optimization_registry.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/graph/algorithm.h" |
| #include "tensorflow/core/graph/graph.h" |
| #include "tensorflow/core/graph/node_builder.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/lib/gtl/array_slice.h" |
| #include "tensorflow/core/lib/gtl/map_util.h" |
| #include "tensorflow/core/lib/hash/hash.h" |
| #include "tensorflow/core/platform/logging.h" |
| #include "tensorflow/core/util/tensor_format.h" |
| |
| #include "tensorflow/core/graph/mkl_graph_util.h" |
| #include "tensorflow/core/graph/mkl_layout_pass.h" |
| |
| namespace tensorflow { |
| |
| #ifdef INTEL_MKL_ML_ONLY |
| |
| // This pass implements rewriting of graph to support following scenarios: |
| // (A) Merging nodes in the graph |
| // (B) Rewriting a node in the graph to a new node |
| // Rewrite happens under following 2 scenarios: |
| // 1) Propagating Mkl layout as an additional output tensor |
| // (we will loosely call a tensor that carries Mkl layout as Mkl tensor |
| // henceforth.) from every Mkl supported NN layer. |
| // 2) Context-based rewrite: This is needed in order to optimize |
| // gradient ops of Conv2D+AddBias. Gradient op of both the Conv2D and |
| // MatMul is BiasAddGrad, and we need to rewrite BiasAddGrad into |
| // Conv2D-specific BiasAddGrad, and MatMul-specific BiasAddGrad. |
| // This is context-specific optimization, where the context is the |
| // forward operator that the BiasAddGrad corresponds to. |
| // |
| // Example of A : Merging nodes in the graph |
| // ----------------------------------------- |
| // Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as: |
| // |
| // O = Conv2D(A, B) |
| // P = BiasAdd(O, C) |
| // |
| // We merge them into Conv2DWithBias as: |
| // P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m) |
| // |
| // The meaning of A_m, B_m and C_m is explained in B.1. |
| // |
| // Merge rules: |
| // - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_ |
| // goes to BiasAdd. |
| // - Also, the intersection of attributes of both the nodes must have same |
| // values. |
| // - Both the nodes must have been assigned to same device (if any). |
| // |
| // Example of B.1 : Rewriting nodes to Mkl nodes |
| // --------------------------------------------- |
| // Consider a Relu node. Current definition of Relu node looks like: |
| // |
| // O = Relu(A) |
| // |
| // Relu has 1 input (A), and 1 output (O). |
| // |
| // This rewrite pass will generate a new graph node for Relu (new node is |
| // called MklRelu) as: |
| // |
| // O, O_m = MklRelu(A, A_m) |
| // |
| // MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is |
| // same as input A of Relu; output O is same as output O of Relu. O_m is the |
| // additional output tensor that will be set by MklRelu, and it represents |
| // Mkl tensor corresponding to O -- in other words, O_m is some kind of |
| // metadata for O. A_m is additional input of Relu, and it represents metadata |
| // for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives |
| // this metadata from previous node in the graph. |
| // |
| // When a previous node in the graph is an Mkl node, A_m will represent a valid |
| // Mkl tensor. But when a previous node is not an Mkl node, A_m will represent |
| // a dummy Mkl tensor. |
| // |
| // Rewriting rules: |
| // - Selection of a node for rewriting happens by registering the op type of |
| // the node with the rewriting pass. If the op type is not registered, then |
| // all nodes of this op type will not be rewritten. |
| // - Number of inputs after rewriting: |
| // Since for every input Tensorflow tensor, the rewritten node gets Mkl |
| // tensor(s), rewritten node gets 2*N inputs, where N is the number of |
| // inputs for the original node. |
| // - Number of outputs after rewriting: |
| // Since for every output Tensorflow tensor, the rewritten node generates |
| // Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the |
| // number of outputs of the original node. |
| // - Ordering of Tensorflow tensors and Mkl tensors: |
| // Since every rewritten node generates twice the number of inputs and |
| // outputs, one could imagine various orderings among Tensorflow tensors |
| // and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as |
| // inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m |
| // in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m |
| // order. Among N inputs one can get N! permutations. |
| // |
| // So the question is: which order do we follow? We support 2 types of |
| // orderings: (1) interleaved, and (2) contiguous. Interleaved ordering |
| // follows an intuitive order where an Mkl tensor follows the |
| // corresponding Tensorflow tensor immediately. In the context of the |
| // above example, it will be: A, A_m, B, B_m. Note that the ordering rule |
| // applies to both the inputs and outputs. Contiguous ordering means |
| // all the Tensorflow tensors are contiguous followed by all the Mkl |
| // tensors. We use contiguous ordering as default. |
| // |
| // Graph rewrite algorithm: |
| // Algorithm: Graph Rewrite |
| // Input: Graph G, Names of the nodes to rewrite and their new names |
| // Output: Modified Graph G' if the nodes are modified, G otherwise. |
| // Start: |
| // N = Topological_Sort(G) // N is a set of nodes in toposort order. |
| // foreach node n in N |
| // do |
| // if (Is_MKL_Op(n)) // Can this node accept an Mkl layout as input. |
| // then |
| // E = set of <incoming edge and its src_output slot> of n |
| // E' = {} // a new set of edges for rewritten node |
| // foreach <e,s> in E |
| // do |
| // E' U {<e,s>} // First copy edge which generates Tensorflow |
| // // tensor as it is |
| // m = Source node of edge e |
| // if Is_Rewritten(m) // Did we rewrite this node in this pass? |
| // then |
| // E' U {<m,s+1>} // If yes, then m will generate an Mkl |
| // // tensor as an additional output. |
| // else |
| // d = Generate_Dummy_Mkl_Tensor() // If not, generate a dummy |
| // // Mkl tensor. |
| // E' U {<d,0>} // The dummy Mkl tensor has only 1 output slot. |
| // fi |
| // done |
| // n' = Build_New_Node(G,new_name,E') |
| // Mark_Rewritten(n') // Mark the new node as being rewritten. |
| // fi |
| // done |
| // |
| // Explanation: |
| // For graph rewrite, we visit nodes of the input graph in the |
| // topological sort order. With this ordering, we visit nodes in the |
| // top-to-bottom fashion. We need this order because while visiting a |
| // node we want that all of its input nodes are visited and rewritten if |
| // applicable. This is because if we need to rewrite a given node |
| // then all of its input nodes need to be fixed (in other words they |
| // cannot be deleted later.) |
| // |
| // While visiting a node, we first check if the op type of the node is |
| // an Mkl op. If it is, then we rewrite that node after constructing |
| // new inputs to the node. If the op type of the node is not Mkl op, |
| // then we do not rewrite that node. |
| // |
| // Handling workspace propagation for certain ops: |
| // |
| // Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require |
| // passing of a workspace from their respective forward ops. Workspace |
| // tensors provide memory for storing results of intermediate operations |
| // which are helpful in backward propagation. TensorFlow does not have |
| // a notion of a workspace and as a result does not allow producing |
| // additional outputs from these forward ops. For these ops, we need |
| // to add 2 extra edges between forward ops and their corresponding |
| // backward ops - the first extra edge carries a workspace tensor and |
| // the second one carries an Mkl tensor for the workspace tensor. |
| // |
| // Example: |
| // |
| // Typical graph for MaxPool and its gradient looks like: |
| // |
| // A = MaxPool(T) |
| // B = MaxPoolGrad(X, A, Y) |
| // |
| // We will transform this graph to propagate the workspace as: |
| // (with the contiguous ordering) |
| // |
| // A, W, A_m, W_m = MklMaxPool(T, T_m) |
| // B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m) |
| // |
| // Here W is the workspace tensor. Transformed tensor names with the |
| // suffix _m are Mkl tensors, and this transformation has been done |
| // using the algorithm discussed earlier. The transformation for |
| // workspace propagation only adds extra outputs (W, W_m) for a forward |
| // op and connects them to the corresponding backward ops. |
| // |
| // Terms: |
| // |
| // Forward op name = name of the op in the forward pass |
| // where a workspace tensor originates (MaxPool in this example) |
| // Backward op name = name of the op in the backward pass that receives |
| // a workspace tensor from the forward op (MaxPoolGrad in the example) |
| // Slot = Position of the output or input slot that will be |
| // used by the workspace tensor (1 for MklMaxPool as W is the 2nd |
| // output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad) |
| // |
| // Question: |
| // |
| // How do we associate a backward op to a forward op? There can be more |
| // than one op with the exact same name. |
| // |
| // In this example, we associate MaxPoolGrad with MaxPool. But there |
| // could be more than one MaxPool ops. To solve this problem, we look |
| // for _direct_ edge between a forward op and a backward op (tensor A is |
| // flowing along this edge in the example). |
| // |
| // How do we transform forward and backward ops when there is no direct |
| // edge between them? In such a case, we generate dummy tensors for |
| // workspace tensors. For the example, transformation of MaxPool will |
| // be exactly same as it would be when there is a direct edge between |
| // the forward and the backward op --- it is just that MaxPool won't |
| // generate any workspace tensor. For MaxPoolGrad, the transformation |
| // will also be same, but instead of connecting W and W_m with the |
| // outputs of MaxPool, we will produce dummy tensors for them, and we |
| // will set workspace_enabled attribute to false. |
| // |
| // Example of B.2 : Context-based node rewrite |
| // ------------------------------------------- |
| // Consider BiasAddGrad op as: |
| // |
| // O = _MklConv2D(A, B, C, A_m, B_m, C_m) |
| // P = BiasAddGrad(O) |
| // |
| // Then we rewrite it as: |
| // |
| // P = Conv2DWithBiasBackpropBias(O, O_m) |
| // |
| // Rewrite of BiasAddGrad into Conv2DWithBiasBackpropBias takes place depending |
| // on the matching 'context'. The term context is loosely related to which |
| // forward op is _associated_ to BiasAddGrad. If it is _MklConv2DWithBias then |
| // we consider it Conv2D context; if it is MatMul, then it is MatMul context. |
| |
| class MklLayoutRewritePass : public GraphOptimizationPass { |
| public: |
| MklLayoutRewritePass() { |
| // NOTE: names are alphabetically sorted. |
| csinfo_.addn = "AddN"; |
| csinfo_.avg_pool = "AvgPool"; |
| csinfo_.avg_pool_grad = "AvgPoolGrad"; |
| csinfo_.bias_add = "BiasAdd"; |
| csinfo_.bias_add_grad = "BiasAddGrad"; |
| csinfo_.concat = "Concat"; |
| csinfo_.concatv2 = "ConcatV2"; |
| csinfo_.conv2d = "Conv2D"; |
| csinfo_.conv2d_grad_input = "Conv2DBackpropInput"; |
| csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter"; |
| csinfo_.fused_batch_norm = "FusedBatchNorm"; |
| csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad"; |
| csinfo_.identity = "Identity"; |
| csinfo_.lrn = "LRN"; |
| csinfo_.lrn_grad = "LRNGrad"; |
| csinfo_.matmul = "MatMul"; |
| csinfo_.max_pool = "MaxPool"; |
| csinfo_.max_pool_grad = "MaxPoolGrad"; |
| csinfo_.mkl_conv2d = "_MklConv2D"; |
| csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput"; |
| csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter"; |
| csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; |
| csinfo_.mkl_conv2d_with_bias_backprop_bias = |
| "_MklConv2DWithBiasBackpropBias"; |
| csinfo_.relu = "Relu"; |
| csinfo_.relu_grad = "ReluGrad"; |
| csinfo_.reshape = "Reshape"; |
| csinfo_.split = "Split"; |
| // Element-wise ops. Ensure you also add any new ops to IsOpElementWise |
| // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the |
| // MklInputConversion op is added before it. |
| csinfo_.add = "Add"; |
| csinfo_.maximum = "Maximum"; |
| csinfo_.mul = "Mul"; |
| csinfo_.squared_difference = "SquaredDifference"; |
| csinfo_.sub = "Sub"; |
| // End - element-wise ops. See note above. |
| |
| // NOTE: names are alphabetically sorted. |
| rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), |
| CopyAttrsAddN, AddNRewrite, nullptr}); |
| rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add), |
| CopyAttrsDataType, AlwaysRewrite, nullptr}); |
| rinfo_.push_back({csinfo_.avg_pool, |
| mkl_op_registry::GetMklOpName(csinfo_.avg_pool), |
| CopyAttrsPooling, AlwaysRewrite, nullptr}); |
| rinfo_.push_back({csinfo_.avg_pool_grad, |
| mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad), |
| CopyAttrsPooling, AlwaysRewrite, nullptr}); |
| // BiasAddGrad gets written into Conv2DWithBiasBackpropBias depending |
| // on if context contains Conv2D. |
| rinfo_.push_back({csinfo_.bias_add_grad, |
| csinfo_.mkl_conv2d_with_bias_backprop_bias, |
| CopyAttrsBiasAddGrad, ContextMatchRewrite, |
| &biasaddgrad_conv2dwithbias_context_}); |
| // BiasAddGrad gets written into BiasAddGrad depending on if context |
| // contains MatMul. |
| rinfo_.push_back({csinfo_.bias_add_grad, csinfo_.matmul, |
| CopyAttrsBiasAddGrad, ContextMatchRewrite, |
| &biasaddgrad_matmul_context_}); |
| rinfo_.push_back({csinfo_.concat, |
| mkl_op_registry::GetMklOpName(csinfo_.concat), |
| CopyAttrsConcat, AlwaysRewrite, nullptr}); |
| rinfo_.push_back({csinfo_.concatv2, |
| mkl_op_registry::GetMklOpName(csinfo_.concatv2), |
| CopyAttrsConcatV2, AlwaysRewrite, nullptr}); |
| rinfo_.push_back({csinfo_.conv2d, |
| mkl_op_registry::GetMklOpName(csinfo_.conv2d), |
| CopyAttrsConv2D, AlwaysRewrite, nullptr}); |
| rinfo_.push_back({csinfo_.conv2d_grad_filter, |
| mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter), |
| CopyAttrsConv2D, AlwaysRewrite, nullptr}); |
| rinfo_.push_back({csinfo_.conv2d_grad_input, |
| mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input), |
| CopyAttrsConv2D, AlwaysRewrite, nullptr}); |
| |
| rinfo_.push_back({csinfo_.fused_batch_norm, |
| mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), |
| CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); |
| rinfo_.push_back( |
| {csinfo_.fused_batch_norm_grad, |
| mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), |
| CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); |
| rinfo_.push_back({csinfo_.identity, |
| mkl_op_registry::GetMklOpName(csinfo_.identity), |
| CopyAttrsIdentity, AlwaysRewrite, nullptr}); |
| rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn), |
| CopyAttrsLRN, AlwaysRewrite, nullptr}); |
| rinfo_.push_back({csinfo_.lrn_grad, |
| mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), |
| CopyAttrsLRN, AlwaysRewrite, nullptr}); |
| rinfo_.push_back({csinfo_.max_pool, |
| mkl_op_registry::GetMklOpName(csinfo_.max_pool), |
| CopyAttrsPooling, NonDepthBatchWisePoolRewrite, nullptr}); |
| rinfo_.push_back({csinfo_.max_pool_grad, |
| mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad), |
| CopyAttrsPooling, AlwaysRewrite, nullptr}); |
| rinfo_.push_back({csinfo_.maximum, |
| mkl_op_registry::GetMklOpName(csinfo_.maximum), |
| CopyAttrsDataType, AlwaysRewrite, nullptr}); |
| rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul), |
| CopyAttrsDataType, AlwaysRewrite, nullptr}); |
| rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu), |
| CopyAttrsDataType, AlwaysRewrite, nullptr}); |
| rinfo_.push_back({csinfo_.relu_grad, |
| mkl_op_registry::GetMklOpName(csinfo_.relu_grad), |
| CopyAttrsDataType, AlwaysRewrite, nullptr}); |
| rinfo_.push_back({csinfo_.reshape, |
| mkl_op_registry::GetMklOpName(csinfo_.reshape), |
| CopyAttrsReshape, AlwaysRewrite, nullptr}); |
| rinfo_.push_back({csinfo_.squared_difference, |
| mkl_op_registry::GetMklOpName(csinfo_.squared_difference), |
| CopyAttrsDataType, AlwaysRewrite, nullptr}); |
| rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub), |
| CopyAttrsDataType, AlwaysRewrite, nullptr}); |
| |
| // Add info about which ops to add workspace edge to and the slots. |
| wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3}); |
| wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3}); |
| |
| // Add a rule for merging nodes |
| minfo_.push_back({csinfo_.mkl_conv2d, csinfo_.bias_add, 0, |
| csinfo_.mkl_conv2d_with_bias}); |
| |
| biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul, |
| IsBiasAddGradInMatMulContext}; |
| |
| biasaddgrad_conv2dwithbias_context_ = { |
| csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias, |
| IsBiasAddGradInConv2DWithBiasContext}; |
| |
| cinfo_.push_back(&biasaddgrad_matmul_context_); |
| cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_); |
| } |
| |
| // Standard interface to run pass |
| Status Run(const GraphOptimizationPassOptions& options); |
| |
| // Helper function which does most of heavy lifting for rewriting |
| // Mkl nodes to propagate Mkl tensor as additional output |
| // |
| // Extracts common functionality between Run public interface and |
| // test interface. |
| // |
| // @return true, if and only if graph is mutated; false otherwise. |
| bool RunPass(std::unique_ptr<Graph>* g); |
| |
| /// Structure to specify the context information used in a node rewrite rule |
| typedef struct { |
| string node; // Name of the node to be rewritten |
| string fwd; // Name of the node in the forward pass that this node |
| // corresponds to |
| std::function<bool(const Node*, const Node**, void* c)> context_match_fn; |
| } ContextInfo; |
| |
| /// Structure to specify the name of an original node, its new name after |
| /// rewrite, the number of inputs to the original node, the function to |
| /// be used to copy attributes for the op, and the rule (if any) which |
| /// must hold for rewriting the node |
| typedef struct { |
| string name; // Original name of op of the node in the graph |
| string new_name; // New name of the op of the node in the graph |
| // A function handler to copy attributes from an old node to a new node. |
| std::function<void(const Node*, NodeBuilder*)> copy_attrs; |
| // A rule under which to rewrite this node |
| std::function<bool(const Node*, const ContextInfo* c)> rewrite_rule; |
| // ContextInfo, if any, to be used for rewrite |
| ContextInfo* context; |
| } RewriteInfo; |
| |
| /// Structure to specify a forward op, a backward op, and the slot numbers |
| /// in the forward and backward ops where we will add a workspace edge. |
| typedef struct { |
| string fwd_op; // Name of a forward op in the graph |
| string bwd_op; // Name of a backward op in the graph |
| int fwd_slot; // Output slot in the forward op node where actual |
| // output tensor resides |
| int bwd_slot; // Input slot in the backward op node where actual |
| // input tensor resides |
| int ws_fwd_slot; // Output slot in the forward op node where workspace |
| // edge is added |
| int ws_bwd_slot; // Input slot in the backward op node where workspace |
| // edge is added |
| } WorkSpaceInfo; |
| |
| /// Structure to specify information used in node merge |
| typedef struct { |
| string pred; // Predecessor node string |
| string succ; // Successor node string |
| int op; // The operand no the predecessor node corresponds |
| // to the successor node |
| string new_node; // Name of the node after merge |
| } MergeInfo; |
| |
| /// Structure to store all constant strings |
| /// NOTE: names are alphabetically sorted. |
| typedef struct { |
| string addn; |
| string add; |
| string avg_pool; |
| string avg_pool_grad; |
| string bias_add; |
| string bias_add_grad; |
| string concat; |
| string concatv2; |
| string conv2d; |
| string conv2d_grad_input; |
| string conv2d_grad_filter; |
| string fused_batch_norm; |
| string fused_batch_norm_grad; |
| string identity; |
| string lrn; |
| string lrn_grad; |
| string matmul; |
| string max_pool; |
| string max_pool_grad; |
| string maximum; |
| string mkl_conv2d; |
| string mkl_conv2d_grad_input; |
| string mkl_conv2d_grad_filter; |
| string mkl_conv2d_with_bias; |
| string mkl_conv2d_with_bias_backprop_bias; |
| string mul; |
| string relu; |
| string relu_grad; |
| string reshape; |
| string split; |
| string squared_difference; |
| string sub; |
| } ConstStringsInfo; |
| |
| private: |
| /// Maintain info about nodes to rewrite |
| std::vector<RewriteInfo> rinfo_; |
| |
| /// Maintain info about nodes to add workspace edge |
| std::vector<WorkSpaceInfo> wsinfo_; |
| |
| /// Maintain info about nodes to be merged |
| std::vector<MergeInfo> minfo_; |
| |
| /// Maintain info about nodes to rewrite |
| static std::vector<ContextInfo*> cinfo_; |
| |
| /// Maintain structure of constant strings |
| static ConstStringsInfo csinfo_; |
| |
| /// Context variables used in referencing rules |
| static ContextInfo biasaddgrad_matmul_context_; |
| static ContextInfo biasaddgrad_conv2dwithbias_context_; |
| |
| private: |
| // Is OpDef::ArgDef a list type? It could be N * T or list(type). |
| // Refer to opdef.proto for details of list type. |
| inline bool ArgIsList(const OpDef::ArgDef& arg) const { |
| return !arg.type_list_attr().empty() || !arg.number_attr().empty(); |
| } |
| |
| // Get length of a list in 'n' if 'arg' is of list type. Refer to |
| // description of ArgIsList for definition of list type. |
| inline int GetTensorListLength(const OpDef::ArgDef& arg, Node* n) { |
| CHECK_EQ(ArgIsList(arg), true); |
| int N = 0; |
| const string attr_name = !arg.type_list_attr().empty() |
| ? arg.type_list_attr() |
| : arg.number_attr(); |
| if (!arg.type_list_attr().empty()) { |
| std::vector<DataType> value; |
| TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value)); |
| N = value.size(); |
| } else { |
| TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N)); |
| } |
| return N; |
| } |
| |
| // Can op represented by node 'n' run on DEVICE_CPU? |
| // Op can run on CPU with MKL if the runtime assigned device or the |
| // user requested device contains device CPU, or both are empty. |
| bool CanOpRunOnCPUDevice(const Node* n) { |
| bool result = true; |
| string reason; |
| |
| // Substring that should be checked for in device name for CPU device. |
| const char* const kCPUDeviceSubStr = "CPU"; |
| |
| // If Op has been specifically assigned to a non-CPU device, then No. |
| if (!n->assigned_device_name().empty() && |
| !str_util::StrContains(n->assigned_device_name(), kCPUDeviceSubStr)) { |
| result = false; |
| reason = "Op has been assigned a runtime device that is not CPU."; |
| } |
| |
| // If user has specifically assigned this op to a non-CPU device, then No. |
| if (!n->def().device().empty() && |
| !str_util::StrContains(n->def().device(), kCPUDeviceSubStr)) { |
| result = false; |
| reason = "User has assigned a device that is not CPU."; |
| } |
| |
| if (result == false) { |
| VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node " |
| << n->type_string() << ", reason: " << reason; |
| } |
| |
| // Otherwise Yes. |
| return result; |
| } |
| |
| // Return a node that can be merged with input node 'n' |
| // |
| // @return pointer to the node if we can find such a |
| // node. Otherwise, it returns nullptr. |
| Node* CheckForNodeMerge(const Node* n) const; |
| |
| // Merge predecessor node with its successor. |
| // Currently, we merge Conv2D with BiasAdd only. |
| // |
| // Input nodes succ and pred may be deleted if the call to |
| // this function is successful. Attempt to use the pointers |
| // after the call to function may result in undefined behaviors. |
| // |
| // @input g - input graph, succ - successor node, pred - predecessor node |
| // @return Status::OK(), if merging is successful and supported. |
| // Returns appropriate Status error code otherwise. |
| // Graph is updated in case nodes are merged. Otherwise, it is |
| // not updated. |
| Status MergeNode(std::unique_ptr<Graph>* g, Node* succ, Node* pred); |
| |
| // Check if the node 'n' has any applicable rewrite rule |
| // We check for 2 scenarios for rewrite. |
| // |
| // @return RewriteInfo* for the applicable rewrite rule |
| const RewriteInfo* CheckForNodeRewrite(const Node* n) const; |
| |
| // Default rewrite rule to be used in scenario 1 for rewrite. |
| // @return - true (since we want to always rewrite) |
| static bool AlwaysRewrite(const Node* n, const ContextInfo* c = nullptr) { |
| return true; |
| } |
| |
| // Check if we are performing pooling on depth or batch. If it is, then we |
| // do not rewrite MaxPool node to Mkl version. |
| // @return - true (if it is not a depth/batch wise pooling case); |
| // false otherwise. |
| static bool NonDepthBatchWisePoolRewrite(const Node* n, |
| const ContextInfo* c) { |
| CHECK_NOTNULL(n); |
| |
| string data_format_str; |
| TensorFormat data_format; |
| std::vector<int32> ksize, strides; |
| CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true); |
| CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true); |
| CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true); |
| CHECK_EQ(FormatFromString(data_format_str, &data_format), true); |
| |
| // Condition that specifies non-batch-wise and non-depth-wise pooling. |
| if (GetTensorDim(ksize, data_format, 'N') == 1 && |
| GetTensorDim(strides, data_format, 'N') == 1 && |
| GetTensorDim(ksize, data_format, 'C') == 1 && |
| GetTensorDim(strides, data_format, 'C') == 1) { |
| return true; |
| } |
| |
| return false; |
| } |
| |
| static bool AddNRewrite(const Node* n, const ContextInfo* c) { |
| CHECK_NOTNULL(n); |
| |
| int num; |
| CHECK_EQ(GetNodeAttr(n->def(), "N", &num).ok(), true); |
| |
| // Condition that specifies non-batch-wise and non-depth-wise pooling. |
| if (num == 2) { |
| return true; |
| } |
| |
| return false; |
| } |
| // Is BiasAddGrad node in 'n' is associated with Conv2DWithBias node |
| // specified in contextinfo 'ci'. Function updates fwd_node to point |
| // to Conv2DWithBias node if 'n' is associated with Conv2DWithBias. |
| // |
| // Association checks for one of the following graphs: |
| // |
| // Graph A: |
| // |
| // _ = Conv2DWithBias(F, I, _) |
| // .. |
| // _ = Conv2DBackpropFilter(F, _, G) |
| // _ = Conv2DBackpropInput(_, I, G) |
| // _ = BiasAddGrad(G) |
| // |
| // OR |
| // |
| // Graph B: |
| // |
| // _ = Conv2DWithBias(F, _, _) |
| // .. |
| // _ = Conv2DBackpropFilter(F, _, G) |
| // _ = BiasAddGrad(G) |
| // |
| // Here F, G, and I are graph nodes; _ represents graph nodes that we |
| // don't care here. |
| // |
| // @return - true (if BiasAddGrad is associated with Conv2DWithBias); |
| // false otherwise. |
| static bool IsBiasAddGradInConv2DWithBiasContext(const Node* n, |
| const Node** fwd_node, |
| void* ci) { |
| CHECK_NOTNULL(n); |
| CHECK_NOTNULL(fwd_node); |
| CHECK_NOTNULL(ci); |
| *fwd_node = nullptr; |
| |
| CHECK_EQ(n->type_string(), csinfo_.bias_add_grad); |
| |
| // Get the only 1 input of BiasAddGrad. |
| CHECK_EQ(n->num_inputs(), 1); |
| const Node* bias_add_grad_inp = nullptr; |
| TF_CHECK_OK(n->input_node(0, &bias_add_grad_inp)); |
| CHECK_NOTNULL(bias_add_grad_inp); |
| |
| // Check if this input also goes to BackpropFilter and BackpropInput |
| // as 3rd input. |
| bool found_backprop_input = false; |
| bool found_backprop_filter = false; |
| Node* backprop_filter_node = nullptr; |
| Node* backprop_input_node = nullptr; |
| |
| for (const Edge* e : bias_add_grad_inp->out_edges()) { |
| Node* third_input = nullptr; |
| if (e->dst()->type_string() == csinfo_.conv2d_grad_input || |
| e->dst()->type_string() == csinfo_.mkl_conv2d_grad_input) { |
| // Third input (index 2) of BackpropInput |
| TF_CHECK_OK(e->dst()->input_node(2, &third_input)); |
| // Third input (index 2) of BackpropInput must be same as the input |
| // of BiasAddGrad. |
| if (third_input == bias_add_grad_inp) { |
| found_backprop_input = true; |
| backprop_input_node = e->dst(); |
| } |
| } |
| |
| if (e->dst()->type_string() == csinfo_.conv2d_grad_filter || |
| e->dst()->type_string() == csinfo_.mkl_conv2d_grad_filter) { |
| // Third input (index 2) of BackpropFilter |
| TF_CHECK_OK(e->dst()->input_node(2, &third_input)); |
| // Third input (index 2) of BackpropFilter must be same as the input |
| // of BiasAddGrad. |
| if (third_input == bias_add_grad_inp) { |
| found_backprop_filter = true; |
| backprop_filter_node = e->dst(); |
| } |
| } |
| |
| // If we found both the nodes, then we can stop the search. |
| if (found_backprop_input && found_backprop_filter) { |
| break; |
| } |
| } |
| |
| // If BackpropFilter node is not found, then this is not |
| // Conv2DWithBias context. For 2nd graph in the example above, only |
| // BackpropFilter would be present. |
| if (!found_backprop_filter) { |
| return false; |
| } |
| |
| // Otherwise, we found the nodes. |
| CHECK_NOTNULL(backprop_filter_node); |
| if (found_backprop_input) { |
| CHECK_NOTNULL(backprop_input_node); |
| } |
| |
| // Now that we confirmed that this is Conv2DWithBias context, we need to |
| // get access to the forward node (Conv2DWithBias). 2nd input of |
| // Conv2DWithBias is same as the 2nd input of Conv2DBackpropInput; 1st |
| // input of Conv2DWithBias is same as the 1st input of Conv2DBackpropFilter |
| // (This comes from definition of gradient computation for Conv2D). |
| if (found_backprop_input) { |
| // Graph A in the example. |
| Node* second_inp_of_input = nullptr; |
| Node* first_inp_of_filter = nullptr; |
| TF_CHECK_OK(backprop_input_node->input_node(1, &second_inp_of_input)); |
| TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter)); |
| CHECK_NOTNULL(second_inp_of_input); |
| CHECK_NOTNULL(first_inp_of_filter); |
| |
| // Now we need to find out Conv2DWithBias node from these input nodes. |
| // Conv2DWithBias node is the node that accepts both the nodes |
| // second_inp_of_input and first_inp_of_filter in 2nd and 1st input slots. |
| for (const Edge* fe : first_inp_of_filter->out_edges()) { |
| if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && |
| fe->dst_input() == 0) { |
| for (const Edge* ie : second_inp_of_input->out_edges()) { |
| if (ie->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && |
| ie->dst_input() == 1 && fe->dst() == ie->dst()) { |
| VLOG(1) << "MklLayoutRewritePass: found " |
| << fe->dst()->DebugString() |
| << " as the forward node for matching context, backward" |
| << " node is: " << n->DebugString(); |
| *fwd_node = fe->dst(); |
| return true; |
| } |
| } |
| } |
| } |
| } else { |
| // We did not find BackpropInput, so we work with BackpropFilter only. |
| // Graph B in the example. |
| Node* first_inp_of_filter = nullptr; |
| TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter)); |
| CHECK_NOTNULL(first_inp_of_filter); |
| |
| // Now we need to find out Conv2DWithBias node from first input of |
| // BackpropFIlter. Conv2DWithBias node is the node that accepts |
| // first_inp_of_filter in 1st input slot. |
| for (const Edge* fe : first_inp_of_filter->out_edges()) { |
| if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && |
| fe->dst_input() == 0) { |
| VLOG(1) << "MklLayoutRewritePass: found " << fe->dst()->DebugString() |
| << " as the forward node for matching context, backward" |
| << " node is: " << n->DebugString(); |
| *fwd_node = fe->dst(); |
| return true; |
| } |
| } |
| } |
| |
| return false; |
| } |
| |
| // Is BiasAddGrad node in 'n' is associated with MatMul node |
| // specified in contextinfo 'ci'. Function does not update fwd_node. |
| // |
| // @return - true (if BiasAddGrad is associated with MatMul); |
| // false otherwise. |
| static bool IsBiasAddGradInMatMulContext(const Node* n, const Node** fwd_node, |
| void* ci) { |
| return (!IsBiasAddGradInConv2DWithBiasContext(n, fwd_node, ci)); |
| } |
| |
| // Rewrite rule that uses context-information for matching, |
| // used in scenario 2. |
| // |
| // @input - Node 'n' for which to search for matching context |
| // @input - The context 'c' under which to rewrite |
| // @return - true if we can rewrite node under context 'c'; |
| // false otherwise. |
| static bool ContextMatchRewrite(const Node* n, const ContextInfo* c); |
| |
| // Helper function that searches the matching contextinfo for the node. |
| // |
| // @input n - Node (gradient op) whose contextinfo is to be searched, |
| // fwd_node - pointer to node from the forward pass that this node |
| // belongs to. fwd_node cannot be NULL. |
| // @return Matching contextinfo in case a match is found; null otherwise. |
| // Also updates *fwd_node with pointer to forward node that this |
| // context matches. |
| static const ContextInfo* SearchMatchingContext(const Node* n, |
| const Node** fwd_node); |
| |
| // Rewrites input node to a new node specified by its matching rewrite info. |
| // |
| // Method first searches matching rewrite info for input node and then |
| // uses that info to rewrite. |
| // |
| // Input node may be deleted in case of rewrite. Attempt to use the node |
| // after the call can result in undefined behaviors. |
| // |
| // @input g - input graph, n - Node to be rewritten, |
| // ri - matching rewriteinfo |
| // @return Status::OK(), if the input node is rewritten; |
| // Returns appropriate Status error code otherwise. |
| // Graph is updated in case the input node is rewritten. |
| // Otherwise, it is not updated. |
| Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri); |
| |
| // Get nodes that will feed a list of TF tensors to the new |
| // node that we are constructing. |
| // |
| // @input g - input graph, |
| // @input inputs - inputs to old node that we are using for constructing |
| // new inputs, |
| // @input input_idx - the index in the 'inputs' vector pointing to the |
| // current input that we have processed so far |
| // @output input_idx - index will be incremented by the number of nodes |
| // from 'inputs' that are processed |
| // @input list_length - The expected length of list of TF tensors |
| // @output output_nodes - the list of new nodes creating TF tensors |
| // |
| // @return None |
| void GetNodesProducingTFTensorList( |
| const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, |
| int* input_idx, int list_length, |
| std::vector<NodeBuilder::NodeOut>* output_nodes); |
| |
| // Get nodes that will feed a list of Mkl tensors to the new |
| // node that we are constructing. |
| // |
| // @input g - input graph, |
| // @input orig_node - Original node that we are rewriting |
| // @input inputs - inputs to old node that we are using for constructing |
| // new inputs, |
| // @input input_idx - the index in the 'inputs' vector pointing to the |
| // current input that we have processed so far |
| // @output input_idx - index will be incremented by the number of nodes |
| // from 'inputs' that are processed |
| // @input list_length - The expected length of list of Mkl tensors |
| // @output output_nodes - the list of new nodes creating Mkl tensors |
| // |
| // @return None |
| void GetNodesProducingMklTensorList( |
| std::unique_ptr<Graph>* g, Node* orig_node, |
| const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, |
| int* input_idx, int list_length, |
| std::vector<NodeBuilder::NodeOut>* output_nodes); |
| |
| // Get a node that will feed an Mkl tensor to the new |
| // node that we are constructing. The output node could be (1) 'n' |
| // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor |
| // if 'n' is not an Mkl layer. |
| // |
| // @input g - input graph, |
| // @input orig_node - Original node that we are rewriting, |
| // @input n - Node based on which we are creating Mkl node, |
| // @input n_output_slot - the output slot of node 'n' |
| // which is feeding to the node that we are constructing |
| // @output mkl_node - the new node that will feed Mkl tensor |
| // @output mkl_node_output_slot - the slot number of mkl_node that |
| // will feed the tensor |
| // @return None |
| void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node, |
| Node* n, int n_output_slot, Node** mkl_node, |
| int* mkl_node_output_slot); |
| |
| // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' |
| // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are |
| // set up in contiguous fashion. 'workspace_tensors' carry graph nodes |
| // producing workspace edges if 'are_workspace_tensors_available' is true. |
| // Otherwise, 'workspace_tensors' is empty vector. |
| // |
| // For details, refer to 'Ordering of inputs after rewriting' section in the |
| // documentation above. |
| // |
| // Returns Status::OK() if setting up inputs is successful, otherwise |
| // returns appropriate status code. |
| int SetUpContiguousInputs( |
| std::unique_ptr<Graph>* g, |
| const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, |
| NodeBuilder* nb, Node* old_node, |
| std::vector<NodeBuilder::NodeOut>* workspace_tensors, |
| bool are_workspace_tensors_available); |
| |
| // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' |
| // in graph 'g'. Original node is input in 'orig_node'. |
| // |
| // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors' |
| // section in the documentation above. |
| // |
| // Returns Status::OK() if setting up inputs is successful, otherwise |
| // returns appropriate status code. |
| Status SetUpInputs(std::unique_ptr<Graph>* g, |
| const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, |
| NodeBuilder* nb, Node* orig_node); |
| |
| // Add workspace edge on the input or output side of Node 'orig_node' by using |
| // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate |
| // adding workspace edge then do not add it. Workspace Tensorflow and Mkl |
| // tensors, if they need to be added, will be set into these tensors. |
| // If we set workspace tensors, then are_ws_tensors_added should be true. |
| void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orig_node, |
| NodeBuilder* nb, |
| std::vector<NodeBuilder::NodeOut>* ws_tensors, |
| bool* are_ws_tensors_added); |
| |
| // Functions specific to operators to copy attributes |
| // We need operator-specific function to copy attributes because the framework |
| // does not provide any generic function for it. |
| // NOTE: names are alphabetically sorted. |
| static void CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsIdentity(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb); |
| |
| // Generate a graph node in graph 'g' representing a dummy Mkl tensor node, |
| // using node for original node 'orig_node' and return it in '*out'. |
| // TODO(nhasabni) We should move this to mkl_util.h |
| void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out, |
| Node* orig_node); |
| void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out, |
| Node* orig_node); |
| }; |
| |
| MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_; |
| MklLayoutRewritePass::ContextInfo |
| MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_; |
| MklLayoutRewritePass::ContextInfo |
| MklLayoutRewritePass::biasaddgrad_matmul_context_; |
| std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_; |
| |
| // We register Mkl rewrite pass for phase 1 in post partitioning group. |
| // We register it here so that we get a complete picture of all users of Mkl |
| // nodes. Do not change the ordering of the Mkl passes. |
| const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup = |
| OptimizationPassRegistry::POST_PARTITIONING; |
| REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass); |
| |
| ////////////////////////////////////////////////////////////////////////// |
| // Helper functions for creating new node |
| ////////////////////////////////////////////////////////////////////////// |
| |
| static void FillInputs(const Node* n, |
| gtl::InlinedVector<Node*, 4>* control_edges, |
| gtl::InlinedVector<std::pair<Node*, int>, 4>* in) { |
| control_edges->clear(); |
| for (const Edge* e : n->in_edges()) { |
| if (e->IsControlEdge()) { |
| control_edges->push_back(e->src()); |
| } else { |
| (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output()); |
| } |
| } |
| std::sort(control_edges->begin(), control_edges->end()); |
| if (n->op_def().is_commutative()) { |
| // For commutative inputs, we sort the input by the input Node* |
| // to get a canonical ordering (so that add(a,b) and add(b, a) will |
| // hash to the same value if is_commutative is true for 'add'). |
| std::sort(in->begin(), in->end()); |
| } |
| } |
| |
| void MklLayoutRewritePass::GetNodesProducingTFTensorList( |
| const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx, |
| int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) { |
| CHECK_LT(*input_idx, inputs.size()); |
| CHECK_GT(list_length, 0); |
| CHECK_NOTNULL(output_nodes); |
| output_nodes->reserve(list_length); |
| |
| while (list_length != 0) { |
| CHECK_GT(list_length, 0); |
| CHECK_LT(*input_idx, inputs.size()); |
| Node* n = inputs[*input_idx].first; |
| int slot = inputs[*input_idx].second; |
| // If input node 'n' is just producing a single tensor at |
| // output slot 'slot' then we just add that single node. |
| output_nodes->push_back(NodeBuilder::NodeOut(n, slot)); |
| (*input_idx)++; |
| list_length--; |
| } |
| } |
| |
| // TODO(nhasabni) We should move this to mkl_util.h. |
| void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g, |
| Node** out, Node* orig_node) { |
| // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent |
| // dummy Mkl tensor. 8 = 2*size_t. |
| const DataType dt = DataTypeToEnum<uint8>::v(); |
| TensorProto proto; |
| proto.set_dtype(dt); |
| uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; |
| proto.set_tensor_content(string(reinterpret_cast<const char*>(zero), 8)); |
| TensorShape dummy_shape({8}); |
| dummy_shape.AsProto(proto.mutable_tensor_shape()); |
| TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") |
| .Attr("value", proto) |
| .Attr("dtype", dt) |
| .Device(orig_node->def().device()) // We place this node on |
| // the same device as the |
| // device of the original |
| // node. |
| .Finalize(&**g, out)); |
| CHECK_NOTNULL(*out); // Make sure we got a valid object before using it |
| |
| // If number of inputs to the original node is > 0, then we add |
| // control dependency between 1st input (index 0) of the original node and |
| // the dummy Mkl node. This is needed because control-flow ops such as Enter, |
| // Merge, etc, require frame_name of the dummy Mkl node to be same as the |
| // rewritten node. Adding control edge between 1st input of the original node |
| // and the dummy Mkl node ensures that the dummy node is in the same frame |
| // as the original node. Choosing 1st input is not necessary - any input of |
| // the original node is fine because all the inputs of a node are always in |
| // the same frame. |
| if (orig_node->num_inputs() > 0) { |
| Node* orig_input0 = nullptr; |
| TF_CHECK_OK( |
| orig_node->input_node(0, const_cast<const Node**>(&orig_input0))); |
| CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); |
| } |
| |
| (*out)->set_assigned_device_name(orig_node->assigned_device_name()); |
| } |
| |
| void MklLayoutRewritePass::GetNodesProducingMklTensorList( |
| std::unique_ptr<Graph>* g, Node* orig_node, |
| const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx, |
| int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) { |
| CHECK_LT(*input_idx, inputs.size()); |
| CHECK_GT(list_length, 0); |
| CHECK_NOTNULL(output_nodes); |
| output_nodes->reserve(list_length); |
| |
| while (list_length != 0) { |
| CHECK_GT(list_length, 0); |
| CHECK_LT(*input_idx, inputs.size()); |
| Node* n = inputs[*input_idx].first; |
| int slot = inputs[*input_idx].second; |
| // If 'n' is producing a single tensor, then create a single Mkl tensor |
| // node. |
| Node* mkl_node = nullptr; |
| int mkl_node_output_slot = 0; |
| GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node, |
| &mkl_node_output_slot); |
| output_nodes->push_back( |
| NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); |
| (*input_idx)++; |
| list_length--; |
| } |
| } |
| |
| // Get an input node that will feed Mkl tensor to the new |
| // node that we are constructing. An input node could be (1) 'n' |
| // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor |
| // if 'n' is not an Mkl layer. |
| void MklLayoutRewritePass::GetNodeProducingMklTensor( |
| std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot, |
| Node** mkl_node, int* mkl_node_output_slot) { |
| CHECK_NOTNULL(n); |
| CHECK_NOTNULL(mkl_node); |
| CHECK_NOTNULL(mkl_node_output_slot); |
| |
| // If this is an MKL op, then it will create extra output for MKL layout. |
| DataType T; |
| if (GetNodeAttr(n->def(), "T", &T).ok() && |
| mkl_op_registry::IsMklOp(n->type_string(), T)) { |
| // If this is an MKL op, then it will generate an edge that will receive |
| // Mkl tensor from a node. |
| // output slot number for Mkl tensor would be N+slot number of TensorFlow |
| // tensor, where N is total number of TensorFlow tensors. |
| *mkl_node = n; |
| *mkl_node_output_slot = |
| GetTensorMetaDataIndex(n_output_slot, n->num_outputs()); |
| } else { |
| // If we have not visited the node and rewritten it, then we need |
| // to create a dummy node that will feed a dummy Mkl tensor to this node. |
| // DummyMklTensor node has no input and generates only 1 output |
| // (dummy Mkl tensor) as output slot number 0. |
| GetDummyMklTensorNode(g, mkl_node, orig_node); |
| CHECK_NOTNULL(*mkl_node); |
| *mkl_node_output_slot = 0; |
| } |
| } |
| |
| int MklLayoutRewritePass::SetUpContiguousInputs( |
| std::unique_ptr<Graph>* g, |
| const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, |
| NodeBuilder* nb, Node* old_node, |
| std::vector<NodeBuilder::NodeOut>* workspace_tensors, |
| bool are_workspace_tensors_available) { |
| CHECK_NOTNULL(workspace_tensors); |
| CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); |
| |
| // TODO(nhasabni): Temporary solution to connect filter input of |
| // BackpropInput with the converted filter from Conv2D. |
| bool do_connect_conv2d_backprop_input_filter = false; |
| Node* conv2d_node = nullptr; |
| // Filter node is 2nd input (slot index 1) of Conv2D. |
| int kConv2DFilterInputSlotIdx = 1; |
| int kConv2DBackpropInputFilterInputSlotIdx = 1; |
| int kConv2DFilterOutputSlotIdx = 1; |
| if (old_node->type_string() == csinfo_.conv2d_grad_input) { |
| // We need to find Conv2D node from Conv2DBackpropInput. |
| // For that let's first find filter node that is 2nd input (slot 1) |
| // of BackpropInput. |
| Node* filter_node = nullptr; |
| TF_CHECK_OK(old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx, |
| &filter_node)); |
| CHECK_NOTNULL(filter_node); |
| |
| // Now check which nodes receive from filter_node. Filter feeds as |
| // 2nd input (slot 1) of _MklConv2D and _MklConv2DWithBias. |
| for (const Edge* e : filter_node->out_edges()) { |
| if (e->dst()->type_string() == csinfo_.mkl_conv2d && |
| e->dst_input() == kConv2DFilterInputSlotIdx |
| /* filter is 2nd input of Conv2D and _MklConv2D. */) { |
| if (conv2d_node != nullptr) { |
| VLOG(1) << "MklLayoutRewritePass: unusual case of same filter" |
| << " feeding multiple Conv2D nodes: " |
| << filter_node->DebugString(); |
| // We will not connect filter input of Conv2DBackpropInput |
| // to be safe here. |
| do_connect_conv2d_backprop_input_filter = false; |
| break; |
| } else { |
| conv2d_node = e->dst(); |
| do_connect_conv2d_backprop_input_filter = true; |
| } |
| } |
| } |
| } |
| |
| // Number of input slots to original op |
| // Input slots are represented by .Input() calls in REGISTER_OP. |
| int old_node_input_slots = old_node->op_def().input_arg_size(); |
| // Actual number of inputs can be greater than or equal to number |
| // of Input slots because inputs of type list could be unfolded. |
| CHECK_GE(old_node_inputs.size(), old_node_input_slots); |
| int nn_slot_idx = 0; // slot index for inputs of new node |
| |
| // Let's copy all inputs (TF tensors) of original node to new node. |
| int iidx = 0; |
| for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) { |
| // An input slot could be a single tensor or a list. We need |
| // to handle this case accordingly. |
| CHECK_LT(iidx, old_node_inputs.size()); |
| const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx); |
| if (ArgIsList(arg)) { |
| std::vector<NodeBuilder::NodeOut> new_node_inputs; |
| int N = GetTensorListLength(arg, old_node); |
| GetNodesProducingTFTensorList(old_node_inputs, &iidx, N, |
| &new_node_inputs); |
| nb->Input(new_node_inputs); |
| nn_slot_idx++; |
| } else { |
| // Special case for connecting filter input of Conv2DBackpropInput |
| if (do_connect_conv2d_backprop_input_filter && |
| iidx == kConv2DBackpropInputFilterInputSlotIdx) { |
| nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx); |
| } else { |
| nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second); |
| } |
| iidx++; |
| nn_slot_idx++; |
| } |
| } |
| |
| // If workspace tensors are available for this op and we are using |
| // contiguous ordering then we need to add Tensorflow tensor for |
| // workspace here because Tensorflow tensor for workspace is the |
| // last tensor in the list of Tensorflow tensors. |
| if (are_workspace_tensors_available) { |
| CHECK_EQ(workspace_tensors->size(), 2); |
| // Tensorflow tensor |
| nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index); |
| nn_slot_idx++; |
| } |
| |
| // Let's now setup all Mkl inputs to new node. |
| // Number of Mkl inputs must be same as number of TF inputs. |
| iidx = 0; |
| for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) { |
| // An input slot could be a single tensor or a list. We need |
| // to handle this case accordingly. |
| CHECK_LT(iidx, old_node_inputs.size()); |
| const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx); |
| if (ArgIsList(arg)) { |
| std::vector<NodeBuilder::NodeOut> new_node_inputs; |
| int N = GetTensorListLength(arg, old_node); |
| GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N, |
| &new_node_inputs); |
| nb->Input(new_node_inputs); |
| nn_slot_idx++; |
| } else { |
| Node* mkl_node = nullptr; |
| int mkl_node_output_slot = 0; |
| // Special case for connecting filter input of Conv2DBackpropInput |
| if (do_connect_conv2d_backprop_input_filter && |
| iidx == kConv2DBackpropInputFilterInputSlotIdx) { |
| GetNodeProducingMklTensor(g, old_node, conv2d_node, |
| kConv2DFilterOutputSlotIdx, &mkl_node, |
| &mkl_node_output_slot); |
| } else { |
| GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first, |
| old_node_inputs[iidx].second, &mkl_node, |
| &mkl_node_output_slot); |
| } |
| nb->Input(mkl_node, mkl_node_output_slot); |
| iidx++; |
| nn_slot_idx++; |
| } |
| } |
| |
| // If workspace tensors are available for this op and we are using |
| // contiguous ordering then we need to add Mkl tensor for |
| // workspace here because Mkl tensor for workspace is the |
| // last tensor in the list of Mkl tensors. |
| if (are_workspace_tensors_available) { |
| CHECK_EQ(workspace_tensors->size(), 2); |
| // Mkl tensor |
| nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index); |
| nn_slot_idx++; |
| } |
| |
| return nn_slot_idx; |
| } |
| |
| Status MklLayoutRewritePass::SetUpInputs( |
| std::unique_ptr<Graph>* g, |
| const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, |
| NodeBuilder* nb, Node* old_node) { |
| // Let's check if we need to add workspace tensors for this node. |
| // We add workspace edge only for MaxPool, LRN and BatchNorm. |
| std::vector<NodeBuilder::NodeOut> workspace_tensors; |
| bool are_workspace_tensors_available = false; |
| AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors, |
| &are_workspace_tensors_available); |
| |
| int new_node_input_slots = 0; |
| if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { |
| // TODO(nhasabni): implement this function just for same of completion. |
| // We do not use interleaved ordering right now. |
| return Status( |
| error::Code::UNIMPLEMENTED, |
| "Interleaved ordering of tensors is currently not supported."); |
| } else { |
| CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); |
| new_node_input_slots = SetUpContiguousInputs( |
| g, old_node_inputs, nb, old_node, &workspace_tensors, |
| are_workspace_tensors_available); |
| } |
| |
| // Sanity check |
| int old_node_input_slots = old_node->op_def().input_arg_size(); |
| if (!are_workspace_tensors_available) { |
| // If we are not adding workspace tensors for this op, then the total |
| // number of input slots to the new node _must_ be 2 times the number |
| // of input slots to the original node: N original Tensorflow tensors and |
| // N for Mkl tensors corresponding to each Tensorflow tensors. |
| CHECK_EQ(new_node_input_slots, old_node_input_slots * 2); |
| } else { |
| // If we are adding workspace tensors for this op, then the total |
| // The total number of input slots to new node _must_ be 2 times the number |
| // of input slots to the original node: N original Tensorflow tensors and |
| // N for Mkl tensors corresponding to each Tensorflow tensors plus 2 |
| // (for workspace Tensorflow tensor and workspace Mkl tensor). |
| CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2); |
| } |
| |
| return Status::OK(); |
| } |
| |
| ////////////////////////////////////////////////////////////////////////// |
| // Helper functions related to workspace pass |
| ////////////////////////////////////////////////////////////////////////// |
| |
| // TODO(nhasabni) We should move this to mkl_util.h. |
| void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( |
| std::unique_ptr<Graph>* g, Node** out, Node* orig_node) { |
| // We use a tensor of shape {1} and value 0 to represent |
| // dummy float tensor. We need this as a dummy workspace tensor. |
| // Workspace tensor has type float. |
| const DataType dt = DataTypeToEnum<float>::v(); |
| TensorProto proto; |
| proto.set_dtype(dt); |
| float zero[1] = {0}; |
| proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 4)); |
| TensorShape dummy_shape({1}); |
| dummy_shape.AsProto(proto.mutable_tensor_shape()); |
| TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") |
| .Attr("value", proto) |
| .Attr("dtype", dt) |
| .Device(orig_node->def().device()) // We place this node on |
| // same the device as the |
| // device of the original |
| // node. |
| .Finalize(&**g, out)); |
| CHECK_NOTNULL(*out); // Make sure we got a valid object before using it |
| |
| // If number of inputs to the original node is > 0, then we add |
| // control dependency between 1st input (index 0) of the original node and |
| // the dummy Mkl node. This is needed because control-flow ops such as Enter, |
| // Merge, etc, require frame_name of the dummy Mkl node to be same as the |
| // rewritten node. Adding control edge between 1st input of the original node |
| // and the dummy Mkl node ensures that the dummy node is in the same frame |
| // as the original node. Choosing 1st input is not necessary - any input of |
| // the original node is fine because all the inputs of a node are always in |
| // the same frame. |
| if (orig_node->num_inputs() > 0) { |
| Node* orig_input0 = nullptr; |
| TF_CHECK_OK( |
| orig_node->input_node(0, const_cast<const Node**>(&orig_input0))); |
| CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); |
| } |
| |
| (*out)->set_assigned_device_name(orig_node->assigned_device_name()); |
| } |
| |
| void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( |
| std::unique_ptr<Graph>* g, Node* orig_node, NodeBuilder* nb, |
| std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) { |
| bool workspace_edge_added = false; // Default initializer |
| CHECK_NOTNULL(are_ws_tensors_added); |
| *are_ws_tensors_added = false; // Default initializer |
| |
| DataType T; |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| for (auto ws : wsinfo_) { |
| if (orig_node->type_string() == ws.fwd_op && |
| mkl_op_registry::IsMklOp( |
| mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { |
| // If this op is a fwd op, then we need to check if there is an |
| // edge from this node's fwd_slot to bwdop's bwd_slot. If there is |
| // an edge, then we just add an attribute on this node for setting |
| // workspace_passed to true. We don't add actual workspace edge |
| // in this node. Actual workspace edge gets added in the backward |
| // op for this node. |
| for (const Edge* e : orig_node->out_edges()) { |
| if (e->src_output() == ws.fwd_slot && |
| e->dst()->type_string() == ws.bwd_op && |
| e->dst_input() == ws.bwd_slot) { |
| nb->Attr("workspace_enabled", true); |
| VLOG(1) << "MklLayoutRewritePass: workspace_enabled for " |
| << orig_node->type_string(); |
| workspace_edge_added = true; |
| // We found the edge that we were looking for, so break. |
| break; |
| } |
| } |
| |
| if (!workspace_edge_added) { |
| // If we are here, then we did not find backward operator for this |
| // node. |
| nb->Attr("workspace_enabled", false); |
| } |
| } else if (orig_node->type_string() == ws.bwd_op && |
| mkl_op_registry::IsMklOp( |
| mkl_op_registry::GetMklOpName(orig_node->type_string()), |
| T)) { |
| // If this op is a bwd op, then we need to add workspace edge and |
| // it's Mkl tensor edge between its corresponding fwd op and this |
| // op. Corresponding fwd op is specified in 'fwd_op' field of |
| // workspace info. fwd_slot and bwd_slot in workspace info specify |
| // an edge between which slots connect forward and backward op. |
| // Once all these criteria match, we add a workspace edge between |
| // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is |
| // determined by interleaved/contiguous ordering. Function |
| // DataIndexToMetaDataIndex tells us the location of Mkl tensor |
| // from the location of the Tensorflow tensor. |
| for (const Edge* e : orig_node->in_edges()) { |
| if (e->src_output() == ws.fwd_slot && |
| // We would have rewritten the forward op, so we need to use |
| // GetMklOpName call to get its Mkl name. |
| e->src()->type_string() == |
| mkl_op_registry::GetMklOpName(ws.fwd_op) && |
| e->dst_input() == ws.bwd_slot) { |
| nb->Attr("workspace_enabled", true); |
| CHECK_NOTNULL(ws_tensors); |
| // Add workspace edge between fwd op and bwd op. |
| ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot)); |
| // Add Mkl tensor edge for workspace edge between fwd op and bwd op. |
| ws_tensors->push_back(NodeBuilder::NodeOut( |
| e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot, |
| e->src()->num_outputs()))); |
| *are_ws_tensors_added = true; |
| // In terms of input ordering, we add these calls to add Input |
| // here because workspace edge (and its Mkl tensor) is the last |
| // edge in the fwdop and bwdop. So all inputs before workspace |
| // tensor have been added by SetUpInputs function. |
| VLOG(1) << "MklLayoutRewritePass: workspace_enabled for " |
| << orig_node->type_string(); |
| workspace_edge_added = true; |
| // We found the edge that we were looking for, so break. |
| break; |
| } |
| } |
| |
| // If we are here means we did not find fwd op that feeds to this |
| // bwd op. So in this case, we need to generate dummy tensors for |
| // workspace input and Mkl tensor for workspace, and set |
| // workspace_enabled to false. |
| if (!workspace_edge_added) { |
| nb->Attr("workspace_enabled", false); |
| Node* dmt_ws = nullptr; // Dummy tensor for workspace |
| Node* dmt_mkl_ws = nullptr; // Dummy Mkl tensor for workspace |
| GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node); |
| GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node); |
| CHECK_NOTNULL(dmt_ws); |
| CHECK_NOTNULL(dmt_mkl_ws); |
| CHECK_NOTNULL(ws_tensors); |
| // We add dummy tensor as workspace tensor. |
| ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0)); |
| // We add dummy tensor as Mkl tensor for workspace tensor. |
| ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0)); |
| *are_ws_tensors_added = true; |
| VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for " |
| << orig_node->type_string(); |
| } |
| } else { |
| // If this node does not match any workspace info, then we do not |
| // do anything special for workspace propagation for it. |
| } |
| } |
| } |
| |
| ////////////////////////////////////////////////////////////////////////// |
| // Op-specific functions to copy attributes from old node to new node |
| ////////////////////////////////////////////////////////////////////////// |
| |
| void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| string data_format; |
| string padding; |
| std::vector<int32> strides; |
| bool use_cudnn_on_gpu; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); |
| TF_CHECK_OK( |
| GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("strides", strides); |
| nb->Attr("padding", padding); |
| nb->Attr("data_format", data_format); |
| nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| int N; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("N", N); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| string data_format; |
| std::vector<int32> strides; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("strides", strides); |
| nb->Attr("data_format", data_format); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsIdentity(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| int depth_radius; |
| float bias; |
| float alpha; |
| float beta; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "depth_radius", &depth_radius)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "bias", &bias)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "beta", &beta)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("depth_radius", depth_radius); |
| nb->Attr("bias", bias); |
| nb->Attr("alpha", alpha); |
| nb->Attr("beta", beta); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| string data_format; |
| string padding; |
| std::vector<int32> ksize, strides; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("ksize", ksize); |
| nb->Attr("strides", strides); |
| nb->Attr("padding", padding); |
| nb->Attr("data_format", data_format); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| DataType Tshape; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape)); |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("Tshape", Tshape); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| string data_format; |
| int num_split; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_split", &num_split)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("num_split", num_split); |
| nb->Attr("data_format", data_format); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsConcat(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| int N; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("N", N); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsConcatV2(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| int N; |
| DataType tidx; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tidx", &tidx)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("N", N); |
| nb->Attr("Tidx", tidx); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| float epsilon; |
| string data_format; |
| bool is_training; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "is_training", &is_training)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("epsilon", epsilon); |
| nb->Attr("data_format", data_format); |
| nb->Attr("is_training", is_training); |
| } |
| |
| ////////////////////////////////////////////////////////////////////////// |
| // Helper functions related to node merge pass |
| ////////////////////////////////////////////////////////////////////////// |
| |
| Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const { |
| // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite |
| // once we support BiasAddGrad as Mkl layer. |
| |
| // Search for all matching mergeinfo. |
| // We allow more than one match for extensibility. |
| std::vector<const MergeInfo*> matching_mi; |
| for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) { |
| if (a->type_string() == mi->succ) { |
| matching_mi.push_back(&*mi); |
| } |
| } |
| |
| for (const MergeInfo* mi : matching_mi) { |
| const int N_in = a->num_inputs(); |
| if (mi->op >= N_in) { |
| continue; |
| } |
| |
| // Get the control edges and input of node |
| gtl::InlinedVector<Node*, 4> a_control_edges; |
| gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in); |
| FillInputs(a, &a_control_edges, &a_in); |
| |
| // Get operand op of the operator |
| Node* b = nullptr; |
| b = a_in[mi->op].first; |
| if (b == nullptr || (b->type_string() != mi->pred)) { |
| // NOTE: Should the first check be assert? |
| continue; |
| } |
| |
| const int B_in = b->num_inputs(); |
| gtl::InlinedVector<Node*, 4> b_control_edges; |
| gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in); |
| FillInputs(b, &b_control_edges, &b_in); |
| |
| // Shouldn't merge if a and b have different control edges. |
| if (a_control_edges != b_control_edges) { |
| continue; |
| } else { |
| // We found a match. |
| return b; |
| } |
| } |
| |
| return nullptr; |
| } |
| |
| Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ, |
| Node* pred) { |
| CHECK_NOTNULL(succ); |
| CHECK_NOTNULL(pred); |
| |
| if (succ->type_string() == csinfo_.bias_add && |
| pred->type_string() == csinfo_.mkl_conv2d) { |
| // 1. Get all attributes from input nodes. |
| DataType T_pred, T_succ; |
| string padding; |
| std::vector<int32> strides; |
| string data_format_pred, data_format_succ; |
| bool use_cudnn_on_gnu; |
| TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred)); |
| TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ)); |
| TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding)); |
| TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides)); |
| TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred)); |
| TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ)); |
| TF_CHECK_OK( |
| GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu)); |
| // We check to ensure that data formats of both succ and pred are same. |
| // We expect them to be same, so we can enforce this as assert. |
| // But assert can be too strict, so we enforce this as a check. |
| // If the check fails, then we do not merge two nodes. |
| // We also do same check for devices. |
| if (data_format_pred != data_format_succ || T_pred != T_succ || |
| pred->assigned_device_name() != succ->assigned_device_name() || |
| pred->def().device() != succ->def().device()) { |
| return Status(error::Code::INVALID_ARGUMENT, |
| "data_format or T attribute or devices of Conv2D and " |
| "BiasAdd do not match. Will skip node merge optimization"); |
| } |
| |
| const int succ_num = succ->num_inputs(); |
| gtl::InlinedVector<Node*, 4> succ_control_edges; |
| gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num); |
| FillInputs(succ, &succ_control_edges, &succ_in); |
| |
| const int pred_num = pred->num_inputs(); |
| gtl::InlinedVector<Node*, 4> pred_control_edges; |
| gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num); |
| FillInputs(pred, &pred_control_edges, &pred_in); |
| |
| // We need to ensure that there is only 1 edge between Conv2D and AddBias. |
| // Otherwise, merging is semantically incorrect. |
| if (pred->out_edges().size() != 1) { |
| return Status(error::Code::INVALID_ARGUMENT, |
| "Conv2D has multiple outputs." |
| "Will skip node merge optimization"); |
| } |
| |
| for (const Edge* e : pred->out_edges()) { |
| if (e->dst() != succ) { |
| return Status(error::Code::INVALID_ARGUMENT, |
| "Conv2D does not feed to BiasAdd." |
| "Will skip node merge optimization"); |
| } |
| } |
| |
| // 2. Get inputs from both the nodes. |
| // Find the 2 inputs from the conv and the bias from the add Bias. |
| // Get operand 0, 1 of conv2D and their Mkl tensors. |
| CHECK_EQ(pred->in_edges().size(), 4); // _MklConv2D must have 4 inputs. |
| // Get operand 1 of add_bias |
| // BiasAdd must have 2 inputs: Conv, bias |
| CHECK_EQ(succ->in_edges().size(), 2); |
| Node* oper3_mkl = nullptr; // Mkl tensor corresponding to oper3 |
| int oper3_mkl_slot = 0; // For dummy MKL tensor node, output slot is 0. |
| GetDummyMklTensorNode(g, &oper3_mkl, pred); // Get dummy Mkl tensor node |
| // as BiasAdd does not have Mkl tensor as input. |
| CHECK_NOTNULL(oper3_mkl); |
| |
| // We will use the node name of BiasAdd as the name of new node |
| // Build new node. We use same name as original node, but change the op |
| // name. |
| NodeBuilder nb(succ->name(), csinfo_.mkl_conv2d_with_bias); |
| if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { |
| nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D |
| // pred_in[1] will be Mkl tensor for In1 if we follow interleaved |
| // ordering, and it will be 2nd Tensorflow tensor for Conv2D if |
| // we follow contiguous ordering. |
| nb.Input(pred_in[1].first, pred_in[1].second); // Mkl for In1 |
| nb.Input(pred_in[2].first, pred_in[2].second); // In2 of Conv2D |
| nb.Input(pred_in[3].first, pred_in[3].second); // Mkl for In2 |
| nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd |
| nb.Input(oper3_mkl, oper3_mkl_slot); // Mkl for In2 of BiasAdd |
| } else { |
| CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); |
| nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D |
| // pred_in[1] will be Mkl tensor for In1 if we follow interleaved |
| // ordering, and it will be 2nd Tensorflow tensor for Conv2D if |
| // we follow contiguous ordering. |
| nb.Input(pred_in[1].first, pred_in[1].second); // In2 of Conv2D |
| nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd |
| nb.Input(pred_in[2].first, pred_in[2].second); // Mkl for In1 of Conv2D |
| nb.Input(pred_in[3].first, pred_in[3].second); // Mkl for In2 of Conv2D |
| nb.Input(oper3_mkl, oper3_mkl_slot); // Mkl for In2 of BiasAdd |
| } |
| |
| // Copy attributes from Conv2D to Conv2DWithBias. |
| CopyAttrsConv2D(const_cast<const Node*>(pred), &nb); |
| |
| // Copy the device assigned to old node to new node. |
| nb.Device(succ->def().device()); |
| |
| // Create node. |
| Node* new_node; |
| TF_CHECK_OK(nb.Finalize(&**g, &new_node)); |
| CHECK_NOTNULL(new_node); |
| |
| // Set the Mkl layer label for this op. |
| new_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel); |
| |
| // Incoming data edges from 'pred' node and 'succ' node to new 'new_node' |
| // node are already copied in BuildNode. We handle control edges now. |
| for (const Edge* e : pred->in_edges()) { |
| if (e->IsControlEdge()) { |
| CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); |
| } |
| } |
| for (const Edge* e : succ->in_edges()) { |
| if (e->IsControlEdge()) { |
| CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); |
| } |
| } |
| |
| // Incoming edges are fixed, we will fix the outgoing edges now. |
| // First, we will fix outgoing control edges from 'pred' node. |
| // We don't need to handle outgoing data edges from 'pred' node |
| // because pred has only 1 output going to succ node (we enforced |
| // this check for merge already). |
| for (const Edge* e : pred->out_edges()) { |
| if (e->IsControlEdge()) { |
| CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); |
| } |
| } |
| |
| // Second, we will fix outgoing control and data edges from 'succ' node. |
| for (const Edge* e : succ->out_edges()) { |
| if (e->IsControlEdge()) { |
| CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); |
| } else { |
| CHECK_NOTNULL( |
| (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input())); |
| } |
| } |
| |
| // Copy device assigned to old node to new node. |
| // It's ok to use pred or succ as we have enforced a check that |
| // both have same device assigned. |
| new_node->set_assigned_device_name(pred->assigned_device_name()); |
| |
| VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString() |
| << ", and node: " << succ->DebugString() |
| << ", into node:" << new_node->DebugString(); |
| |
| (*g)->RemoveNode(succ); |
| (*g)->RemoveNode(pred); |
| |
| return Status::OK(); |
| } |
| |
| return Status(error::Code::UNIMPLEMENTED, |
| "Unimplemented case for node merge optimization."); |
| } |
| |
| ////////////////////////////////////////////////////////////////////////// |
| // Helper functions for node rewrite |
| ////////////////////////////////////////////////////////////////////////// |
| |
| Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, |
| Node* orig_node, |
| const RewriteInfo* ri) { |
| CHECK_NOTNULL(ri); |
| CHECK_NOTNULL(orig_node); |
| |
| VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString(); |
| |
| // Check if this is scenario 2 (context-based rewrite). |
| // Get the matching ContextInfo if it is. |
| const Node* fwd_node = nullptr; |
| const ContextInfo* ci = nullptr; |
| bool is_context_based_rewrite = false; |
| if ((ci = SearchMatchingContext(orig_node, &fwd_node)) != nullptr) { |
| is_context_based_rewrite = true; |
| |
| // Sanity checks for context-based rewrite (if any) |
| if (orig_node->type_string() == csinfo_.bias_add_grad && |
| ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) { |
| CHECK_NOTNULL(fwd_node); |
| DataType orig_T, ctx_T; |
| string orig_data_format, ctx_data_format; |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &orig_T)); |
| TF_CHECK_OK( |
| GetNodeAttr(orig_node->def(), "data_format", &orig_data_format)); |
| TF_CHECK_OK(GetNodeAttr(fwd_node->def(), "T", &ctx_T)); |
| TF_CHECK_OK( |
| GetNodeAttr(fwd_node->def(), "data_format", &ctx_data_format)); |
| |
| if (orig_data_format != ctx_data_format || orig_T != ctx_T || |
| orig_node->assigned_device_name() != |
| fwd_node->assigned_device_name() || |
| orig_node->def().device() != fwd_node->def().device()) { |
| return Status( |
| error::Code::INVALID_ARGUMENT, |
| "data_format or T attribute or devices of BiasAddGrad and " |
| "Conv2D do not match. Will skip node rewrite optimization"); |
| } |
| } else if (orig_node->type_string() == csinfo_.bias_add_grad && |
| ri->new_name == csinfo_.matmul) { |
| // When BiasAddGrad has MatMul in context, we do not do any rewrite |
| // and leave BiasAddGrad as it is. But we check for this condition |
| // when we check for node rewrite rule. So we should not even come |
| // here for MatMul. So we will fail now. |
| return Status( |
| error::Code::INVALID_ARGUMENT, |
| "No rewrite is required for BiasAddGrad for MatMul context."); |
| } |
| } |
| |
| // Get all inputs. |
| int num_inputs = orig_node->in_edges().size(); |
| |
| // Drop count for control edges from inputs |
| for (const Edge* e : orig_node->in_edges()) { |
| if (e->IsControlEdge()) { |
| num_inputs--; |
| } |
| } |
| |
| gtl::InlinedVector<Node*, 4> control_edges; |
| gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs); |
| FillInputs(orig_node, &control_edges, &inputs); |
| |
| // Build new node. We use same name as original node, but change the op name. |
| NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str()); |
| // Copy user-specified device assigned to original node to new node. |
| nb.Device(orig_node->def().device()); |
| // Set up new inputs to the rewritten node. |
| Status s = SetUpInputs(g, inputs, &nb, orig_node); |
| if (s != Status::OK()) { |
| return s; |
| } |
| |
| // Copy attributes from original node to new node (for scenario 1). |
| // For context-based rewrite, we use context to copy the attributes. |
| if (is_context_based_rewrite) { |
| if (orig_node->type_string() == csinfo_.bias_add_grad && |
| ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) { |
| CHECK_NOTNULL(fwd_node); |
| ri->copy_attrs(fwd_node, &nb); |
| } else { |
| return Status(error::Code::UNIMPLEMENTED, |
| "Unimplemented case for node rewrite optimization."); |
| } |
| } else { |
| ri->copy_attrs(const_cast<const Node*>(orig_node), &nb); |
| } |
| // Set the Mkl layer label for this op. |
| nb.Attr("_kernel", mkl_op_registry::kMklOpLabel); |
| |
| // Finalize graph and get new node. |
| Node* new_node = nullptr; |
| TF_CHECK_OK(nb.Finalize(&**g, &new_node)); |
| CHECK_NOTNULL(new_node); |
| |
| // Incoming data edges from 'orig_node' node to new 'new_node' node are |
| // already copied in BuildNode. We need to handle control edges now. |
| for (const Edge* e : orig_node->in_edges()) { |
| if (e->IsControlEdge()) { |
| CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); |
| } |
| } |
| |
| // Copy outgoing edges from 'orig_node' node to new |
| // 'new_node' node, since the output also follows same ordering among |
| // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow |
| // tensors appropriately. Specifically, nth output of the original node |
| // will become 2*nth output of the Mkl node for the interleaved ordering |
| // of the tensors. For the contiguous ordering of the tensors, it will be n. |
| // GetTensorDataIndex provides this mapping function. |
| for (const Edge* e : orig_node->out_edges()) { |
| if (e->IsControlEdge()) { |
| CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); |
| } else { |
| CHECK_NOTNULL((*g)->AddEdge( |
| new_node, |
| GetTensorDataIndex(e->src_output(), e->src()->num_outputs()), |
| e->dst(), e->dst_input())); |
| } |
| } |
| |
| // Copy the runtime device assigned from original code to new node. |
| new_node->set_assigned_device_name(orig_node->assigned_device_name()); |
| |
| // Delete original node and mark new node as rewritten. |
| (*g)->RemoveNode(orig_node); |
| |
| VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString(); |
| return Status::OK(); |
| } |
| |
| const MklLayoutRewritePass::ContextInfo* |
| MklLayoutRewritePass::SearchMatchingContext(const Node* n, |
| const Node** fwd_node) { |
| CHECK_NOTNULL(n); |
| CHECK_NOTNULL(fwd_node); |
| *fwd_node = nullptr; |
| |
| // Search for matching contextinfo based on node name and call |
| // callback function using matching contextinfo. |
| // There could be more than one matching contextinfos but whichever |
| // matches first is returned. |
| for (auto ci = cinfo_.cbegin(); ci != cinfo_.cend(); ++ci) { |
| if (n->type_string() == (*ci)->node && |
| (*ci)->context_match_fn(n, fwd_node, *ci)) { |
| VLOG(1) << "Found context as matching: " << (*ci)->fwd; |
| return *ci; |
| } |
| } |
| return nullptr; |
| } |
| |
| bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n, |
| const ContextInfo* c) { |
| const Node* fwd_node = nullptr; |
| return SearchMatchingContext(n, &fwd_node) == c; |
| } |
| |
| const MklLayoutRewritePass::RewriteInfo* |
| MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { |
| CHECK_NOTNULL(n); |
| |
| // First check if node along with its type is supported by MKL layer. |
| // We do not want to rewrite an op into Mkl op if types are not supported. |
| // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to |
| // MklRelu if type is INT32. |
| DataType T; |
| if (!GetNodeAttr(n->def(), "T", &T).ok()) { |
| return nullptr; |
| } |
| |
| // BiasAddGrad is not an Mkl layer, so we make an exception for it. |
| if (n->type_string() != csinfo_.bias_add_grad) { |
| if (!mkl_op_registry::IsMklOp( |
| mkl_op_registry::GetMklOpName(n->type_string()), T)) { |
| return nullptr; |
| } |
| } |
| |
| // For elementwise node, we reuse the Eigen implementation and pass the MKL |
| // metadata tensor through so we can avoid conversions. However, if all |
| // incoming edges are in TF format, we don't need all this overhead, so |
| // replace the elementwise node only if at least one of its parents is a MKL |
| // node. |
| // |
| // TODO(vrane): Add implementation for element-wise ops that doesn't reuse |
| // eigen code to reduce cross-library dependency. |
| if (mkl_op_registry::IsMklElementWiseOp( |
| mkl_op_registry::GetMklOpName(n->type_string()), T)) { |
| bool incoming_mkl_edge = false; |
| for (auto parent : n->in_edges()) { |
| if (mkl_op_registry::IsMklOp( |
| mkl_op_registry::GetMklOpName(parent->src()->type_string()), T)) { |
| incoming_mkl_edge = true; |
| break; |
| } else { |
| VLOG(1) << "Non-MKL parent is: " << parent->src()->type_string(); |
| } |
| } |
| if (incoming_mkl_edge == false) { |
| VLOG(1) << "Skipping replacement of elementwise node which has no MKL " |
| "parents."; |
| return nullptr; |
| } |
| } |
| |
| // We support 2 types of node rewrites: |
| // 1. Rewriting BiasAddGrad depending on its MklConv2DWithBias context. |
| // 2. Rewriting an op to Mkl op always |
| // We return true if any of these 2 conditions is met. |
| |
| // Find matching RewriteInfo and then check that rewrite rule applies. |
| for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) { |
| if (n->type_string().compare(ri->name) == 0 && |
| ri->rewrite_rule(n, ri->context)) { |
| // If we are rewriting BiasAddGrad into BiasAddGrad for MatMul context, |
| // then we just return directly. |
| if (n->type_string() == csinfo_.bias_add_grad && |
| ri->context->fwd == csinfo_.matmul && |
| ri->new_name == csinfo_.bias_add_grad) { |
| return nullptr; |
| } |
| return &*ri; |
| } |
| } |
| |
| // Else return not found. |
| return nullptr; |
| } |
| |
| /////////////////////////////////////////////////////////////////////////////// |
| // Run function for the pass |
| /////////////////////////////////////////////////////////////////////////////// |
| |
| bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) { |
| bool result = false; |
| CHECK_NOTNULL(g); |
| |
| DumpGraph("Before running MklLayoutRewritePass", &**g); |
| |
| std::vector<Node*> order; |
| GetReversePostOrder(**g, &order); // This will give us topological sort. |
| |
| for (Node* n : order) { |
| // If node is not an op or it cannot run on CPU device, then skip. |
| if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) { |
| continue; |
| } |
| |
| const RewriteInfo* ri = nullptr; |
| Node* predn = nullptr; |
| // We will first search if node is to be rewritten |
| if ((ri = CheckForNodeRewrite(n)) != nullptr) { |
| string node_name = n->name(); |
| string op_name = n->type_string(); |
| |
| VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name |
| << " with op " << op_name << " for rewrite using" |
| << " layout optimization."; |
| |
| if (RewriteNode(g, n, ri) == Status::OK()) { |
| VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name |
| << " with op " << op_name << " for Mkl layout optimization."; |
| result = true; |
| } |
| } else if ((predn = CheckForNodeMerge(n)) != nullptr) { |
| // Otherwise, we will check if the node is to be merged. |
| string n1_name = n->name(); |
| string n2_name = predn->name(); |
| |
| VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and " |
| << n2_name << " for merging"; |
| |
| if (MergeNode(g, n, predn) == Status::OK()) { |
| VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and " |
| << n2_name; |
| result = true; |
| } |
| } |
| } |
| |
| DumpGraph("After running MklLayoutRewritePass", &**g); |
| |
| return result; |
| } |
| |
| bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) { |
| return MklLayoutRewritePass().RunPass(g); |
| } |
| |
| Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) { |
| if (options.graph == nullptr && options.partition_graphs == nullptr) { |
| return Status::OK(); |
| } |
| |
| auto process_graph = [&](std::unique_ptr<Graph>* g) { |
| // Get the ownership of a graph |
| std::unique_ptr<Graph>* ng = std::move(g); |
| RunPass(ng); |
| // Return the ownership of a graph back |
| g->reset(ng->release()); |
| }; |
| |
| if (kMklLayoutRewritePassGroup != |
| OptimizationPassRegistry::POST_PARTITIONING) { |
| // For any pre-partitioning phase, a graph is stored in options.graph. |
| process_graph(options.graph); |
| } else { |
| // For post partitioning phase, graphs are stored in |
| // options.partition_graphs. |
| for (auto& pg : *options.partition_graphs) { |
| process_graph(&pg.second); |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| #else // INTEL_MKL_ML_ONLY |
| |
| // This pass implements rewriting of graph to support following scenarios: |
| // (A) Merging nodes in the graph |
| // (B) Rewriting a node in the graph to a new node |
| // Rewrite happens under following scenario: |
| // - Propagating Mkl layout as an additional output tensor |
| // (we will loosely call a tensor that carries Mkl layout as Mkl tensor |
| // henceforth.) from every Mkl supported NN layer. |
| // |
| // Example of A : Merging nodes in the graph |
| // ----------------------------------------- |
| // Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as: |
| // |
| // O = Conv2D(A, B) |
| // P = BiasAdd(O, C) |
| // |
| // We merge them into Conv2DWithBias as: |
| // P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m) |
| // |
| // The meaning of A_m, B_m and C_m is explained in B.1. |
| // |
| // Merge rules: |
| // - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_ |
| // goes to BiasAdd. |
| // - Also, the intersection of attributes of both the nodes must have same |
| // values. |
| // - Both the nodes must have been assigned to same device (if any). |
| // |
| // Example of B.1 : Rewriting nodes to Mkl nodes |
| // --------------------------------------------- |
| // Consider a Relu node. Current definition of Relu node looks like: |
| // |
| // O = Relu(A) |
| // |
| // Relu has 1 input (A), and 1 output (O). |
| // |
| // This rewrite pass will generate a new graph node for Relu (new node is |
| // called MklRelu) as: |
| // |
| // O, O_m = MklRelu(A, A_m) |
| // |
| // MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is |
| // same as input A of Relu; output O is same as output O of Relu. O_m is the |
| // additional output tensor that will be set by MklRelu, and it represents |
| // Mkl tensor corresponding to O -- in other words, O_m is some kind of |
| // metadata for O. A_m is additional input of Relu, and it represents metadata |
| // for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives |
| // this metadata from previous node in the graph. |
| // |
| // When a previous node in the graph is an Mkl node, A_m will represent a valid |
| // Mkl tensor. But when a previous node is not an Mkl node, A_m will represent |
| // a dummy Mkl tensor. |
| // |
| // Rewriting rules: |
| // - Selection of a node for rewriting happens by registering the op type of |
| // the node with the rewriting pass. If the op type is not registered, then |
| // all nodes of this op type will not be rewritten. |
| // - Number of inputs after rewriting: |
| // Since for every input Tensorflow tensor, the rewritten node gets Mkl |
| // tensor(s), rewritten node gets 2*N inputs, where N is the number of |
| // inputs for the original node. |
| // - Number of outputs after rewriting: |
| // Since for every output Tensorflow tensor, the rewritten node generates |
| // Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the |
| // number of outputs of the original node. |
| // - Ordering of Tensorflow tensors and Mkl tensors: |
| // Since every rewritten node generates twice the number of inputs and |
| // outputs, one could imagine various orderings among Tensorflow tensors |
| // and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as |
| // inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m |
| // in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m |
| // order. Among N inputs one can get N! permutations. |
| // |
| // So the question is: which order do we follow? We support 2 types of |
| // orderings: (1) interleaved, and (2) contiguous. Interleaved ordering |
| // follows an intuitive order where an Mkl tensor follows the |
| // corresponding Tensorflow tensor immediately. In the context of the |
| // above example, it will be: A, A_m, B, B_m. Note that the ordering rule |
| // applies to both the inputs and outputs. Contiguous ordering means |
| // all the Tensorflow tensors are contiguous followed by all the Mkl |
| // tensors. We use contiguous ordering as default. |
| // |
| // Graph rewrite algorithm: |
| // Algorithm: Graph Rewrite |
| // Input: Graph G, Names of the nodes to rewrite and their new names |
| // Output: Modified Graph G' if the nodes are modified, G otherwise. |
| // Start: |
| // N = Topological_Sort(G) // N is a set of nodes in toposort order. |
| // foreach node n in N |
| // do |
| // if (Is_MKL_Op(n)) // Can this node accept an Mkl layout as input. |
| // then |
| // E = set of <incoming edge and its src_output slot> of n |
| // E' = {} // a new set of edges for rewritten node |
| // foreach <e,s> in E |
| // do |
| // E' U {<e,s>} // First copy edge which generates Tensorflow |
| // // tensor as it is |
| // m = Source node of edge e |
| // if Is_Rewritten(m) // Did we rewrite this node in this pass? |
| // then |
| // E' U {<m,s+1>} // If yes, then m will generate an Mkl |
| // // tensor as an additional output. |
| // else |
| // d = Generate_Dummy_Mkl_Tensor() // If not, generate a dummy |
| // // Mkl tensor. |
| // E' U {<d,0>} // The dummy Mkl tensor has only 1 output slot. |
| // fi |
| // done |
| // n' = Build_New_Node(G,new_name,E') |
| // Mark_Rewritten(n') // Mark the new node as being rewritten. |
| // fi |
| // done |
| // |
| // Explanation: |
| // For graph rewrite, we visit nodes of the input graph in the |
| // topological sort order. With this ordering, we visit nodes in the |
| // top-to-bottom fashion. We need this order because while visiting a |
| // node we want that all of its input nodes are visited and rewritten if |
| // applicable. This is because if we need to rewrite a given node |
| // then all of its input nodes need to be fixed (in other words they |
| // cannot be deleted later.) |
| // |
| // While visiting a node, we first check if the op type of the node is |
| // an Mkl op. If it is, then we rewrite that node after constructing |
| // new inputs to the node. If the op type of the node is not Mkl op, |
| // then we do not rewrite that node. |
| // |
| // Handling workspace propagation for certain ops: |
| // |
| // Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require |
| // passing of a workspace from their respective forward ops. Workspace |
| // tensors provide memory for storing results of intermediate operations |
| // which are helpful in backward propagation. TensorFlow does not have |
| // a notion of a workspace and as a result does not allow producing |
| // additional outputs from these forward ops. For these ops, we need |
| // to add 2 extra edges between forward ops and their corresponding |
| // backward ops - the first extra edge carries a workspace tensor and |
| // the second one carries an Mkl tensor for the workspace tensor. |
| // |
| // Example: |
| // |
| // Typical graph for MaxPool and its gradient looks like: |
| // |
| // A = MaxPool(T) |
| // B = MaxPoolGrad(X, A, Y) |
| // |
| // We will transform this graph to propagate the workspace as: |
| // (with the contiguous ordering) |
| // |
| // A, W, A_m, W_m = MklMaxPool(T, T_m) |
| // B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m) |
| // |
| // Here W is the workspace tensor. Transformed tensor names with the |
| // suffix _m are Mkl tensors, and this transformation has been done |
| // using the algorithm discussed earlier. The transformation for |
| // workspace propagation only adds extra outputs (W, W_m) for a forward |
| // op and connects them to the corresponding backward ops. |
| // |
| // Terms: |
| // |
| // Forward op name = name of the op in the forward pass |
| // where a workspace tensor originates (MaxPool in this example) |
| // Backward op name = name of the op in the backward pass that receives |
| // a workspace tensor from the forward op (MaxPoolGrad in the example) |
| // Slot = Position of the output or input slot that will be |
| // used by the workspace tensor (1 for MklMaxPool as W is the 2nd |
| // output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad) |
| // |
| // Question: |
| // |
| // How do we associate a backward op to a forward op? There can be more |
| // than one op with the exact same name. |
| // |
| // In this example, we associate MaxPoolGrad with MaxPool. But there |
| // could be more than one MaxPool ops. To solve this problem, we look |
| // for _direct_ edge between a forward op and a backward op (tensor A is |
| // flowing along this edge in the example). |
| // |
| // How do we transform forward and backward ops when there is no direct |
| // edge between them? In such a case, we generate dummy tensors for |
| // workspace tensors. For the example, transformation of MaxPool will |
| // be exactly same as it would be when there is a direct edge between |
| // the forward and the backward op --- it is just that MaxPool won't |
| // generate any workspace tensor. For MaxPoolGrad, the transformation |
| // will also be same, but instead of connecting W and W_m with the |
| // outputs of MaxPool, we will produce dummy tensors for them, and we |
| // will set workspace_enabled attribute to false. |
| // |
| class MklLayoutRewritePass : public GraphOptimizationPass { |
| public: |
| MklLayoutRewritePass() { |
| // NOTE: names are alphabetically sorted. |
| csinfo_.addn = "AddN"; |
| csinfo_.avg_pool = "AvgPool"; |
| csinfo_.avg_pool_grad = "AvgPoolGrad"; |
| csinfo_.avg_pool3d = "AvgPool3D"; |
| csinfo_.avg_pool3d_grad = "AvgPool3DGrad"; |
| csinfo_.bias_add = "BiasAdd"; |
| csinfo_.bias_add_grad = "BiasAddGrad"; |
| csinfo_.concat = "Concat"; |
| csinfo_.concatv2 = "ConcatV2"; |
| csinfo_.conv2d = "Conv2D"; |
| csinfo_.conv2d_with_bias = "__MklDummyConv2DWithBias"; |
| csinfo_.conv2d_grad_input = "Conv2DBackpropInput"; |
| csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter"; |
| csinfo_.conv2d_grad_filter_with_bias = |
| "__MklDummyConv2DBackpropFilterWithBias"; |
| csinfo_.conv3d = "Conv3D"; |
| csinfo_.conv3d_grad_input = "Conv3DBackpropInputV2"; |
| csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2"; |
| csinfo_.fused_batch_norm = "FusedBatchNorm"; |
| csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad"; |
| csinfo_.identity = "Identity"; |
| csinfo_.lrn = "LRN"; |
| csinfo_.lrn_grad = "LRNGrad"; |
| csinfo_.matmul = "MatMul"; |
| csinfo_.max_pool = "MaxPool"; |
| csinfo_.max_pool_grad = "MaxPoolGrad"; |
| csinfo_.max_pool3d = "MaxPool3D"; |
| csinfo_.max_pool3d_grad = "MaxPool3DGrad"; |
| csinfo_.mkl_conv2d = "_MklConv2D"; |
| csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput"; |
| csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter"; |
| csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; |
| csinfo_.mkl_conv2d_grad_filter_with_bias = |
| "_MklConv2DBackpropFilterWithBias"; |
| csinfo_.relu = "Relu"; |
| csinfo_.relu_grad = "ReluGrad"; |
| csinfo_.tanh = "Tanh"; |
| csinfo_.tanh_grad = "TanhGrad"; |
| csinfo_.reshape = "Reshape"; |
| csinfo_.softmax = "Softmax"; |
| csinfo_.split = "Split"; |
| // Element-wise ops. Ensure you also add any new ops to IsOpElementWise |
| // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the |
| // MklInputConversion op is added before it. |
| csinfo_.add = "Add"; |
| csinfo_.maximum = "Maximum"; |
| csinfo_.mul = "Mul"; |
| csinfo_.squared_difference = "SquaredDifference"; |
| csinfo_.sub = "Sub"; |
| // End - element-wise ops. See note above. |
| |
| // NOTE: names are alphabetically sorted. |
| rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), |
| CopyAttrsAddN, AddNRewrite}); |
| rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add), |
| CopyAttrsDataType, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.avg_pool, |
| mkl_op_registry::GetMklOpName(csinfo_.avg_pool), |
| CopyAttrsPooling, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.avg_pool_grad, |
| mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad), |
| CopyAttrsPooling, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.avg_pool3d, |
| mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d), |
| CopyAttrsPooling, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.avg_pool3d_grad, |
| mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad), |
| CopyAttrsPooling, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.concat, |
| mkl_op_registry::GetMklOpName(csinfo_.concat), |
| CopyAttrsConcat, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.concatv2, |
| mkl_op_registry::GetMklOpName(csinfo_.concatv2), |
| CopyAttrsConcatV2, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.conv2d, |
| mkl_op_registry::GetMklOpName(csinfo_.conv2d), |
| CopyAttrsConv, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias, |
| CopyAttrsConv, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.conv2d_grad_filter, |
| mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter), |
| CopyAttrsConv, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias, |
| csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv, |
| AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.conv2d_grad_input, |
| mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input), |
| CopyAttrsConv, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.conv3d, |
| mkl_op_registry::GetMklOpName(csinfo_.conv3d), |
| CopyAttrsConv, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.conv3d_grad_filter, |
| mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter), |
| CopyAttrsConv, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.conv3d_grad_input, |
| mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input), |
| CopyAttrsConv, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.fused_batch_norm, |
| mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), |
| CopyAttrsFusedBatchNorm, AlwaysRewrite}); |
| rinfo_.push_back( |
| {csinfo_.fused_batch_norm_grad, |
| mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), |
| CopyAttrsFusedBatchNorm, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.identity, |
| mkl_op_registry::GetMklOpName(csinfo_.identity), |
| CopyAttrsDataType, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn), |
| CopyAttrsLRN, LrnRewrite}); |
| rinfo_.push_back({csinfo_.lrn_grad, |
| mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), |
| CopyAttrsLRN, LrnGradRewrite}); |
| rinfo_.push_back({csinfo_.max_pool, |
| mkl_op_registry::GetMklOpName(csinfo_.max_pool), |
| CopyAttrsPooling, NonDepthBatchWisePoolRewrite}); |
| rinfo_.push_back({csinfo_.max_pool_grad, |
| mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad), |
| CopyAttrsPooling, MaxpoolGradRewrite}); |
| rinfo_.push_back({csinfo_.max_pool3d, |
| mkl_op_registry::GetMklOpName(csinfo_.max_pool3d), |
| CopyAttrsPooling, NonDepthBatchWisePoolRewrite}); |
| rinfo_.push_back({csinfo_.max_pool3d_grad, |
| mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad), |
| CopyAttrsPooling, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.maximum, |
| mkl_op_registry::GetMklOpName(csinfo_.maximum), |
| CopyAttrsDataType, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.mul, |
| mkl_op_registry::GetMklOpName(csinfo_.mul), |
| CopyAttrsDataType, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu), |
| CopyAttrsDataType, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.relu_grad, |
| mkl_op_registry::GetMklOpName(csinfo_.relu_grad), |
| CopyAttrsDataType, AlwaysRewrite}); |
| /* |
| rinfo_.push_back({csinfo_.tanh, |
| mkl_op_registry::GetMklOpName(csinfo_.tanh), |
| CopyAttrsDataType, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.tanh_grad, |
| mkl_op_registry::GetMklOpName(csinfo_.tanh_grad), |
| CopyAttrsDataType, AlwaysRewrite}); |
| */ |
| rinfo_.push_back({csinfo_.reshape, |
| mkl_op_registry::GetMklOpName(csinfo_.reshape), |
| CopyAttrsReshape, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.softmax, |
| mkl_op_registry::GetMklOpName(csinfo_.softmax), |
| CopyAttrsDataType, AlwaysRewrite}); |
| |
| rinfo_.push_back({csinfo_.squared_difference, |
| mkl_op_registry::GetMklOpName(csinfo_.squared_difference), |
| CopyAttrsDataType, AlwaysRewrite}); |
| rinfo_.push_back({csinfo_.sub, |
| mkl_op_registry::GetMklOpName(csinfo_.sub), |
| CopyAttrsDataType, AlwaysRewrite}); |
| |
| // Add info about which ops to add workspace edge to and the slots. |
| wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3}); |
| wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3}); |
| wsinfo_.push_back |
| ({csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3}); |
| |
| // Add a rule for merging nodes |
| minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add, |
| csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd}); |
| |
| minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad, |
| csinfo_.conv2d_grad_filter_with_bias, |
| GetConv2DBackpropFilterOrBiasAddGrad}); |
| } |
| |
| // Standard interface to run pass |
| Status Run(const GraphOptimizationPassOptions& options); |
| |
| // Helper function which does most of heavy lifting for rewriting |
| // Mkl nodes to propagate Mkl tensor as additional output |
| // |
| // Extracts common functionality between Run public interface and |
| // test interface. |
| // |
| // @return true, if and only if graph is mutated; false otherwise. |
| bool RunPass(std::unique_ptr<Graph>* g); |
| |
| /// Structure to specify the name of an original node, its new name after |
| /// rewrite, the number of inputs to the original node, the function to |
| /// be used to copy attributes for the op, and the rule (if any) which |
| /// must hold for rewriting the node |
| typedef struct { |
| string name; // Original name of op of the node in the graph |
| string new_name; // New name of the op of the node in the graph |
| // A function handler to copy attributes from an old node to a new node. |
| std::function<void(const Node*, NodeBuilder*)> copy_attrs; |
| // A rule under which to rewrite this node |
| std::function<bool(const Node*)> rewrite_rule; |
| } RewriteInfo; |
| |
| /// Structure to specify a forward op, a backward op, and the slot numbers |
| /// in the forward and backward ops where we will add a workspace edge. |
| typedef struct { |
| string fwd_op; // Name of a forward op in the graph |
| string bwd_op; // Name of a backward op in the graph |
| int fwd_slot; // Output slot in the forward op node where actual |
| // output tensor resides |
| int bwd_slot; // Input slot in the backward op node where actual |
| // input tensor resides |
| int ws_fwd_slot; // Output slot in the forward op node where workspace |
| // edge is added |
| int ws_bwd_slot; // Input slot in the backward op node where workspace |
| // edge is added |
| } WorkSpaceInfo; |
| |
| /// Structure to specify information used in node merge of 2 operators |
| typedef struct { |
| string op1; // Node string for one operator. |
| string op2; // Node string for second operator. |
| string new_node; // Name of the node after merge |
| // Function that enables user of the node merger to specify how to find |
| // second operator given the first operator. |
| std::function<Node*(const Node*)> get_node_to_be_merged; |
| } MergeInfo; |
| |
| /// Structure to store all constant strings |
| /// NOTE: names are alphabetically sorted. |
| typedef struct { |
| string addn; |
| string add; |
| string avg_pool; |
| string avg_pool_grad; |
| string avg_pool3d; |
| string avg_pool3d_grad; |
| string bias_add; |
| string bias_add_grad; |
| string concat; |
| string concatv2; |
| string conv2d; |
| string conv2d_with_bias; |
| string conv2d_grad_input; |
| string conv2d_grad_filter; |
| string conv2d_grad_filter_with_bias; |
| string conv3d; |
| string conv3d_grad_input; |
| string conv3d_grad_filter; |
| string fused_batch_norm; |
| string fused_batch_norm_grad; |
| string identity; |
| string lrn; |
| string lrn_grad; |
| string matmul; |
| string max_pool; |
| string max_pool_grad; |
| string max_pool3d; |
| string max_pool3d_grad; |
| string maximum; |
| string mkl_conv2d; |
| string mkl_conv2d_grad_input; |
| string mkl_conv2d_grad_filter; |
| string mkl_conv2d_grad_filter_with_bias; |
| string mkl_conv2d_with_bias; |
| string mul; |
| string relu; |
| string relu_grad; |
| string tanh; |
| string tanh_grad; |
| string reshape; |
| string softmax; |
| string split; |
| string squared_difference; |
| string sub; |
| } ConstStringsInfo; |
| |
| private: |
| /// Maintain info about nodes to rewrite |
| std::vector<RewriteInfo> rinfo_; |
| |
| /// Maintain info about nodes to add workspace edge |
| std::vector<WorkSpaceInfo> wsinfo_; |
| |
| /// Maintain info about nodes to be merged |
| std::vector<MergeInfo> minfo_; |
| |
| /// Maintain structure of constant strings |
| static ConstStringsInfo csinfo_; |
| |
| private: |
| // Is OpDef::ArgDef a list type? It could be N * T or list(type). |
| // Refer to opdef.proto for details of list type. |
| inline bool ArgIsList(const OpDef::ArgDef& arg) const { |
| return !arg.type_list_attr().empty() || !arg.number_attr().empty(); |
| } |
| |
| // Get length of a list in 'n' if 'arg' is of list type. Refer to |
| // description of ArgIsList for definition of list type. |
| inline int GetTensorListLength(const OpDef::ArgDef& arg, Node* n) { |
| CHECK_EQ(ArgIsList(arg), true); |
| int N = 0; |
| const string attr_name = !arg.type_list_attr().empty() |
| ? arg.type_list_attr() |
| : arg.number_attr(); |
| if (!arg.type_list_attr().empty()) { |
| std::vector<DataType> value; |
| TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value)); |
| N = value.size(); |
| } else { |
| TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N)); |
| } |
| return N; |
| } |
| |
| // Can op represented by node 'n' run on DEVICE_CPU? |
| // Op can run on CPU with MKL if the runtime assigned device or the |
| // user requested device contains device CPU, or both are empty. |
| bool CanOpRunOnCPUDevice(const Node* n) { |
| bool result = true; |
| string reason; |
| |
| // Substring that should be checked for in device name for CPU device. |
| const char* const kCPUDeviceSubStr = "CPU"; |
| |
| // If Op has been specifically assigned to a non-CPU device, then No. |
| if (!n->assigned_device_name().empty() && |
| !str_util::StrContains(n->assigned_device_name(), kCPUDeviceSubStr)) { |
| result = false; |
| reason = "Op has been assigned a runtime device that is not CPU."; |
| } |
| |
| // If user has specifically assigned this op to a non-CPU device, then No. |
| if (!n->def().device().empty() && |
| !str_util::StrContains(n->def().device(), kCPUDeviceSubStr)) { |
| result = false; |
| reason = "User has assigned a device that is not CPU."; |
| } |
| |
| if (result == false) { |
| VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node " |
| << n->type_string() << ", reason: " << reason; |
| } |
| |
| // Otherwise Yes. |
| return result; |
| } |
| |
| // Return a node that can be merged with input node 'n' |
| // |
| // @return pointer to the node if we can find such a |
| // node. Otherwise, it returns nullptr. |
| Node* CheckForNodeMerge(const Node* n) const; |
| |
| // Merge node 'm' with node 'n'. |
| // Currently, we merge (1) Conv2D with BiasAdd, and (2) BiasAddGrad with |
| // Conv2DBackpropFilter. |
| // |
| // Input nodes m and n may be deleted if the call to |
| // this function is successful. Attempt to use the pointers |
| // after the call to function may result in undefined behaviors. |
| // |
| // @input g - input graph, m - graph node, n - graph node to be merged with m |
| // @return Status::OK(), if merging is successful and supported. |
| // Returns appropriate Status error code otherwise. |
| // Graph is updated in case nodes are merged. Otherwise, it is |
| // not updated. |
| Status MergeNode(std::unique_ptr<Graph>* g, Node* m, Node* n); |
| |
| // Helper function to merge different nodes |
| Status MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, Node* m, Node* n); |
| Status MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph>* g, |
| Node* m, Node* n); |
| |
| // Find BiasAdd or Conv2D node that can be merged with input node 'm'. |
| // If input 'm' is BiasAdd, then check if there exists Conv2D node that can be |
| // merged with 'm'. If input 'm' is Conv2D, then check if there exists BiasAdd |
| // node that can be merged with 'm'. |
| static Node* GetConv2DOrBiasAdd(const Node* m) { |
| CHECK_NOTNULL(m); |
| Node* n = nullptr; |
| |
| if (m->type_string() == csinfo_.bias_add) { |
| // If a is BiasAdd, then Conv2D is 0th input of BiasAdd. |
| TF_CHECK_OK(m->input_node(0, &n)); |
| } else { |
| CHECK_EQ(m->type_string(), csinfo_.conv2d); |
| // Go over all output edges and search for BiasAdd Node. |
| // 0th input of BiasAdd is Conv2D. |
| for (const Edge* e : m->out_edges()) { |
| if (!e->IsControlEdge() && |
| e->dst()->type_string() == csinfo_.bias_add && |
| e->dst_input() == 0) { |
| n = e->dst(); |
| break; |
| } |
| } |
| } |
| |
| if (n == nullptr) { |
| VLOG(1) << "MklLayoutRewritePass: Could not find matching " |
| << "Conv2D and BiasAdd node for merging. Input node: " |
| << m->DebugString(); |
| } |
| |
| return n; |
| } |
| |
| // Find Conv2DBackpropFilter or BiasAddGrad node that can be merged with input |
| // node 'm'. If input 'm' is Conv2DBackpropFilter, then check if there exists |
| // BiasAddGrad node that can be merged with 'm'. If input 'm' is BiasAddGrad, |
| // then check if there exists Conv2DBackpropFilter node that can be merged |
| // with 'm'. |
| // |
| // Graph that will allow us to connect Conv2DBackpropFilter with BiasAddGrad |
| // would look like: |
| // |
| // _ = Conv2DBackpropFilter(F, _, G) |
| // _ = BiasAddGrad(G) |
| // |
| // So 1st input of BiasAddGrad connects with 3rd input of |
| // Conv2DBackpropFilter and vice versa. |
| static Node* GetConv2DBackpropFilterOrBiasAddGrad(const Node* m) { |
| CHECK_NOTNULL(m); |
| Node* n = nullptr; |
| |
| if (m->type_string() == csinfo_.bias_add_grad) { |
| // Get 1st input 'g' of BiasAddGrad. |
| Node* g = nullptr; |
| TF_CHECK_OK(m->input_node(0, &g)); |
| // Now traverse all outgoing edges from g that have destination node as |
| // Conv2DBackpropFilter. |
| for (const Edge* e : g->out_edges()) { |
| if (!e->IsControlEdge() && |
| e->dst()->type_string() == csinfo_.conv2d_grad_filter && |
| e->dst_input() == 2 /* 3rd input of BackpropFilter */) { |
| n = e->dst(); |
| break; |
| } |
| } |
| } else { |
| CHECK_EQ(m->type_string(), csinfo_.conv2d_grad_filter); |
| // Get 3rd input 'g' of Conv2DBackpropFilter. |
| Node* g = nullptr; |
| TF_CHECK_OK(m->input_node(2, &g)); |
| // Now traverse all outgoing edges from g that have destination node as |
| // BiasAddGrad. |
| for (const Edge* e : g->out_edges()) { |
| if (!e->IsControlEdge() && |
| e->dst()->type_string() == csinfo_.bias_add_grad && |
| e->dst_input() == 0 /* 1st input of BiasAddGrad */) { |
| n = e->dst(); |
| break; |
| } |
| } |
| } |
| |
| if (n == nullptr) { |
| VLOG(1) << "MklLayoutRewritePass: Could not find matching " |
| << "Conv2DBackpropFilter and BiasAddGrad node for merging. " |
| << "Input node: " << m->DebugString(); |
| } |
| return n; |
| } |
| |
| // Check if the node 'n' has any applicable rewrite rule |
| // We check for 2 scenarios for rewrite. |
| // |
| // @return RewriteInfo* for the applicable rewrite rule |
| const RewriteInfo* CheckForNodeRewrite(const Node* n) const; |
| |
| // Default rewrite rule to be used in scenario 1 for rewrite. |
| // @return - true (since we want to always rewrite) |
| static bool AlwaysRewrite(const Node* n) { return true; } |
| |
| // Check if we are performing pooling on depth or batch. If it is, then we |
| // do not rewrite MaxPool node to Mkl version. |
| // @return - true (if it is not a depth/batch wise pooling case); |
| // false otherwise. |
| static bool NonDepthBatchWisePoolRewrite(const Node* n) { |
| CHECK_NOTNULL(n); |
| |
| string data_format_str; |
| TensorFormat data_format; |
| std::vector<int32> ksize, strides; |
| CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true); |
| CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true); |
| CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true); |
| CHECK_EQ(FormatFromString(data_format_str, &data_format), true); |
| |
| // Condition that specifies non-batch-wise and non-depth-wise pooling. |
| if (GetTensorDim(ksize, data_format, 'N') == 1 && |
| GetTensorDim(strides, data_format, 'N') == 1 && |
| GetTensorDim(ksize, data_format, 'C') == 1 && |
| GetTensorDim(strides, data_format, 'C') == 1) { |
| return true; |
| } |
| |
| return false; |
| } |
| |
| // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized |
| // path. The unoptimized path is slow. Thus we dont rewrite the node |
| // and use default Eigen. But for depth_radius=2, MKL DNN optimized |
| // path is taken, i.e., eigen node is rewritten by MKl DNN node. |
| static bool LrnRewrite(const Node* n) { |
| CHECK_NOTNULL(n); |
| |
| int depth_radius; |
| CHECK_EQ(GetNodeAttr(n->def(), "depth_radius", &depth_radius).ok(), true); |
| |
| // if the depth_radius of LRN is not 2, don't rewrite the node by MKL DNN |
| // and use eigen node instead |
| if (depth_radius == 2) { |
| return true; |
| } |
| VLOG(1) << "LrnRewrite: The model sets depth_radius as not 2 which" |
| << "case is not optimized by Intel MKL, thus using Eigen op" |
| << "for LRN "; |
| |
| return false; |
| } |
| |
| static bool LrnGradRewrite(const Node* n) { |
| CHECK_NOTNULL(n); |
| bool do_rewrite = false; |
| |
| for (const Edge* e : n->in_edges()) { |
| // Rewrite only if there is corresponding LRN, i.e workspace is available |
| if (e->dst()->type_string() == csinfo_.lrn_grad && e->dst_input() == 2 && |
| e->src()->type_string() == |
| mkl_op_registry::GetMklOpName(csinfo_.lrn) && |
| e->src_output() == 0) { |
| do_rewrite = true; |
| break; |
| } |
| } |
| return do_rewrite; |
| } |
| |
| static bool MaxpoolGradRewrite(const Node* n) { |
| CHECK_NOTNULL(n); |
| bool do_rewrite = false; |
| for (const Edge* e : n->in_edges()) { |
| // Rewrite only if there is corresponding Maxpool, i.e workspace is |
| // available |
| if (e->dst()->type_string() == csinfo_.max_pool_grad && |
| e->dst_input() == 1 && |
| e->src()->type_string() == |
| mkl_op_registry::GetMklOpName(csinfo_.max_pool) && |
| e->src_output() == 0) { |
| do_rewrite = true; |
| break; |
| } |
| } |
| return do_rewrite; |
| } |
| |
| static bool AddNRewrite(const Node* n) { |
| CHECK_NOTNULL(n); |
| |
| int num; |
| CHECK_EQ(GetNodeAttr(n->def(), "N", &num).ok(), true); |
| |
| // Condition that specifies non-batch-wise and non-depth-wise pooling. |
| if (num == 2) { |
| return true; |
| } |
| |
| return false; |
| } |
| |
| // Rewrites input node to a new node specified by its matching rewrite info. |
| // |
| // Method first searches matching rewrite info for input node and then |
| // uses that info to rewrite. |
| // |
| // Input node may be deleted in case of rewrite. Attempt to use the node |
| // after the call can result in undefined behaviors. |
| // |
| // @input g - input graph, n - Node to be rewritten, |
| // ri - matching rewriteinfo |
| // @return Status::OK(), if the input node is rewritten; |
| // Returns appropriate Status error code otherwise. |
| // Graph is updated in case the input node is rewritten. |
| // Otherwise, it is not updated. |
| Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri); |
| |
| // Get nodes that will feed a list of TF tensors to the new |
| // node that we are constructing. |
| // |
| // @input g - input graph, |
| // @input inputs - inputs to old node that we are using for constructing |
| // new inputs, |
| // @input input_idx - the index in the 'inputs' vector pointing to the |
| // current input that we have processed so far |
| // @output input_idx - index will be incremented by the number of nodes |
| // from 'inputs' that are processed |
| // @input list_length - The expected length of list of TF tensors |
| // @output output_nodes - the list of new nodes creating TF tensors |
| // |
| // @return None |
| void GetNodesProducingTFTensorList( |
| const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, |
| int* input_idx, int list_length, |
| std::vector<NodeBuilder::NodeOut>* output_nodes); |
| |
| // Get nodes that will feed a list of Mkl tensors to the new |
| // node that we are constructing. |
| // |
| // @input g - input graph, |
| // @input orig_node - Original node that we are rewriting |
| // @input inputs - inputs to old node that we are using for constructing |
| // new inputs, |
| // @input input_idx - the index in the 'inputs' vector pointing to the |
| // current input that we have processed so far |
| // @output input_idx - index will be incremented by the number of nodes |
| // from 'inputs' that are processed |
| // @input list_length - The expected length of list of Mkl tensors |
| // @output output_nodes - the list of new nodes creating Mkl tensors |
| // |
| // @return None |
| void GetNodesProducingMklTensorList( |
| std::unique_ptr<Graph>* g, Node* orig_node, |
| const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, |
| int* input_idx, int list_length, |
| std::vector<NodeBuilder::NodeOut>* output_nodes); |
| |
| // Get a node that will feed an Mkl tensor to the new |
| // node that we are constructing. The output node could be (1) 'n' |
| // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor |
| // if 'n' is not an Mkl layer. |
| // |
| // @input g - input graph, |
| // @input orig_node - Original node that we are rewriting, |
| // @input n - Node based on which we are creating Mkl node, |
| // @input n_output_slot - the output slot of node 'n' |
| // which is feeding to the node that we are constructing |
| // @output mkl_node - the new node that will feed Mkl tensor |
| // @output mkl_node_output_slot - the slot number of mkl_node that |
| // will feed the tensor |
| // @return None |
| void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node, |
| Node* n, int n_output_slot, Node** mkl_node, |
| int* mkl_node_output_slot); |
| |
| // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' |
| // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are |
| // set up in contiguous fashion. 'workspace_tensors' carry graph nodes |
| // producing workspace edges if 'are_workspace_tensors_available' is true. |
| // Otherwise, 'workspace_tensors' is empty vector. |
| // |
| // For details, refer to 'Ordering of inputs after rewriting' section in the |
| // documentation above. |
| // |
| // Returns Status::OK() if setting up inputs is successful, otherwise |
| // returns appropriate status code. |
| int SetUpContiguousInputs( |
| std::unique_ptr<Graph>* g, |
| const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, |
| NodeBuilder* nb, Node* old_node, |
| std::vector<NodeBuilder::NodeOut>* workspace_tensors, |
| bool are_workspace_tensors_available); |
| |
| // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' |
| // in graph 'g'. Original node is input in 'orig_node'. |
| // |
| // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors' |
| // section in the documentation above. |
| // |
| // Returns Status::OK() if setting up inputs is successful, otherwise |
| // returns appropriate status code. |
| Status SetUpInputs(std::unique_ptr<Graph>* g, |
| const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, |
| NodeBuilder* nb, Node* orig_node); |
| |
| // Add workspace edge on the input or output side of Node 'orig_node' by using |
| // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate |
| // adding workspace edge then do not add it. Workspace Tensorflow and Mkl |
| // tensors, if they need to be added, will be set into these tensors. |
| // If we set workspace tensors, then are_ws_tensors_added should be true. |
| void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orig_node, |
| NodeBuilder* nb, |
| std::vector<NodeBuilder::NodeOut>* ws_tensors, |
| bool* are_ws_tensors_added); |
| |
| // Helper function used by FixMklMetaDataEdges. Fixes the metadata edge |
| // pointed by 'e_metadata' corresponding to the data edge 'e_data' in graph |
| // 'g'. Returns true is fixup was done; otherwise, it returns false. |
| bool FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g, |
| const Edge* e_data, const Edge* e_metadata); |
| |
| // Are the input Mkl metadata edges for node 'n' in graph 'g' correctly |
| // connected? If not, then fix them. This is needed because a graph may have |
| // some input Mkl metadata edges incorrectly setup after node merge and |
| // rewrite passes. This could happen because GetReversePostOrder function may |
| // not provide topologically sorted order if a graph contains cycles. The |
| // function returns true if at least one Mkl metadata edge for node 'n' was |
| // fixed. Otherwise, it returns false. |
| // |
| // Example: |
| // |
| // X = MklConv2D(_, _, _) |
| // Y = MklConv2DWithBias(_, _, _, _, _, _) |
| // Z = MklAdd(X, Y, DummyMklTensor, Y:1) |
| // |
| // For a graph such as shown above, note that 3rd argument of MklAdd contains |
| // DummyMklTensor. Actually, it should be getting the Mkl metadata from |
| // MklConv2D op (specifically, X:2). This incorrect plumbing could be possible |
| // (although rare) if the Mkl NodeMerge + NodeRewrite passes visit Z before X |
| // (possible if X, Y, Z are part of a loop.) This function fixes the Mkl |
| // metadata edges only - it does not rewrite nodes nor does it modify the Mkl |
| // data edges (1st and 2nd arguments of MklAdd). |
| bool FixMklMetaDataEdges(std::unique_ptr<Graph>* g, Node* n); |
| |
| // Functions specific to operators to copy attributes |
| // We need operator-specific function to copy attributes because the framework |
| // does not provide any generic function for it. |
| // NOTE: names are alphabetically sorted. |
| static void CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb); |
| static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb); |
| |
| // Generate a graph node in graph 'g' representing a dummy Mkl tensor node, |
| // using node for original node 'orig_node' and return it in '*out'. |
| // TODO(nhasabni) We should move this to mkl_util.h |
| void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out, |
| Node* orig_node); |
| void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out, |
| Node* orig_node); |
| }; |
| |
| MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_; |
| |
| // We register Mkl rewrite pass for phase 1 in post partitioning group. |
| // We register it here so that we get a complete picture of all users of Mkl |
| // nodes. Do not change the ordering of the Mkl passes. |
| const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup = |
| OptimizationPassRegistry::POST_PARTITIONING; |
| REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass); |
| |
| ////////////////////////////////////////////////////////////////////////// |
| // Helper functions for creating new node |
| ////////////////////////////////////////////////////////////////////////// |
| |
| static void FillInputs(const Node* n, |
| gtl::InlinedVector<Node*, 4>* control_edges, |
| gtl::InlinedVector<std::pair<Node*, int>, 4>* in) { |
| control_edges->clear(); |
| for (const Edge* e : n->in_edges()) { |
| if (e->IsControlEdge()) { |
| control_edges->push_back(e->src()); |
| } else { |
| (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output()); |
| } |
| } |
| std::sort(control_edges->begin(), control_edges->end()); |
| if (n->op_def().is_commutative()) { |
| // For commutative inputs, we sort the input by the input Node* |
| // to get a canonical ordering (so that add(a,b) and add(b, a) will |
| // hash to the same value if is_commutative is true for 'add'). |
| std::sort(in->begin(), in->end()); |
| } |
| } |
| |
| void MklLayoutRewritePass::GetNodesProducingTFTensorList( |
| const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx, |
| int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) { |
| CHECK_LT(*input_idx, inputs.size()); |
| CHECK_GT(list_length, 0); |
| CHECK_NOTNULL(output_nodes); |
| output_nodes->reserve(list_length); |
| |
| while (list_length != 0) { |
| CHECK_GT(list_length, 0); |
| CHECK_LT(*input_idx, inputs.size()); |
| Node* n = inputs[*input_idx].first; |
| int slot = inputs[*input_idx].second; |
| // If input node 'n' is just producing a single tensor at |
| // output slot 'slot' then we just add that single node. |
| output_nodes->push_back(NodeBuilder::NodeOut(n, slot)); |
| (*input_idx)++; |
| list_length--; |
| } |
| } |
| |
| // TODO(nhasabni) We should move this to mkl_util.h. |
| void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g, |
| Node** out, Node* orig_node) { |
| // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent |
| // dummy Mkl tensor. 8 = 2*size_t. |
| const DataType dt = DataTypeToEnum<uint8>::v(); |
| TensorProto proto; |
| proto.set_dtype(dt); |
| uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; |
| proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 8)); |
| TensorShape dummy_shape({8}); |
| dummy_shape.AsProto(proto.mutable_tensor_shape()); |
| TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") |
| .Attr("value", proto) |
| .Attr("dtype", dt) |
| .Device(orig_node->def().device()) // We place this node on |
| // the same device as the |
| // device of the original |
| // node. |
| .Finalize(&**g, out)); |
| CHECK_NOTNULL(*out); // Make sure we got a valid object before using it |
| |
| // If number of inputs to the original node is > 0, then we add |
| // control dependency between 1st input (index 0) of the original node and |
| // the dummy Mkl node. This is needed because control-flow ops such as Enter, |
| // Merge, etc, require frame_name of the dummy Mkl node to be same as the |
| // rewritten node. Adding control edge between 1st input of the original node |
| // and the dummy Mkl node ensures that the dummy node is in the same frame |
| // as the original node. Choosing 1st input is not necessary - any input of |
| // the original node is fine because all the inputs of a node are always in |
| // the same frame. |
| if (orig_node->num_inputs() > 0) { |
| Node* orig_input0 = nullptr; |
| TF_CHECK_OK( |
| orig_node->input_node(0, const_cast<const Node**>(&orig_input0))); |
| // Allow duplicate while adding control edge as it would fail (return |
| // NULL) if we try to add duplicate edge. |
| CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true)); |
| } |
| |
| (*out)->set_assigned_device_name(orig_node->assigned_device_name()); |
| } |
| |
| void MklLayoutRewritePass::GetNodesProducingMklTensorList( |
| std::unique_ptr<Graph>* g, Node* orig_node, |
| const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx, |
| int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) { |
| CHECK_LT(*input_idx, inputs.size()); |
| CHECK_GT(list_length, 0); |
| CHECK_NOTNULL(output_nodes); |
| output_nodes->reserve(list_length); |
| |
| while (list_length != 0) { |
| CHECK_GT(list_length, 0); |
| CHECK_LT(*input_idx, inputs.size()); |
| Node* n = inputs[*input_idx].first; |
| int slot = inputs[*input_idx].second; |
| // If 'n' is producing a single tensor, then create a single Mkl tensor |
| // node. |
| Node* mkl_node = nullptr; |
| int mkl_node_output_slot = 0; |
| GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node, |
| &mkl_node_output_slot); |
| output_nodes->push_back( |
| NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); |
| (*input_idx)++; |
| list_length--; |
| } |
| } |
| |
| // Get an input node that will feed Mkl tensor to the new |
| // node that we are constructing. An input node could be (1) 'n' |
| // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor |
| // if 'n' is not an Mkl layer. |
| void MklLayoutRewritePass::GetNodeProducingMklTensor( |
| std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot, |
| Node** mkl_node, int* mkl_node_output_slot) { |
| CHECK_NOTNULL(n); |
| CHECK_NOTNULL(mkl_node); |
| CHECK_NOTNULL(mkl_node_output_slot); |
| |
| // If this is an MKL op, then it will create extra output for MKL layout. |
| DataType T; |
| if (GetNodeAttr(n->def(), "T", &T).ok() && |
| mkl_op_registry::IsMklOp(n->type_string(), T)) { |
| // If this is an MKL op, then it will generate an edge that will receive |
| // Mkl tensor from a node. |
| // output slot number for Mkl tensor would be N+slot number of TensorFlow |
| // tensor, where N is total number of TensorFlow tensors. |
| *mkl_node = n; |
| *mkl_node_output_slot = |
| GetTensorMetaDataIndex(n_output_slot, n->num_outputs()); |
| } else { |
| // If we have not visited the node and rewritten it, then we need |
| // to create a dummy node that will feed a dummy Mkl tensor to this node. |
| // DummyMklTensor node has no input and generates only 1 output |
| // (dummy Mkl tensor) as output slot number 0. |
| GetDummyMklTensorNode(g, mkl_node, orig_node); |
| CHECK_NOTNULL(*mkl_node); |
| *mkl_node_output_slot = 0; |
| } |
| } |
| |
| int MklLayoutRewritePass::SetUpContiguousInputs( |
| std::unique_ptr<Graph>* g, |
| const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, |
| NodeBuilder* nb, Node* old_node, |
| std::vector<NodeBuilder::NodeOut>* workspace_tensors, |
| bool are_workspace_tensors_available) { |
| CHECK_NOTNULL(workspace_tensors); |
| CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); |
| |
| // TODO(nhasabni): Temporary solution to connect filter input of |
| // BackpropInput with the converted filter from Conv2D. |
| bool do_connect_conv2d_backprop_input_filter = false; |
| Node* conv2d_node = nullptr; |
| // Filter node is 2nd input (slot index 1) of Conv2D. |
| int kConv2DFilterInputSlotIdx = 1; |
| int kConv2DBackpropInputFilterInputSlotIdx = 1; |
| int kConv2DFilterOutputSlotIdx = 1; |
| if (old_node->type_string() == csinfo_.conv2d_grad_input) { |
| // We need to find Conv2D node from Conv2DBackpropInput. |
| // For that let's first find filter node that is 2nd input (slot 1) |
| // of BackpropInput. |
| Node* filter_node = nullptr; |
| TF_CHECK_OK(old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx, |
| &filter_node)); |
| CHECK_NOTNULL(filter_node); |
| |
| // Now check which nodes receive from filter_node. Filter feeds as |
| // 2nd input (slot 1) of _MklConv2D and _MklConv2DWithBias. |
| for (const Edge* e : filter_node->out_edges()) { |
| if ((e->dst()->type_string() == csinfo_.mkl_conv2d || |
| e->dst()->type_string() == csinfo_.mkl_conv2d_with_bias) && |
| e->dst_input() == kConv2DFilterInputSlotIdx |
| /* filter is 2nd input of Conv2D and _MklConv2D. */) { |
| if (conv2d_node != nullptr) { |
| VLOG(1) << "MklLayoutRewritePass: unusual case of same filter" |
| << " feeding multiple Conv2D nodes: " |
| << filter_node->DebugString(); |
| // We will not connect filter input of Conv2DBackpropInput |
| // to be safe here. |
| do_connect_conv2d_backprop_input_filter = false; |
| break; |
| } else { |
| conv2d_node = e->dst(); |
| do_connect_conv2d_backprop_input_filter = true; |
| } |
| } |
| } |
| } |
| |
| // Number of input slots to original op |
| // Input slots are represented by .Input() calls in REGISTER_OP. |
| int old_node_input_slots = old_node->op_def().input_arg_size(); |
| // Actual number of inputs can be greater than or equal to number |
| // of Input slots because inputs of type list could be unfolded. |
| CHECK_GE(old_node_inputs.size(), old_node_input_slots); |
| int nn_slot_idx = 0; // slot index for inputs of new node |
| |
| // Let's copy all inputs (TF tensors) of original node to new node. |
| int iidx = 0; |
| for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) { |
| // An input slot could be a single tensor or a list. We need |
| // to handle this case accordingly. |
| CHECK_LT(iidx, old_node_inputs.size()); |
| const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx); |
| if (ArgIsList(arg)) { |
| std::vector<NodeBuilder::NodeOut> new_node_inputs; |
| int N = GetTensorListLength(arg, old_node); |
| GetNodesProducingTFTensorList(old_node_inputs, &iidx, N, |
| &new_node_inputs); |
| nb->Input(new_node_inputs); |
| nn_slot_idx++; |
| } else { |
| // Special case for connecting filter input of Conv2DBackpropInput |
| if (do_connect_conv2d_backprop_input_filter && |
| iidx == kConv2DBackpropInputFilterInputSlotIdx) { |
| nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx); |
| } else { |
| nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second); |
| } |
| iidx++; |
| nn_slot_idx++; |
| } |
| } |
| |
| // If workspace tensors are available for this op and we are using |
| // contiguous ordering then we need to add Tensorflow tensor for |
| // workspace here because Tensorflow tensor for workspace is the |
| // last tensor in the list of Tensorflow tensors. |
| if (are_workspace_tensors_available) { |
| CHECK_EQ(workspace_tensors->size(), 2); |
| // Tensorflow tensor |
| nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index); |
| nn_slot_idx++; |
| } |
| |
| // Let's now setup all Mkl inputs to a new node. |
| // Number of Mkl inputs must be same as number of TF inputs. |
| iidx = 0; |
| for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) { |
| // An input slot could be a single tensor or a list. We need |
| // to handle this case accordingly. |
| CHECK_LT(iidx, old_node_inputs.size()); |
| const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx); |
| if (ArgIsList(arg)) { |
| std::vector<NodeBuilder::NodeOut> new_node_inputs; |
| int N = GetTensorListLength(arg, old_node); |
| GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N, |
| &new_node_inputs); |
| nb->Input(new_node_inputs); |
| nn_slot_idx++; |
| } else { |
| Node* mkl_node = nullptr; |
| int mkl_node_output_slot = 0; |
| // Special case for connecting filter input of Conv2DBackpropInput |
| if (do_connect_conv2d_backprop_input_filter && |
| iidx == kConv2DBackpropInputFilterInputSlotIdx) { |
| GetNodeProducingMklTensor(g, old_node, conv2d_node, |
| kConv2DFilterOutputSlotIdx, &mkl_node, |
| &mkl_node_output_slot); |
| } else { |
| GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first, |
| old_node_inputs[iidx].second, &mkl_node, |
| &mkl_node_output_slot); |
| } |
| nb->Input(mkl_node, mkl_node_output_slot); |
| iidx++; |
| nn_slot_idx++; |
| } |
| } |
| |
| // If workspace tensors are available for this op and we are using |
| // contiguous ordering then we need to add Mkl tensor for |
| // workspace here because Mkl tensor for workspace is the |
| // last tensor in the list of Mkl tensors. |
| if (are_workspace_tensors_available) { |
| CHECK_EQ(workspace_tensors->size(), 2); |
| // Mkl tensor |
| nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index); |
| nn_slot_idx++; |
| } |
| |
| return nn_slot_idx; |
| } |
| |
| Status MklLayoutRewritePass::SetUpInputs( |
| std::unique_ptr<Graph>* g, |
| const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, |
| NodeBuilder* nb, Node* old_node) { |
| // Let's check if we need to add workspace tensors for this node. |
| // We add workspace edge only for MaxPool, LRN and BatchNorm. |
| std::vector<NodeBuilder::NodeOut> workspace_tensors; |
| bool are_workspace_tensors_available = false; |
| AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors, |
| &are_workspace_tensors_available); |
| |
| int new_node_input_slots = 0; |
| if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { |
| // TODO(nhasabni): implement this function just for same of completion. |
| // We do not use interleaved ordering right now. |
| return Status( |
| error::Code::UNIMPLEMENTED, |
| "Interleaved ordering of tensors is currently not supported."); |
| } else { |
| CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); |
| new_node_input_slots = SetUpContiguousInputs( |
| g, old_node_inputs, nb, old_node, &workspace_tensors, |
| are_workspace_tensors_available); |
| } |
| |
| // Sanity check |
| int old_node_input_slots = old_node->op_def().input_arg_size(); |
| if (!are_workspace_tensors_available) { |
| // If we are not adding workspace tensors for this op, then the total |
| // number of input slots to the new node _must_ be 2 times the number |
| // of input slots to the original node: N original Tensorflow tensors and |
| // N for Mkl tensors corresponding to each Tensorflow tensors. |
| CHECK_EQ(new_node_input_slots, old_node_input_slots * 2); |
| } else { |
| // If we are adding workspace tensors for this op, then the total |
| // The total number of input slots to new node _must_ be 2 times the number |
| // of input slots to the original node: N original Tensorflow tensors and |
| // N for Mkl tensors corresponding to each Tensorflow tensors plus 2 |
| // (for workspace Tensorflow tensor and workspace Mkl tensor). |
| CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2); |
| } |
| |
| return Status::OK(); |
| } |
| |
| ////////////////////////////////////////////////////////////////////////// |
| // Helper functions related to workspace pass |
| ////////////////////////////////////////////////////////////////////////// |
| |
| // TODO(nhasabni) We should move this to mkl_util.h. |
| void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( |
| std::unique_ptr<Graph>* g, Node** out, Node* orig_node) { |
| // We use uint8 tensor of shape 8 with content {0,0,0,0,0,0,0,0} to represent |
| // workspace tensor. |
| GetDummyMklTensorNode(g, out, orig_node); |
| } |
| |
| void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( |
| std::unique_ptr<Graph>* g, Node* orig_node, NodeBuilder* nb, |
| std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) { |
| bool workspace_edge_added = false; // Default initializer |
| CHECK_NOTNULL(are_ws_tensors_added); |
| *are_ws_tensors_added = false; // Default initializer |
| |
| DataType T; |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| for (auto ws : wsinfo_) { |
| if (orig_node->type_string() == ws.fwd_op && |
| mkl_op_registry::IsMklOp( |
| mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { |
| // If this op is a fwd op, then we need to check if there is an |
| // edge from this node's fwd_slot to bwdop's bwd_slot. If there is |
| // an edge, then we just add an attribute on this node for setting |
| // workspace_passed to true. We don't add actual workspace edge |
| // in this node. Actual workspace edge gets added in the backward |
| // op for this node. |
| for (const Edge* e : orig_node->out_edges()) { |
| if (e->src_output() == ws.fwd_slot && |
| e->dst()->type_string() == ws.bwd_op && |
| e->dst_input() == ws.bwd_slot) { |
| nb->Attr("workspace_enabled", true); |
| VLOG(1) << "MklLayoutRewritePass: workspace_enabled for " |
| << orig_node->type_string(); |
| workspace_edge_added = true; |
| // We found the edge that we were looking for, so break. |
| break; |
| } |
| } |
| |
| if (!workspace_edge_added) { |
| // If we are here, then we did not find backward operator for this |
| // node. |
| nb->Attr("workspace_enabled", false); |
| } |
| } else if (orig_node->type_string() == ws.bwd_op && |
| mkl_op_registry::IsMklOp( |
| mkl_op_registry::GetMklOpName(orig_node->type_string()), |
| T)) { |
| // If this op is a bwd op, then we need to add workspace edge and |
| // it's Mkl tensor edge between its corresponding fwd op and this |
| // op. Corresponding fwd op is specified in 'fwd_op' field of |
| // workspace info. fwd_slot and bwd_slot in workspace info specify |
| // an edge between which slots connect forward and backward op. |
| // Once all these criteria match, we add a workspace edge between |
| // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is |
| // determined by interleaved/contiguous ordering. Function |
| // DataIndexToMetaDataIndex tells us the location of Mkl tensor |
| // from the location of the Tensorflow tensor. |
| for (const Edge* e : orig_node->in_edges()) { |
| if (e->src_output() == ws.fwd_slot && |
| // We would have rewritten the forward op, so we need to use |
| // GetMklOpName call to get its Mkl name. |
| e->src()->type_string() == |
| mkl_op_registry::GetMklOpName(ws.fwd_op) && |
| e->dst_input() == ws.bwd_slot) { |
| nb->Attr("workspace_enabled", true); |
| CHECK_NOTNULL(ws_tensors); |
| // Add workspace edge between fwd op and bwd op. |
| ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot)); |
| // Add Mkl tensor edge for workspace edge between fwd op and bwd op. |
| ws_tensors->push_back(NodeBuilder::NodeOut( |
| e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot, |
| e->src()->num_outputs()))); |
| *are_ws_tensors_added = true; |
| // In terms of input ordering, we add these calls to add Input |
| // here because workspace edge (and its Mkl tensor) is the last |
| // edge in the fwdop and bwdop. So all inputs before workspace |
| // tensor have been added by SetUpInputs function. |
| VLOG(1) << "MklLayoutRewritePass: workspace_enabled for " |
| << orig_node->type_string(); |
| workspace_edge_added = true; |
| // We found the edge that we were looking for, so break. |
| break; |
| } |
| } |
| |
| // If we are here means we did not find fwd op that feeds to this |
| // bwd op. So in this case, we need to generate dummy tensors for |
| // workspace input and Mkl tensor for workspace, and set |
| // workspace_enabled to false. |
| if (!workspace_edge_added) { |
| nb->Attr("workspace_enabled", false); |
| Node* dmt_ws = nullptr; // Dummy tensor for workspace |
| Node* dmt_mkl_ws = nullptr; // Dummy Mkl tensor for workspace |
| GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node); |
| GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node); |
| CHECK_NOTNULL(dmt_ws); |
| CHECK_NOTNULL(dmt_mkl_ws); |
| CHECK_NOTNULL(ws_tensors); |
| // We add dummy tensor as workspace tensor. |
| ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0)); |
| // We add dummy tensor as Mkl tensor for workspace tensor. |
| ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0)); |
| *are_ws_tensors_added = true; |
| VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for " |
| << orig_node->type_string(); |
| } |
| } else { |
| // If this node does not match any workspace info, then we do not |
| // do anything special for workspace propagation for it. |
| } |
| } |
| } |
| |
| ////////////////////////////////////////////////////////////////////////// |
| // Op-specific functions to copy attributes from old node to new node |
| ////////////////////////////////////////////////////////////////////////// |
| |
| void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| string data_format; |
| string padding; |
| std::vector<int32> strides; |
| std::vector<int32> dilations; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("strides", strides); |
| nb->Attr("dilations", dilations); |
| nb->Attr("padding", padding); |
| nb->Attr("data_format", data_format); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| int N; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("N", N); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| string data_format; |
| std::vector<int32> strides; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("strides", strides); |
| nb->Attr("data_format", data_format); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| int depth_radius; |
| float bias; |
| float alpha; |
| float beta; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "depth_radius", &depth_radius)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "bias", &bias)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "beta", &beta)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("depth_radius", depth_radius); |
| nb->Attr("bias", bias); |
| nb->Attr("alpha", alpha); |
| nb->Attr("beta", beta); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| string data_format; |
| string padding; |
| std::vector<int32> ksize, strides; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("ksize", ksize); |
| nb->Attr("strides", strides); |
| nb->Attr("padding", padding); |
| nb->Attr("data_format", data_format); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| DataType Tshape; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape)); |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("Tshape", Tshape); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| string data_format; |
| int num_split; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_split", &num_split)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("num_split", num_split); |
| nb->Attr("data_format", data_format); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsConcat(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| int N; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("N", N); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsConcatV2(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| int N; |
| DataType tidx; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tidx", &tidx)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("N", N); |
| nb->Attr("Tidx", tidx); |
| } |
| |
| void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node, |
| NodeBuilder* nb) { |
| DataType T; |
| float epsilon; |
| string data_format; |
| bool is_training; |
| |
| // Get all attributes from old node. |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); |
| TF_CHECK_OK(GetNodeAttr(orig_node->def(), "is_training", &is_training)); |
| |
| // Add attributes to new node. |
| nb->Attr("T", T); |
| nb->Attr("epsilon", epsilon); |
| nb->Attr("data_format", data_format); |
| nb->Attr("is_training", is_training); |
| } |
| |
| ////////////////////////////////////////////////////////////////////////// |
| // Helper functions related to node merge pass |
| ////////////////////////////////////////////////////////////////////////// |
| |
| Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const { |
| // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite |
| // once we support BiasAddGrad as Mkl layer. |
| |
| // Search for all matching mergeinfo. |
| // We allow more than one match for extensibility. |
| std::vector<const MergeInfo*> matching_mi; |
| for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) { |
| if (a->type_string() == mi->op1 || a->type_string() == mi->op2) { |
| matching_mi.push_back(&*mi); |
| } |
| } |
| |
| for (const MergeInfo* mi : matching_mi) { |
| // Get the operand with which 'a' can be merged. |
| Node* b = nullptr; |
| if ((b = mi->get_node_to_be_merged(a)) == nullptr) { |
| continue; |
| } |
| |
| // Get the control edges and input of node |
| const int N_in = a->num_inputs(); |
| gtl::InlinedVector<Node*, 4> a_control_edges; |
| gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in); |
| FillInputs(a, &a_control_edges, &a_in); |
| |
| const int B_in = b->num_inputs(); |
| gtl::InlinedVector<Node*, 4> b_control_edges; |
| gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in); |
| FillInputs(b, &b_control_edges, &b_in); |
| |
| // Shouldn't merge if a and b have different control edges. |
| if (a_control_edges != b_control_edges) { |
| continue; |
| } else { |
| // We found a match. |
| return b; |
| } |
| } |
| |
| return nullptr; |
| } |
| |
| Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, |
| Node* m, Node* n) { |
| CHECK_EQ(((m->type_string() == csinfo_.bias_add && |
| n->type_string() == csinfo_.conv2d)) || |
| ((n->type_string() == csinfo_.bias_add && |
| m->type_string() == csinfo_.conv2d)), |
| true); |
| |
| // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd, |
| // BiasAdd is successor node, and Conv2D predecessor node. |
| Node* pred = m->type_string() == csinfo_.bias_add ? n : m; |
| Node* succ = m->type_string() == csinfo_.bias_add ? m : n; |
| |
| // 1. Get all attributes from input nodes. |
| DataType T_pred, T_succ; |
| string padding; |
| std::vector<int32> strides; |
| std::vector<int32> dilations; |
| string data_format_pred, data_format_succ; |
| bool use_cudnn_on_gnu; |
| TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred)); |
| TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ)); |
| TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding)); |
| TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides)); |
| TF_CHECK_OK(GetNodeAttr(pred->def(), "dilations", &dilations)); |
| TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred)); |
| TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ)); |
| TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu)); |
| // We check to ensure that data formats of both succ and pred are same. |
| // We expect them to be same, so we can enforce this as assert. |
| // But assert can be too strict, so we enforce this as a check. |
| // If the check fails, then we do not merge two nodes. |
| // We also do same check for devices. |
| if (data_format_pred != data_format_succ || T_pred != T_succ || |
| pred->assigned_device_name() != succ->assigned_device_name() || |
| pred->def().device() != succ->def().device()) { |
| return Status(error::Code::INVALID_ARGUMENT, |
| "data_format or T attribute or devices of Conv2D and " |
| "BiasAdd do not match. Will skip node merge optimization"); |
| } |
| |
| const int succ_num = succ->num_inputs(); |
| gtl::InlinedVector<Node*, 4> succ_control_edges; |
| gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num); |
| FillInputs(succ, &succ_control_edges, &succ_in); |
| |
| const int pred_num = pred->num_inputs(); |
| gtl::InlinedVector<Node*, 4> pred_control_edges; |
| gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num); |
| FillInputs(pred, &pred_control_edges, &pred_in); |
| |
| // We need to ensure that Conv2D only feeds to BiasAdd (some other operator is |
| // not expecting output of Conv2D). If this is not the case, then we cannot |
| // merge Conv2D with BiasAdd. |
| const int kFirstOutputSlot = 0; |
| for (const Edge* e : pred->out_edges()) { |
| if (e->src_output() == kFirstOutputSlot && e->dst() != succ) { |
| return Status(error::Code::INVALID_ARGUMENT, |
| "Conv2D does not feed to BiasAdd, or " |
| "it feeds BiasAdd but has multiple outputs. " |
| "Will skip node merge optimization"); |
| } |
| } |
| |
| // 2. Get inputs from both the nodes. |
| // Find the 2 inputs from the conv and the bias from the add Bias. |
| // Get operand 0, 1 of conv2D. |
| CHECK_EQ(pred->in_edges().size(), 2); // Conv2D must have 2 inputs. |
| // Get operand 1 of add_bias |
| // BiasAdd must have 2 inputs: Conv, bias |
| CHECK_EQ(succ->in_edges().size(), 2); |
| |
| // We will use the node name of BiasAdd as the name of new node |
| // Build new node. We use same name as original node, but change the op |
| // name. |
| NodeBuilder nb(succ->name(), csinfo_.conv2d_with_bias); |
| nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D |
| // pred_in[1] will be 2nd Tensorflow tensor for Conv2D. |
| nb.Input(pred_in[1].first, pred_in[1].second); // In2 of Conv2D |
| // In1 of BiasAdd is same as output of Conv2D. |
| nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd |
| |
| // Copy attributes from Conv2D to Conv2DWithBias. |
| CopyAttrsConv(const_cast<const Node*>(pred), &nb); |
| |
| // Copy the device assigned to old node to new node. |
| nb.Device(succ->def().device()); |
| |
| // Create node. |
| Node* new_node; |
| TF_CHECK_OK(nb.Finalize(&**g, &new_node)); |
| CHECK_NOTNULL(new_node); |
| |
| // Incoming data edges from 'pred' node and 'succ' node to new 'new_node' |
| // node are already copied in BuildNode. We handle control edges now. |
| for (const Edge* e : pred->in_edges()) { |
| if (e->IsControlEdge()) { |
| // Allow duplicate while adding control edge as it would fail (return |
| // NULL) if we try to add duplicate edge. |
| CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); |
| } |
| } |
| for (const Edge* e : succ->in_edges()) { |
| if (e->IsControlEdge()) { |
| // Allow duplicate while adding control edge as it would fail (return |
| // NULL) if we try to add duplicate edge. |
| CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); |
| } |
| } |
| |
| // Incoming edges are fixed, we will fix the outgoing edges now. |
| // First, we will fix outgoing control edges from 'pred' node. |
| for (const Edge* e : pred->out_edges()) { |
| if (e->IsControlEdge()) { |
| // Allow duplicate while adding control edge as it would fail (return |
| // NULL) if we try to add duplicate edge. |
| CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); |
| } |
| } |
| |
| // Second, we will fix outgoing control and data edges from 'succ' node. |
| for (const Edge* e : succ->out_edges()) { |
| if (e->IsControlEdge()) { |
| // Allow duplicate while adding control edge as it would fail (return |
| // NULL) if we try to add duplicate edge. |
| CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); |
| } else { |
| // BiasAdd has only 1 output (at slot 0) and merged node also has only 1 |
| // output (at slot 0). |
| const int kConv2DWithBiasOutputSlot = 0; |
| CHECK_NOTNULL((*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, e->dst(), |
| e->dst_input())); |
| } |
| } |
| |
| // Copy device assigned to old node to new node. |
| // It's ok to use pred or succ as we have enforced a check that |
| // both have same device assigned. |
| new_node->set_assigned_device_name(pred->assigned_device_name()); |
| |
| VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString() |
| << ", and node: " << succ->DebugString() |
| << ", into node:" << new_node->DebugString(); |
| |
| (*g)->RemoveNode(succ); |
| (*g)->RemoveNode(pred); |
| |
| return Status::OK(); |
| } |
| |
| Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad( |
| std::unique_ptr<Graph>* g, Node* m, Node* n) { |
| CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad && |
| n->type_string() == csinfo_.conv2d_grad_filter)) || |
| ((n->type_string() == csinfo_.bias_add_grad && |
| m->type_string() == csinfo_.conv2d_grad_filter)), |
| true); |
| |
| // If 'm' is BiasAddGrad, then 'n' is BackpropFilter. |
| Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n; |
| Node* fltr = m->type_string() == csinfo_.bias_add_grad ? n : m; |
| |
| // Sanity check for attributes from input nodes. |
| DataType T_b, T_f; |
| string data_format_b, data_format_f; |
| TF_CHECK_OK(GetNodeAttr(badd->def(), "T", &T_b)); |
| TF_CHECK_OK(GetNodeAttr(fltr->def(), "T", &T_f)); |
| TF_CHECK_OK(GetNodeAttr(badd->def(), "data_format", &data_format_b)); |
| TF_CHECK_OK(GetNodeAttr(fltr->def(), "data_format", &data_format_f)); |
| if (data_format_b != data_format_f || T_b != T_f || |
| badd->assigned_device_name() != fltr->assigned_device_name() || |
| badd->def().device() != fltr->def().device()) { |
| return Status(error::Code::INVALID_ARGUMENT, |
| "data_format or T attribute or devices of " |
| "Conv2DBackpropFilter and BiasAddGrad do not match. " |
| "Will skip node merge optimization"); |
| } |
| |
| // We will use the node name of Conv2DBackpropFilter as the name of new node. |
| // This is because BackpropFilterWithBias is going to emit bias output also. |
| NodeBuilder nb(fltr->name(), csinfo_.conv2d_grad_filter_with_bias); |
| // Since Conv2DBackpropFilterWithBias has same number of inputs as |
| // Conv2DBackpropFilter, we can just copy input edges directly. We dont need |
| // to copy any data input of BiasAddGrad because that input also goes to |
| // Conv2DBackpropFilter. |
| const int fltr_ins = fltr->num_inputs(); |
| gtl::InlinedVector<Node*, 4> fltr_control_edges; |
| gtl::InlinedVector<std::pair<Node*, int>, 4> fltr_in_edges(fltr_ins); |
| FillInputs(fltr, &fltr_control_edges, &fltr_in_edges); |
| for (int idx = 0; idx < fltr_ins; idx++) { |
| nb.Input(fltr_in_edges[idx].first, fltr_in_edges[idx].second); |
| } |
| |
| // Copy attributes from Conv2DBackpropFilter. |
| CopyAttrsConv(const_cast<const Node*>(fltr), &nb); |
| |
| // Copy the device assigned to old node to new node. |
| nb.Device(fltr->def().device()); |
| |
| // Create node. |
| Node* new_node; |
| TF_CHECK_OK(nb.Finalize(&**g, &new_node)); |
| CHECK_NOTNULL(new_node); |
| |
| // Incoming data edges from BiasAddGrad node and Conv2DBackpropFilter node to |
| // new 'new_node' node are already copied in BuildNode. We handle control |
| // edges now. |
| for (const Edge* e : badd->in_edges()) { |
| if (e->IsControlEdge()) { |
| // Allow duplicate while adding control edge as it would fail (return |
| // NULL) if we try to add duplicate edge. |
| CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); |
| } |
| } |
| for (const Edge* e : fltr->in_edges()) { |
| if (e->IsControlEdge()) { |
| // Allow duplicate while adding control edge as it would fail (return |
| // NULL) if we try to add duplicate edge. |
| CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); |
| } |
| } |
| |
| // Incoming edges are fixed, we will fix the outgoing edges now. |
| // First, we will fix outgoing control edges from 'badd' node. |
| // Conv2DBackpropFilter has 1 output -- filter_grad. |
| // Conv2DBackpropFilterWithBias has 2 outputs -- filter_grad and |
| // bias_grad. But filter_grad is at same slot number (0) in both the |
| // nodes. bias_grad is at slot number 1 in Conv2DBackpropFilterWithBias, while |
| // it is at slot number 0 in BiasAddGrad. |
| const int kMergedNodeFilterGradOutputIdx = 0; |
| const int kMergedNodeBiasGradOutputIdx = 1; |
| |
| for (const Edge* e : badd->out_edges()) { |
| if (e->IsControlEdge()) { |
| // Allow duplicate while adding control edge as it would fail (return |
| // NULL) if we try to add duplicate edge. |
| CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); |
| } else { |
| CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeBiasGradOutputIdx, |
| e->dst(), e->dst_input())); |
| } |
| } |
| |
| // Second, we will fix outgoing control and data edges from 'fltr' node. |
| for (const Edge* e : fltr->out_edges()) { |
| if (e->IsControlEdge()) { |
| // We allow duplicate edge for this case since we already add control |
| // edge from new_node in line 3990. Line below could be adding same |
| // edge to same destination again. In such case, if we do not allow |
| // duplicate edge, then this call will fail. |
| CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); |
| } else { |
| CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeFilterGradOutputIdx, |
| e->dst(), e->dst_input())); |
| } |
| } |
| |
| // Copy device assigned to old node to new node. |
| // It's ok to use badd or fltr as we have enforced a check that |
| // both have same device assigned. |
| new_node->set_assigned_device_name(badd->assigned_device_name()); |
| |
| VLOG(1) << "MklLayoutRewritePass: Merged old node:" << badd->DebugString() |
| << ", and node: " << fltr->DebugString() |
| << ", into node:" << new_node->DebugString(); |
| |
| (*g)->RemoveNode(badd); |
| (*g)->RemoveNode(fltr); |
| |
| return Status::OK(); |
| } |
| |
| Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* m, |
| Node* n) { |
| CHECK_NOTNULL(m); |
| CHECK_NOTNULL(n); |
| |
| if (((m->type_string() == csinfo_.bias_add && |
| n->type_string() == csinfo_.conv2d)) || |
| ((n->type_string() == csinfo_.bias_add && |
| m->type_string() == csinfo_.conv2d))) { |
| return this->MergeConv2DWithBiasAdd(g, m, n); |
| } |
| |
| if (((m->type_string() == csinfo_.bias_add_grad && |
| n->type_string() == csinfo_.conv2d_grad_filter)) || |
| ((n->type_string() == csinfo_.bias_add_grad && |
| m->type_string() == csinfo_.conv2d_grad_filter))) { |
| return this->MergeConv2DBackpropFilterWithBiasAddGrad(g, m, n); |
| } |
| |
| return Status(error::Code::UNIMPLEMENTED, |
| "Unimplemented case for node merge optimization."); |
| } |
| |
| ////////////////////////////////////////////////////////////////////////// |
| // Helper functions for node rewrite |
| ////////////////////////////////////////////////////////////////////////// |
| |
| Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, |
| Node* orig_node, |
| const RewriteInfo* ri) { |
| CHECK_NOTNULL(ri); |
| CHECK_NOTNULL(orig_node); |
| |
| VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString(); |
| |
| // Get all inputs. |
| int num_inputs = orig_node->in_edges().size(); |
| |
| // Drop count for control edges from inputs |
| for (const Edge* e : orig_node->in_edges()) { |
| if (e->IsControlEdge()) { |
| num_inputs--; |
| } |
| } |
| |
| gtl::InlinedVector<Node*, 4> control_edges; |
| gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs); |
| FillInputs(orig_node, &control_edges, &inputs); |
| |
| // Build new node. We use same name as original node, but change the op name. |
| NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str()); |
| // Copy user-specified device assigned to original node to new node. |
| nb.Device(orig_node->def().device()); |
| // Set up new inputs to the rewritten node. |
| Status s = SetUpInputs(g, inputs, &nb, orig_node); |
| if (s != Status::OK()) { |
| return s; |
| } |
| |
| ri->copy_attrs(const_cast<const Node*>(orig_node), &nb); |
| // Set the Mkl layer label for this op. |
| nb.Attr("_kernel", mkl_op_registry::kMklOpLabel); |
| |
| // Finalize graph and get new node. |
| Node* new_node = nullptr; |
| TF_CHECK_OK(nb.Finalize(&**g, &new_node)); |
| CHECK_NOTNULL(new_node); |
| |
| // Incoming data edges from 'orig_node' node to new 'new_node' node are |
| // already copied in BuildNode. We need to handle control edges now. |
| for (const Edge* e : orig_node->in_edges()) { |
| if (e->IsControlEdge()) { |
| // Allow duplicate while adding control edge as it would fail (return |
| // NULL) if we try to add duplicate edge. |
| CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); |
| } |
| } |
| |
| // Copy outgoing edges from 'orig_node' node to new |
| // 'new_node' node, since the output also follows same ordering among |
| // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow |
| // tensors appropriately. Specifically, nth output of the original node |
| // will become 2*nth output of the Mkl node for the interleaved ordering |
| // of the tensors. For the contiguous ordering of the tensors, it will be n. |
| // GetTensorDataIndex provides this mapping function. |
| for (const Edge* e : orig_node->out_edges()) { |
| if (e->IsControlEdge()) { |
| // Allow duplicate while adding control edge as it would fail (return |
| // NULL) if we try to add duplicate edge. |
| CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); |
| } else { |
| CHECK_NOTNULL((*g)->AddEdge( |
| new_node, |
| GetTensorDataIndex(e->src_output(), e->src()->num_outputs()), |
| e->dst(), e->dst_input())); |
| } |
| } |
| |
| // Copy the runtime device assigned from original code to new node. |
| new_node->set_assigned_device_name(orig_node->assigned_device_name()); |
| |
| // Delete original node and mark new node as rewritten. |
| (*g)->RemoveNode(orig_node); |
| |
| VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString(); |
| return Status::OK(); |
| } |
| |
| const MklLayoutRewritePass::RewriteInfo* |
| MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { |
| CHECK_NOTNULL(n); |
| |
| // First check if node along with its type is supported by MKL layer. |
| // We do not want to rewrite an op into Mkl op if types are not supported. |
| // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to |
| // MklRelu if type is INT32. |
| DataType T; |
| if (!GetNodeAttr(n->def(), "T", &T).ok()) { |
| return nullptr; |
| } |
| |
| // We make an exception for __MklDummyConv2DWithBias and |
| // __MklConv2DBackpropFilterWithBias since their names do not match Mkl node |
| // names. |
| if (n->type_string() != csinfo_.conv2d_with_bias && |
| n->type_string() != csinfo_.conv2d_grad_filter_with_bias && |
| !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), |
| T)) { |
| return nullptr; |
| } |
| |
| // For elementwise node, we reuse the Eigen implementation and pass the MKL |
| // metadata tensor through so we can avoid conversions. However, if all |
| // incoming edges are in TF format, we don't need all this overhead, so |
| // replace the elementwise node only if at least one of its parents is a MKL |
| // node. |
| // |
| // Identity nodes can also skip replacement if they are not being served by |
| // any MKL nodes. |
| // |
| // TODO(vrane): Add implementation for element-wise ops that doesn't reuse |
| // eigen code to reduce cross-library dependency. |
| VLOG(1) << "ELEMENTWISE: checking op: " << n->type_string(); |
| if (mkl_op_registry::IsMklElementWiseOp( |
| mkl_op_registry::GetMklOpName(n->type_string()), T) || |
| n->type_string().find("Identity") != string::npos) { |
| VLOG(1) << "ELEMENTWISE: op is elementwise: " << n->type_string(); |
| bool incoming_mkl_edge = false; |
| int num_parent = 0; |
| for (auto parent : n->in_edges()) { |
| if (mkl_op_registry::IsMklOp(parent->src()->type_string(), T)) { |
| VLOG(1) << "ELEMENTWISE: parent " << num_parent++ |
| << " is MKL op: " << parent->src()->type_string(); |
| incoming_mkl_edge = true; |
| break; |
| } else { |
| VLOG(1) << "ELEMENTWISE: parent " << num_parent++ |
| << " is NON-MKL op: " << parent->src()->type_string(); |
| } |
| } |
| if (incoming_mkl_edge == false) { |
| VLOG(1) << "ELEMENTWISE: Skipping replacement of elementwise node which " |
| "has no MKL " |
| "parents."; |
| return nullptr; |
| } else { |
| VLOG(1) << "ELEMENTWISE: Replacing elementwise node " << n->type_string() |
| << " which has MKL parents"; |
| } |
| } |
| |
| // We now check if rewrite rule applies for this op. If rewrite rule passes |
| // for this op, then we rewrite it to Mkl op. |
| // Find matching RewriteInfo and then check that rewrite rule applies. |
| for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) { |
| if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) { |
| return &*ri; |
| } |
| } |
| |
| // Else return not found. |
| return nullptr; |
| } |
| |
| /////////////////////////////////////////////////////////////////////////////// |
| // Post-rewrite Mkl metadata fixup pass |
| /////////////////////////////////////////////////////////////////////////////// |
| bool MklLayoutRewritePass::FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g, |
| const Edge* e_data, const Edge* e_metadata) { |
| if (g == nullptr || e_data == nullptr || e_metadata == nullptr) { |
| return false; |
| } |
| |
| Node* n_data = e_data->src(); |
| int n_data_op_slot = e_data->src_output(); |
| int n_metadata_op_slot = GetTensorMetaDataIndex(n_data_op_slot, |
| n_data->num_outputs()); |
| |
| // If the source of meta edge is a constant node (producing dummy Mkl metadata |
| // tensor), then we will need to fix. |
| if (IsConstant(e_metadata->src())) { |
| Node* e_metadata_dst = e_metadata->dst(); |
| int e_metadata_in_slot = e_metadata->dst_input(); |
| CHECK_NOTNULL((*g)->AddEdge(n_data, n_metadata_op_slot, |
| e_metadata_dst, e_metadata_in_slot)); |
| |
| (*g)->RemoveEdge(e_metadata); |
| return true; |
| } |
| |
| return false; |
| } |
| |
| bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g, |
| Node* n) { |
| bool result = false; |
| |
| // If graph node is not Mkl node, then return. |
| DataType T = DT_INVALID; |
| if (!GetNodeAttr(n->def(), "T", &T).ok() || |
| !mkl_op_registry::IsMklOp(n->type_string(), T)) { |
| return result; |
| } |
| |
| // If it is Mkl node, then check if the input edges to this node that carry |
| // Mkl metadata are linked up correctly with the source node. |
| |
| // For Mkl nodes, we generate twice the number of input tensors (n for Mkl |
| // data tensors + n for Mkl metadata tensors). We need to check for correct |
| // connection of n metadata tensors only. |
| int num_data_inputs = n->num_inputs() / 2; |
| for (int idx = 0; idx < num_data_inputs; idx++) { |
| // Get the edge connecting input slot with index (idx). |
| const Edge* e = nullptr; |
| TF_CHECK_OK(n->input_edge(idx, &e)); |
| |
| // If e is control edge, then skip. |
| if (e->IsControlEdge()) { |
| continue; |
| } |
| |
| // Check that the source node for edge 'e' is Mkl node. If it is not an Mkl |
| // node, then we don't need to do anything. |
| Node* e_src = e->src(); |
| if (GetNodeAttr(e_src->def(), "T", &T).ok() && |
| mkl_op_registry::IsMklOp(e_src->type_string(), T)) { |
| // Source node for edge 'e' is Mkl node. |
| // Destination node and destination input slot of e is node 'n' and 'idx' |
| // resp. |
| CHECK_EQ(e->dst(), n); |
| CHECK_EQ(e->dst_input(), idx); |
| |
| // Let's get edge that carries Mkl metadata corresponding to Mkl data edge |
| // 'e'. For that, let's first get the input slot of 'n' where the meta |
| // edge will feed the value. |
| int e_meta_in_slot = GetTensorMetaDataIndex(e->dst_input(), |
| n->num_inputs()); |
| const Edge* e_meta = nullptr; |
| TF_CHECK_OK(n->input_edge(e_meta_in_slot, &e_meta)); |
| |
| // Let's check if we need to fix this meta edge. |
| if (FixMklMetaDataEdgeIfNeeded(g, e, e_meta)) { |
| result = true; |
| } |
| } |
| } |
| |
| return result; |
| } |
| |
| /////////////////////////////////////////////////////////////////////////////// |
| // Run function for the pass |
| /////////////////////////////////////////////////////////////////////////////// |
| |
| bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) { |
| bool result = false; |
| CHECK_NOTNULL(g); |
| |
| DumpGraph("Before running MklLayoutRewritePass", &**g); |
| |
| std::vector<Node*> order; |
| GetReversePostOrder(**g, &order); // This will give us topological sort. |
| for (Node* n : order) { |
| // If node is not an op or it cannot run on CPU device, then skip. |
| if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) { |
| continue; |
| } |
| |
| Node* m = nullptr; |
| if ((m = CheckForNodeMerge(n)) != nullptr && CanOpRunOnCPUDevice(m)) { |
| // Check if the node 'n' can be merged with any other node. If it can |
| // be 'm' contains the node with which it can be merged. |
| string n1_name = n->name(); |
| string n2_name = m->name(); |
| |
| VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and " |
| << n2_name << " for merging"; |
| |
| if (MergeNode(g, n, m) == Status::OK()) { |
| VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and " |
| << n2_name; |
| result = true; |
| } |
| } |
| } |
| |
| DumpGraph("After running MklLayoutRewritePass(NodeMerge)", &**g); |
| |
| order.clear(); |
| GetReversePostOrder(**g, &order); // This will give us topological sort. |
| for (Node* n : order) { |
| // If node is not an op or it cannot run on CPU device, then skip. |
| if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) { |
| continue; |
| } |
| |
| const RewriteInfo* ri = nullptr; |
| // We will first search if node is to be rewritten. |
| if ((ri = CheckForNodeRewrite(n)) != nullptr) { |
| string node_name = n->name(); |
| string op_name = n->type_string(); |
| |
| VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name |
| << " with op " << op_name << " for rewrite using" |
| << " layout optimization."; |
| |
| if (RewriteNode(g, n, ri) == Status::OK()) { |
| VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name |
| << " with op " << op_name << " for Mkl layout optimization."; |
| result = true; |
| } |
| } |
| } |
| |
| DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g); |
| |
| order.clear(); |
| GetReversePostOrder(**g, &order); // This will give us topological sort. |
| for (Node* n : order) { |
| // If node is not an op or it cannot run on CPU device, then skip. |
| if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) { |
| continue; |
| } |
| if (FixMklMetaDataEdges(g, n)) { |
| string node_name = n->name(); |
| string op_name = n->type_string(); |
| |
| VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node " |
| << node_name << " with op " << op_name; |
| result = true; |
| } |
| } |
| DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)", |
| &**g); |
| |
| return result; |
| } |
| |
| bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) { |
| return MklLayoutRewritePass().RunPass(g); |
| } |
| |
| Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) { |
| if (options.graph == nullptr && options.partition_graphs == nullptr) { |
| return Status::OK(); |
| } |
| |
| auto process_graph = [&](std::unique_ptr<Graph>* g) { |
| // Get the ownership of a graph |
| std::unique_ptr<Graph>* ng = std::move(g); |
| RunPass(ng); |
| // Return the ownership of a graph back |
| g->reset(ng->release()); |
| }; |
| |
| if (kMklLayoutRewritePassGroup != |
| OptimizationPassRegistry::POST_PARTITIONING) { |
| // For any pre-partitioning phase, a graph is stored in options.graph. |
| process_graph(options.graph); |
| } else { |
| // For post partitioning phase, graphs are stored in |
| // options.partition_graphs. |
| for (auto& pg : *options.partition_graphs) { |
| process_graph(&pg.second); |
| } |
| } |
| |
| return Status::OK(); |
| } |
| #endif // INTEL_MKL_ML_ONLY |
| } // namespace tensorflow |
| |
| #endif |