| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include <algorithm> |
| #include <memory> |
| #include <queue> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "absl/container/flat_hash_set.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
| #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
| #include "mlir/IR/TypeUtilities.h" // from @llvm-project |
| #include "mlir/IR/Value.h" // from @llvm-project |
| #include "mlir/Pass/Pass.h" // from @llvm-project |
| #include "mlir/Pass/PassRegistry.h" // from @llvm-project |
| #include "mlir/Transforms/RegionUtils.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h" |
| #include "tensorflow/compiler/mlir/xla/transforms/passes.h" |
| |
| namespace mlir { |
| namespace TFDevice { |
| |
| namespace { |
| |
| // For the 2d vector type aliases defined below, the first dimension represents |
| // the class of the IfRegion group and the second dimension represents the |
| // segments of the IfRegion group. |
| // For example, if we want to merge the following six IfRegions |
| // which share the same if_cond (regionA) |
| // ````````````` |
| // IfRegionA(1) |
| // IfRegionA(2) |
| // IfRegionA(3) |
| // IfRegionA(4) |
| // IfRegionA(5) |
| // IfRegionA(6) |
| // `````````````` |
| // After the analysis, we consider IfRegionA(1), IfRegionA(2) and IfRegionA(3) |
| // can be merged, IfRegionA(4) is standalone, IfRegionA(5) and IfRegionA(6) |
| // can be merged. Then the defined 2D vector is |
| // [[IfRegionA(1), IfRegionA(2), IfRegionA(3)], |
| // [IfRegionA(4)], |
| // [IfRegionA(5), IfRegionA(6)]] |
| using RegionVec2D = llvm::SmallVector<llvm::SmallVector<TF::IfRegionOp, 8>, 8>; |
| using OperationVec2D = llvm::SmallVector<llvm::SmallVector<Operation*, 8>, 8>; |
| using MapToRegionVec2D = llvm::SmallDenseMap<Value, RegionVec2D>; |
| using MapToOperationVec2D = llvm::SmallDenseMap<Value, OperationVec2D>; |
| using IfOpIterConst = |
| llvm::SmallVectorTemplateCommon<mlir::TF::IfRegionOp>::const_iterator; |
| |
| struct MergeControlFlowPass |
| : public TF::MergeControlFlowPassBase<MergeControlFlowPass> { |
| void runOnOperation() override; |
| }; |
| |
| // Gets the IfRegion op and all of ops in the then and else branches. |
| llvm::SmallSetVector<Operation*, 4> GetAllOpsFromIf(TF::IfRegionOp if_op) { |
| llvm::SmallSetVector<Operation*, 4> all_ops; |
| all_ops.insert(if_op); |
| for (Operation& op : if_op.then_branch().front()) { |
| all_ops.insert(&op); |
| } |
| for (Operation& op : if_op.else_branch().front()) { |
| all_ops.insert(&op); |
| } |
| return all_ops; |
| } |
| |
| // Returns whether it is safe to merge `second_if` IfRegion into `first_if` |
| // IfRegion. `second if` must come after `first_if`. |
| // Note that `downstream_if_ops` means the ops in IfRegions except`first_if`. |
| bool SafeToMerge(TF::IfRegionOp first_if, TF::IfRegionOp second_if, |
| llvm::SmallSetVector<Operation*, 4>& downstream_if_ops, |
| const TF::SideEffectAnalysis::Info& side_effect_analysis) { |
| // IfRegion ops must be in the same block. |
| if (second_if.getOperation()->getBlock() != |
| first_if.getOperation()->getBlock()) { |
| return false; |
| } |
| assert(first_if.getOperation()->isBeforeInBlock(second_if.getOperation())); |
| |
| llvm::SmallSetVector<Operation*, 4> destination_ops = |
| GetAllOpsFromIf(first_if); |
| |
| // If there is an intermediate data or side effect dependency between the |
| // ops in first_if and the ops in second_if, it's not safe to merge |
| // them. |
| std::vector<Operation*> dependencies; |
| for (auto* user : first_if.getOperation()->getUsers()) { |
| if (!downstream_if_ops.contains(user)) { |
| dependencies.push_back(user); |
| } |
| } |
| for (auto* successor : |
| side_effect_analysis.DirectControlSuccessors(first_if.getOperation())) { |
| if (!downstream_if_ops.contains(successor)) { |
| dependencies.push_back(successor); |
| } |
| } |
| for (Operation& op : first_if.then_branch().front()) { |
| for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) { |
| if (!downstream_if_ops.contains(successor) && |
| !destination_ops.contains(successor)) |
| dependencies.push_back(successor); |
| } |
| } |
| for (Operation& op : first_if.else_branch().front()) { |
| for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) { |
| if (!downstream_if_ops.contains(successor) && |
| !destination_ops.contains(successor)) |
| dependencies.push_back(successor); |
| } |
| } |
| |
| bool safe_to_merge = true; |
| |
| llvm::SmallPtrSet<Operation*, 4> visited; |
| while (!dependencies.empty()) { |
| Operation* dependency = dependencies.back(); |
| dependencies.pop_back(); |
| if (visited.count(dependency)) continue; |
| visited.insert(dependency); |
| for (auto* user : dependency->getUsers()) { |
| if (downstream_if_ops.contains(user)) { |
| safe_to_merge = false; |
| break; |
| } else { |
| dependencies.push_back(user); |
| } |
| } |
| for (auto* successor : |
| side_effect_analysis.DirectControlSuccessors(dependency)) { |
| if (downstream_if_ops.contains(successor)) { |
| safe_to_merge = false; |
| break; |
| } else { |
| dependencies.push_back(successor); |
| } |
| } |
| // If the op is nested, then also consider the users and successors of the |
| // parent op. |
| if (dependency->getBlock() != first_if.getOperation()->getBlock()) |
| dependencies.push_back(dependency->getParentOp()); |
| if (!safe_to_merge) break; |
| } |
| return safe_to_merge; |
| } |
| |
| // Move the body excluding the terminators of else and then regions from |
| // 'second_if' to 'first_if'. |
| void MoveBranches(TF::IfRegionOp first_if, TF::IfRegionOp second_if) { |
| Block& first_if_then_block = first_if.then_branch().front(); |
| auto& second_if_then_body = second_if.then_branch().front().getOperations(); |
| first_if_then_block.getOperations().splice( |
| first_if_then_block.without_terminator().end(), second_if_then_body, |
| second_if_then_body.begin(), std::prev(second_if_then_body.end())); |
| |
| Block& first_if_else_block = first_if.else_branch().front(); |
| auto& second_if_else_body = second_if.else_branch().front().getOperations(); |
| first_if_else_block.getOperations().splice( |
| first_if_else_block.without_terminator().end(), second_if_else_body, |
| second_if_else_body.begin(), std::prev(second_if_else_body.end())); |
| } |
| |
| // Check if the `last` IfRegion can be added to the segment of |
| // IfRegion start with `first` IfRegion. |
| bool CanAddToIfSegment( |
| IfOpIterConst first, IfOpIterConst last, |
| const llvm::SmallVector<mlir::TF::IfRegionOp, 8>& if_ops, |
| const std::unique_ptr<TF::SideEffectAnalysis>& side_effect_analysis) { |
| if (last == if_ops.end()) { |
| return false; |
| } |
| // downstream_if_ops contain ops in those IfRegions between first IfRegion |
| // and last IfRegion plus the ops in the last IfRegion. |
| llvm::SmallSetVector<Operation*, 4> downstream_if_ops; |
| |
| TF::IfRegionOp second_if_op = *last; |
| |
| for (auto iter = std::prev(last); std::next(iter) != first; iter--) { |
| TF::IfRegionOp first_if_op = *iter; |
| func::FuncOp func = first_if_op->getParentOfType<func::FuncOp>(); |
| const TF::SideEffectAnalysis::Info& analysis = |
| side_effect_analysis->GetAnalysisForFunc(func); |
| auto all_ops = GetAllOpsFromIf(*(std::next(iter))); |
| downstream_if_ops.insert(all_ops.begin(), all_ops.end()); |
| if (!SafeToMerge(first_if_op, second_if_op, downstream_if_ops, analysis)) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| // Return the iterator of the IfRegion Op. This is the last IfRegion |
| // in the segment. |
| // For example, we have the following sequence of IfRegions |
| // ````` |
| // 1 2 3 4 5 |
| // IfRegionA, IfRegionA, IfRegionA, IfRegionA, IfRegionA |
| // ````` |
| // The first three IfRegionA are in one group and the last two are in another |
| // group. Then when we call FindLastIfInSegment for the first segment, it |
| // will return iterator of the 3rd IfRegionA. |
| // In the same way, when we call it for the second segment, it will return |
| // iterator of the 5th IfRegionA. |
| IfOpIterConst FindLastIfInSegment( |
| IfOpIterConst first_if, |
| const llvm::SmallVector<mlir::TF::IfRegionOp, 8>& if_ops, |
| const std::unique_ptr<TF::SideEffectAnalysis>& side_effect_analysis) { |
| IfOpIterConst last_if = first_if; |
| for (; CanAddToIfSegment(first_if, last_if, if_ops, side_effect_analysis); |
| last_if = std::next(last_if)) { |
| } |
| return std::prev(last_if); |
| } |
| |
| // Returns a set of ops to be moved after merged IfRegion between two IfRegions. |
| absl::flat_hash_set<Operation*> GetMoveOpsBetweenTwoIfRegions( |
| Operation* result_op, Operation* after_op, |
| llvm::SmallSetVector<Operation*, 4> middle_if_ops, |
| const std::unique_ptr<TF::SideEffectAnalysis>& side_effect_analysis) { |
| Block* block = after_op->getBlock(); |
| std::queue<Operation*> queue; |
| absl::flat_hash_set<Operation*> visited; |
| absl::flat_hash_set<Operation*> moved_ops; |
| |
| func::FuncOp func = result_op->getParentOfType<func::FuncOp>(); |
| const TF::SideEffectAnalysis::Info& analysis = |
| side_effect_analysis->GetAnalysisForFunc(func); |
| |
| // Enqueue dependencies of source_op into queue. |
| auto enqueue_deps = [&](Operation* source_op) { |
| for (Operation* user : source_op->getUsers()) { |
| if (!visited.count(user) && !middle_if_ops.count(user)) { |
| visited.insert(user); |
| queue.push(user); |
| } |
| } |
| source_op->walk([&](Operation* walked_op) { |
| for (Operation* successor : analysis.DirectControlSuccessors(walked_op)) { |
| if (!source_op->isProperAncestor(successor)) { |
| if (!visited.count(successor) && !middle_if_ops.count(successor)) { |
| visited.insert(successor); |
| queue.push(successor); |
| } |
| } |
| } |
| }); |
| }; |
| enqueue_deps(result_op); |
| |
| while (!queue.empty()) { |
| auto* op = queue.front(); |
| queue.pop(); |
| while (op->getBlock() != block) op = op->getParentOp(); |
| if (op->isBeforeInBlock(after_op)) { |
| moved_ops.insert(op); |
| enqueue_deps(op); |
| } |
| } |
| return moved_ops; |
| } |
| |
| // Returns a vector that contains the ops to be moved after merged IfRegion. |
| // `sub_if_group` refers to a segment of IfRegions. |
| // The returned vector preserves op order. |
| llvm::SmallVector<Operation*, 8> GetMoveOpList( |
| llvm::SmallVector<TF::IfRegionOp, 8>& sub_if_group, |
| const std::unique_ptr<TF::SideEffectAnalysis>& side_effect_analysis) { |
| absl::flat_hash_set<Operation*> all_moved_ops; |
| Operation* last_if_op = sub_if_group.back().getOperation(); |
| llvm::SmallSetVector<Operation*, 4> middle_if_ops; |
| |
| // reversely calculate the all ops need to be moved because in this way, |
| // ops in the middle IfRegions can be easily obtained by simply adding to the |
| // current set. |
| for (auto it = std::prev(std::prev(sub_if_group.end())); |
| std::next(it) != sub_if_group.begin(); --it) { |
| auto op_list = GetMoveOpsBetweenTwoIfRegions( |
| it->getOperation(), last_if_op, middle_if_ops, side_effect_analysis); |
| all_moved_ops.insert(op_list.begin(), op_list.end()); |
| auto first_if_ops = GetAllOpsFromIf(*it); |
| middle_if_ops.insert(first_if_ops.begin(), first_if_ops.end()); |
| } |
| |
| llvm::SmallVector<Operation*, 8> moved_ops_ordered; |
| moved_ops_ordered.reserve(all_moved_ops.size()); |
| for (Operation& op : *last_if_op->getBlock()) { |
| if (all_moved_ops.count(&op)) { |
| moved_ops_ordered.push_back(&op); |
| } |
| } |
| |
| return moved_ops_ordered; |
| } |
| |
| // Generate the segments for each IfRegion groups. Each element in the segments |
| // are supposed to can be merged into one new IfRegion.`if_cond` refers to the |
| // if condition of the segment of IfRegions. `if_ops` refers to the segment of |
| // IfRegions. `merged_groups` refers to all segments of IfRegions. |
| // `moved_ops_groups` refers to the ops need to be moved after new merged |
| // IfRegions associated with each segment of IfRegions. |
| void GenerateSegmentsPerIfGroups( |
| const mlir::Value& if_cond, |
| const llvm::SmallVector<mlir::TF::IfRegionOp, 8>& if_ops, |
| const std::unique_ptr<TF::SideEffectAnalysis>& side_effect_analysis, |
| MapToRegionVec2D& merged_groups, MapToOperationVec2D& moved_ops_groups) { |
| auto it_merged = merged_groups.try_emplace(if_cond); |
| auto it_moved = moved_ops_groups.try_emplace(if_cond); |
| llvm::SmallVector<TF::IfRegionOp, 8> sub_merged_groups; |
| auto begin_if_op_iter = if_ops.begin(); |
| |
| while (begin_if_op_iter != if_ops.end()) { |
| auto current_last_if_op_iter = |
| FindLastIfInSegment(begin_if_op_iter, if_ops, side_effect_analysis); |
| assert(current_last_if_op_iter != if_ops.end()); |
| llvm::SmallVector<TF::IfRegionOp, 8> sub_if_group; |
| for (auto it = begin_if_op_iter; it != std::next(current_last_if_op_iter); |
| ++it) { |
| sub_if_group.push_back(*it); |
| } |
| it_merged.first->getSecond().push_back(sub_if_group); |
| it_moved.first->getSecond().push_back( |
| GetMoveOpList(sub_if_group, side_effect_analysis)); |
| begin_if_op_iter = std::next(current_last_if_op_iter); |
| } |
| } |
| |
| // Checks whether a return index should be kept for `current_if_op` by checking |
| // for results in `if_op_segment`. |
| llvm::SmallVector<int, 4> GetReturnIndicesToKeep( |
| TF::IfRegionOp current_if_op, |
| const llvm::SmallVector<TF::IfRegionOp, 8>& if_op_segment) { |
| llvm::SmallVector<int, 4> return_indices_to_keep; |
| auto is_op_inside_IfRegions = [&](Operation* op) { |
| for (auto& if_op : if_op_segment) { |
| if (if_op == current_if_op) { |
| continue; |
| } |
| if (if_op->isProperAncestor(op)) { |
| return true; |
| } |
| } |
| return false; |
| }; |
| for (auto& index_and_value : llvm::enumerate(current_if_op.getResults())) { |
| if (!llvm::all_of(index_and_value.value().getUsers(), |
| is_op_inside_IfRegions)) { |
| return_indices_to_keep.push_back(index_and_value.index()); |
| } |
| } |
| return return_indices_to_keep; |
| } |
| |
| // Return a vector of the return indices. |
| llvm::SmallVector<llvm::SmallVector<int, 4>> GetReturnIndicesVec( |
| const llvm::SmallVector<TF::IfRegionOp, 8>& if_op_segment) { |
| llvm::SmallVector<llvm::SmallVector<int, 4>> return_indices_vec; |
| for (auto it = if_op_segment.begin(); it != if_op_segment.end(); ++it) { |
| llvm::SmallVector<int, 4> indices_to_keep_vec = |
| GetReturnIndicesToKeep(*it, if_op_segment); |
| return_indices_vec.push_back(indices_to_keep_vec); |
| } |
| return return_indices_vec; |
| } |
| |
| // Replace the internal usage in each pair of IfRegions from top to bottom for |
| // both then branch and else branch. |
| void ReplaceInternalUsage(llvm::SmallVector<TF::IfRegionOp, 8>& if_op_segment) { |
| for (auto it = if_op_segment.begin(); it != if_op_segment.end(); ++it) { |
| for (auto it2 = std::next(it); it2 != if_op_segment.end(); ++it2) { |
| for (OpResult result : it->getResults()) { |
| replaceAllUsesInRegionWith( |
| result, |
| it->then_branch().front().getTerminator()->getOperand( |
| result.getResultNumber()), |
| it2->then_branch()); |
| replaceAllUsesInRegionWith( |
| result, |
| it->else_branch().front().getTerminator()->getOperand( |
| result.getResultNumber()), |
| it2->else_branch()); |
| } |
| } |
| } |
| } |
| |
| // Move ops in the `moved_ops_ordered` after `last_op`. |
| void MoveOpsAfter(Operation* last_op, |
| llvm::SmallVector<Operation*, 8>& moved_ops_ordered) { |
| auto block = last_op->getBlock(); |
| absl::flat_hash_set<Operation*> all_moved_ops(moved_ops_ordered.begin(), |
| moved_ops_ordered.end()); |
| moved_ops_ordered.clear(); |
| for (Operation& op : *block) { |
| // There are no mutations in the loop. So each call of `isBeforeInBlock` |
| // is O(1). |
| if (all_moved_ops.count(&op) && op.isBeforeInBlock(last_op)) { |
| moved_ops_ordered.push_back(&op); |
| } |
| } |
| // Move ops in order. |
| for (Operation* op : moved_ops_ordered) { |
| op->moveAfter(last_op); |
| last_op = op; |
| } |
| } |
| |
| // Replace all external usage for each IfRegion in the segment of IfRegions. |
| // `if_op_segment` refers to the segment of IfRegions, `new_if_op` refers to the |
| // new merged IfRegion, `return_indices` refers to the indices to be kept in new |
| // merged IfRegion. |
| void ReplaceExternalUsage( |
| llvm::SmallVector<TF::IfRegionOp, 8>& if_op_segment, |
| TF::IfRegionOp new_if_op, |
| llvm::SmallVector<llvm::SmallVector<int, 4>>& return_indices) { |
| int new_return_index = 0; |
| for (const auto& index_and_value : llvm::enumerate(if_op_segment)) { |
| auto old_if_op = index_and_value.value(); |
| for (int i : return_indices[index_and_value.index()]) { |
| old_if_op.getResult(i).replaceAllUsesWith( |
| new_if_op.getResult(new_return_index++)); |
| } |
| } |
| } |
| |
| // Update the moved op list to remove old IfRegions from the list and add new |
| // merged IfRegions. `old_to_new_IfRegions_map` refers to a map from old |
| // IfRegion to new merged IfRegion. `moved_ops_list` refers to the list of ops |
| // to be moved after new merged IfRegion. |
| void UpdateMovedOpList( |
| llvm::SmallDenseMap<Operation*, TF::IfRegionOp>& old_to_new_IfRegion_map, |
| llvm::SmallVector<Operation*, 8>& moved_ops_list) { |
| llvm::SmallDenseSet<TF::IfRegionOp> new_if_ops; |
| bool need_add_new_if_op = false; |
| for (auto iter = moved_ops_list.begin(); iter != moved_ops_list.end(); |
| iter++) { |
| if (old_to_new_IfRegion_map.count(*iter)) { |
| need_add_new_if_op = true; |
| auto new_if_op = old_to_new_IfRegion_map[*iter]; |
| new_if_ops.insert(new_if_op); |
| moved_ops_list.erase(iter--); |
| } |
| } |
| if (need_add_new_if_op) { |
| for (auto& new_if_op : new_if_ops) { |
| moved_ops_list.push_back(new_if_op.getOperation()); |
| } |
| } |
| } |
| |
| // Create the Yield ops for both branches with merged results. |
| // `builder` is the OpBuilder. |
| // `if_op_segment` refers to the segment of IfRegions to be merged. |
| // `return_indices` refers to the return indices to be kept in merged IfRegion |
| // `new_if_op` refers to the created new IfRegion |
| void CreateYieldOps( |
| OpBuilder& builder, llvm::SmallVector<TF::IfRegionOp, 8>& if_op_segment, |
| llvm::SmallVector<llvm::SmallVector<int, 4>>& return_indices, |
| TF::IfRegionOp new_if_op, TF::IfRegionOp first_if) { |
| llvm::SmallVector<Value, 4> merged_then_yield_values; |
| for (const auto& index_and_value : llvm::enumerate(if_op_segment)) { |
| auto if_op = index_and_value.value(); |
| for (auto i : return_indices[index_and_value.index()]) { |
| merged_then_yield_values.push_back( |
| if_op.then_branch().front().getTerminator()->getOperand(i)); |
| } |
| } |
| builder.setInsertionPointToEnd(&new_if_op.then_branch().front()); |
| builder.create<TF::YieldOp>( |
| first_if.then_branch().front().getTerminator()->getLoc(), |
| /*operands=*/merged_then_yield_values); |
| |
| llvm::SmallVector<Value, 4> merged_else_yield_values; |
| for (const auto& index_and_value : llvm::enumerate(if_op_segment)) { |
| auto if_op = index_and_value.value(); |
| for (auto i : return_indices[index_and_value.index()]) { |
| merged_else_yield_values.push_back( |
| if_op.else_branch().front().getTerminator()->getOperand(i)); |
| } |
| } |
| builder.setInsertionPointToEnd(&new_if_op.else_branch().front()); |
| builder.create<TF::YieldOp>( |
| first_if.else_branch().front().getTerminator()->getLoc(), |
| /*operands=*/merged_else_yield_values); |
| } |
| |
| // Merge the IfRegions in each segment. In the meantime, the old IfRegions in |
| // the segment will be added to `regions_to_remove`. They will be erased in the |
| // end. |
| // `if_op_segment` refers to segments of IfRegions. `moved_op_list` refers to |
| // the ops to be moved after new merged IfRegion. `regions_to_remove` refers to |
| // the regions to be removed from the `moved_ops_list`. |
| // `old_to_new_IfRegion_map` refers to a map from old IfRegion to new merged |
| // IfRegion. |
| void MergeIfPerSegment( |
| llvm::SmallVector<TF::IfRegionOp, 8>& if_op_segment, |
| llvm::SmallVector<Operation*, 8>& moved_ops_list, |
| llvm::SmallSetVector<TF::IfRegionOp, 8>& regions_to_remove, |
| llvm::SmallDenseMap<Operation*, TF::IfRegionOp>& old_to_new_IfRegion_map) { |
| TF::IfRegionOp first_if = if_op_segment[0]; |
| llvm::SmallVector<Type, 4> merged_return_types; |
| llvm::SmallVector<TF::IfRegionOp, 8> sources_if_ops( |
| std::next(if_op_segment.begin()), if_op_segment.end()); |
| |
| // Create new IfRegion's merged results. |
| auto return_indices = GetReturnIndicesVec(if_op_segment); |
| for (const auto& index_and_value : llvm::enumerate(return_indices)) { |
| TF::IfRegionOp if_op = if_op_segment[index_and_value.index()]; |
| for (auto i : index_and_value.value()) { |
| merged_return_types.push_back(if_op.getResult(i).getType()); |
| } |
| } |
| |
| // Create new IfRegion for merged all IfRegions in if_op_segmemt. |
| OpBuilder builder(first_if); |
| builder.setInsertionPoint(if_op_segment.back().getOperation()); |
| |
| auto new_if_op = builder.create<TF::IfRegionOp>( |
| first_if.getLoc(), merged_return_types, first_if.cond(), |
| llvm::all_of(if_op_segment, |
| [&](TF::IfRegionOp op) { return op.is_stateless(); }), |
| first_if._then_func_nameAttr(), first_if._else_func_nameAttr()); |
| new_if_op.then_branch().push_back(new Block); |
| new_if_op.else_branch().push_back(new Block); |
| |
| // Replace internal usages of merged if ops. |
| ReplaceInternalUsage(if_op_segment); |
| |
| // Replace external usages of merged if ops. |
| ReplaceExternalUsage(if_op_segment, new_if_op, return_indices); |
| |
| // Move ops after the new merged If region. |
| MoveOpsAfter(new_if_op.getOperation(), moved_ops_list); |
| |
| // Create the Yield ops for both branches with merged results. |
| CreateYieldOps(builder, if_op_segment, return_indices, new_if_op, first_if); |
| |
| for (auto& old_if_op : if_op_segment) { |
| MoveBranches(/*first_if=*/new_if_op, /*second_if=*/old_if_op); |
| } |
| |
| for (auto& old_if_op : if_op_segment) { |
| old_to_new_IfRegion_map[old_if_op.getOperation()] = new_if_op; |
| regions_to_remove.insert(old_if_op); |
| } |
| } |
| |
| // Merge IfRegions for each IfRegion group. Each IfRegion group contains |
| // several segments of IfRegions and each segment of IfRegions can be merged |
| // into one IfRegion. |
| // `if_cond` refers to the if condition of the segments of IfRegions. |
| // `planned_merged_groups` refers to the groups of IfRegions to be merged |
| // `moved_ops_groups` refers to the ops need to be moved after new merged |
| // IfRegions associated with each segment of IfRegions. |
| // `regions_to_remove` refers to the regions to be removed |
| // `old_to_new_IfRegion_map` refers to a map from old IfRegion to new merged |
| // IfRegion. |
| void MergeIfPerIfGroups( |
| const Value& if_cond, MapToRegionVec2D& planned_merged_groups, |
| MapToOperationVec2D& moved_ops_groups, |
| llvm::SmallSetVector<TF::IfRegionOp, 8>& regions_to_remove, |
| llvm::SmallDenseMap<Operation*, TF::IfRegionOp>& old_to_new_IfRegion_map) { |
| OperationVec2D& moved_ops_group = moved_ops_groups[if_cond]; |
| RegionVec2D& segments = planned_merged_groups[if_cond]; |
| |
| for (auto i = 0; i < segments.size(); ++i) { |
| if (segments[i].size() >= 2) { |
| UpdateMovedOpList(old_to_new_IfRegion_map, moved_ops_group[i]); |
| MergeIfPerSegment(segments[i], moved_ops_group[i], regions_to_remove, |
| old_to_new_IfRegion_map); |
| } |
| } |
| } |
| |
| // Groups IfRegions by common predicate and attemps to merge them. |
| void OptimizeIfRegions(Block* block, ModuleOp module) { |
| // Do side effect analysis only one time in the beginning |
| auto side_effect_analysis = std::make_unique<TF::SideEffectAnalysis>(module); |
| |
| // Determine IfRegions with the same predicate. |
| llvm::SmallDenseMap<Value, llvm::SmallVector<TF::IfRegionOp, 8>, 8> |
| grouped_if_ops; |
| llvm::SmallVector<Value, 4> if_cond_order; |
| block->walk([&](TF::IfRegionOp if_op) { |
| auto it = grouped_if_ops.try_emplace(if_op.cond()); |
| if (it.second) { |
| if_cond_order.push_back(if_op.cond()); |
| } |
| it.first->getSecond().push_back(if_op); |
| }); |
| |
| MapToRegionVec2D planned_merged_groups; |
| MapToOperationVec2D moved_ops_groups; |
| llvm::SmallSetVector<TF::IfRegionOp, 8> regions_to_remove; |
| llvm::SmallDenseMap<Operation*, TF::IfRegionOp> old_to_new_IfRegion_map; |
| |
| // For each if group, determine the segments of each if groups |
| // that can be merged and their related ops to be moved after |
| // the new generated IfRegions |
| // We cache the infomation into two maps: |
| // planned_merged_groups and moved_ops_groups |
| for (const auto& if_cond : if_cond_order) { |
| GenerateSegmentsPerIfGroups(if_cond, grouped_if_ops[if_cond], |
| side_effect_analysis, planned_merged_groups, |
| moved_ops_groups); |
| } |
| |
| // Merge IfRegions for each IfRegion groups. |
| for (const auto& if_cond : if_cond_order) { |
| MergeIfPerIfGroups(if_cond, planned_merged_groups, moved_ops_groups, |
| regions_to_remove, old_to_new_IfRegion_map); |
| } |
| |
| // Remove all old IfRegions that already been merged. |
| for (auto old_if_region : regions_to_remove) { |
| old_if_region.erase(); |
| } |
| } |
| |
| void MergeControlFlowPass::runOnOperation() { |
| ModuleOp module = getOperation(); |
| auto result = module.walk([&](tf_device::ClusterOp cluster) { |
| OptimizeIfRegions(&cluster.GetBody(), module); |
| return WalkResult::advance(); |
| }); |
| if (result.wasInterrupted()) return signalPassFailure(); |
| } |
| |
| } // namespace |
| |
| std::unique_ptr<OperationPass<ModuleOp>> CreateMergeControlFlowPass() { |
| return std::make_unique<MergeControlFlowPass>(); |
| } |
| |
| } // namespace TFDevice |
| } // namespace mlir |