|  | /** | 
|  | * Copyright (c) 2016-present, Facebook, Inc. | 
|  | * | 
|  | * Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | * you may not use this file except in compliance with the License. | 
|  | * You may obtain a copy of the License at | 
|  | * | 
|  | *     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | * | 
|  | * Unless required by applicable law or agreed to in writing, software | 
|  | * distributed under the License is distributed on an "AS IS" BASIS, | 
|  | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | * See the License for the specific language governing permissions and | 
|  | * limitations under the License. | 
|  | */ | 
|  |  | 
|  | #ifndef CAFFE2_OPT_FUSION_H_ | 
|  | #define CAFFE2_OPT_FUSION_H_ | 
|  |  | 
|  | #include "caffe2/core/workspace.h" | 
|  | #include "nomnigraph/Representations/NeuralNet.h" | 
|  |  | 
|  | namespace caffe2 { | 
|  | namespace opt { | 
|  |  | 
|  | using namespace nom; | 
|  |  | 
|  | TORCH_API void fuseConvBN(repr::NNModule* nn, caffe2::Workspace* ws); | 
|  |  | 
|  | // Generic activation fusion helper. | 
|  | // | 
|  | // \tparam OperationT The operator to be fused. | 
|  | // \tparam ActivationT The activation to be fused. | 
|  | // \param nn Neural network module to be modified in place | 
|  | // \param should_fuse Given a conv op, check whether we want to fuse it with | 
|  | // subsequent relu or not | 
|  | // \param postprocess Functor to postprocess the conv node, | 
|  | // attaching additional attributes if necessary | 
|  | template <typename OperationT, typename ActivationT> | 
|  | C10_EXPORT void fuseActivation( | 
|  | repr::NNModule* nn, | 
|  | std::function<bool(const OperationT& conv)> should_fuse, | 
|  | std::function<void(repr::NNGraph::NodeRef conv_node)> postprocess) { | 
|  | for (auto node_pair : repr::nn::dataIterator<OperationT>(nn->dataFlow)) { | 
|  | repr::NNGraph::NodeRef conv_node; | 
|  | OperationT* conv; | 
|  | std::tie(conv, conv_node) = node_pair; | 
|  |  | 
|  | // Check topological feasibility | 
|  | auto conv_outputs = repr::nn::getOutputs(conv_node); | 
|  | if (conv_outputs.size() != 1) { | 
|  | continue; | 
|  | } | 
|  | auto conv_output = conv_outputs.front(); | 
|  |  | 
|  | auto consumers = repr::nn::getConsumers(conv_output); | 
|  | if (consumers.size() != 1) { | 
|  | continue; | 
|  | } | 
|  | if (!repr::nn::is<ActivationT>(consumers.front())) { | 
|  | continue; | 
|  | } | 
|  | auto relu_node = consumers.front(); | 
|  |  | 
|  | auto relu_outputs = repr::nn::getOutputs(relu_node); | 
|  | if (relu_outputs.size() != 1) { | 
|  | continue; | 
|  | } | 
|  |  | 
|  | // Check feasibility with application specific logic | 
|  | if (!should_fuse(*conv)) { | 
|  | continue; | 
|  | } | 
|  |  | 
|  | // Ready to fuse | 
|  | auto relu_output = relu_outputs.front(); | 
|  | auto output_tensor = repr::nn::get<repr::Tensor>(relu_output); | 
|  | auto output_node = relu_output; | 
|  | auto input_tensor = | 
|  | repr::nn::get<repr::Tensor>(repr::nn::getInputs(conv_node).front()); | 
|  |  | 
|  | // Conv cannot be in-place | 
|  | if (output_tensor->getName() != input_tensor->getName()) { | 
|  | nn->dataFlow.replaceNode(conv_output, relu_output); | 
|  | nn->dataFlow.deleteNode(relu_node); | 
|  | nn->dataFlow.deleteNode(conv_output); | 
|  | } else { | 
|  | nn->dataFlow.replaceNode(relu_output, conv_output); | 
|  | output_tensor = repr::nn::get<repr::Tensor>(conv_output); | 
|  | output_node = conv_output; | 
|  | nn->dataFlow.deleteNode(relu_node); | 
|  | nn->dataFlow.deleteNode(relu_output); | 
|  | } | 
|  |  | 
|  | // We may have accidentally made the next op in-place | 
|  | // In future iterations of transformations this won't be an issue, | 
|  | // but current caffe2 predictor usage requires things like | 
|  | // external_input and output to be unchanged. | 
|  | bool rectify_inplace = false; | 
|  | for (auto& consumer : repr::nn::getConsumers(output_node)) { | 
|  | for (auto& consumer_output : repr::nn::getOutputs(consumer)) { | 
|  | auto co_name = repr::nn::get<repr::Tensor>(consumer_output)->getName(); | 
|  | if (co_name == output_tensor->getName()) { | 
|  | rectify_inplace = true; | 
|  | } | 
|  | } | 
|  | } | 
|  | if (rectify_inplace) { | 
|  | auto new_output = nn->dataFlow.createNode( | 
|  | make_unique<repr::Tensor>(output_tensor->getName() + "_fusion_fix")); | 
|  | nn->dataFlow.replaceNode(output_node, new_output); | 
|  | } | 
|  |  | 
|  | // Application specific logic for postprocessing the conv node | 
|  | postprocess(conv_node); | 
|  | } | 
|  | } | 
|  |  | 
|  | } // namespace opt | 
|  | } // namespace caffe2 | 
|  |  | 
|  | #endif // CAFFE2_OPT_FUSION_H_ |