| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/service/conditional_code_motion.h" |
| |
| #include <iterator> |
| #include <stack> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "absl/algorithm/container.h" |
| #include "absl/strings/numbers.h" |
| #include "absl/strings/str_cat.h" |
| #include "tensorflow/compiler/xla/debug_options_flags.h" |
| #include "tensorflow/compiler/xla/literal.h" |
| #include "tensorflow/compiler/xla/map_util.h" |
| #include "tensorflow/compiler/xla/service/call_graph.h" |
| #include "tensorflow/compiler/xla/service/call_inliner.h" |
| #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" |
| #include "tensorflow/compiler/xla/service/hlo_computation.h" |
| #include "tensorflow/compiler/xla/service/hlo_cse.h" |
| #include "tensorflow/compiler/xla/service/hlo_dce.h" |
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
| #include "tensorflow/compiler/xla/service/hlo_instructions.h" |
| #include "tensorflow/compiler/xla/service/hlo_opcode.h" |
| #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" |
| #include "tensorflow/compiler/xla/service/hlo_verifier.h" |
| #include "tensorflow/compiler/xla/service/tuple_simplifier.h" |
| #include "tensorflow/compiler/xla/shape_util.h" |
| #include "tensorflow/compiler/xla/status_macros.h" |
| #include "tensorflow/compiler/xla/statusor.h" |
| #include "tensorflow/compiler/xla/types.h" |
| #include "tensorflow/compiler/xla/util.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/platform/errors.h" |
| |
| namespace xla { |
| |
| namespace conditional_opt { |
| |
| class BoundaryVisitor { |
| public: |
| // start with an existing conditional computation. |
| explicit BoundaryVisitor(HloInstruction* conditional) { |
| Boundary b(Boundary::Position::kInsideBranch); |
| b.mutable_operands().push_back(conditional); |
| worklist_.push_back(b); |
| } |
| // Start with an empty work list. |
| BoundaryVisitor() {} |
| // Get next boundary to visit. |
| Boundary PopNextBoundary() { |
| CHECK(!worklist_.empty()); |
| Boundary b = worklist_.front(); |
| worklist_.pop_front(); |
| // if b is already visited, it must have multiple users and is already in |
| // new boundaries. Skip it. |
| while (!worklist_.empty() && ContainsKey(visited_, b)) { |
| b = worklist_.front(); |
| worklist_.pop_front(); |
| } |
| visited_.insert(b); |
| return b; |
| } |
| void AddToWorkList(const Boundary& b) { |
| CHECK(!b.operands().empty()); |
| worklist_.push_back(b); |
| } |
| |
| bool HasNextBoundary() { |
| while (!worklist_.empty()) { |
| Boundary b = worklist_.front(); |
| if (!ContainsKey(visited_, b)) { |
| break; |
| } |
| worklist_.pop_front(); |
| } |
| return !worklist_.empty(); |
| } |
| |
| private: |
| // worklist is the deque that contains instructions to be visited. |
| std::deque<Boundary> worklist_; |
| absl::flat_hash_set<Boundary> visited_; |
| }; |
| |
| template <class OpCollection> |
| int64_t CountNonLeafOps(const OpCollection& ops) { |
| absl::flat_hash_set<HloInstruction*> op_set; |
| for (auto op : ops) { |
| if (!op_set.contains(op) && op->opcode() != HloOpcode::kConstant) { |
| op_set.insert(op); |
| } |
| } |
| return op_set.size(); |
| } |
| |
| // Returns estimation of potential reuses carried by a given pair of |
| // instructions. Use different integers to classify different levels |
| // of reuses This is used as a placeholder only, assuming all |
| // instructions can be fused to enable data reuses |
| int64_t ReusesCarriedBy(HloOpcode op, HloOpcode user) { |
| // Reuses in some way work like forces that pull instructions |
| // towards each other. We use a number 0-10 to classify how strong the force |
| // is between a pair of operations. Given a group of instructions that can be |
| // moved together, if the forces inside a conditional are stronger, the group |
| // will be moved incide or remain inside the conditional; otherwise, it will |
| // be moved outside to or remain outside of the conditional. |
| switch (user) { |
| case HloOpcode::kGetTupleElement: |
| return 0; |
| case HloOpcode::kConvert: |
| // Because convert is treated not moveable when following Dot or |
| // convolution, here if op is dot or convolution, they must be separated |
| // by a conditional boundary. Here we do not try to pull convert inside |
| // conditionals to be together with the dot or convolution. |
| switch (op) { |
| case HloOpcode::kConvolution: |
| case HloOpcode::kDot: |
| return 0; |
| default: |
| break; |
| } |
| break; |
| default: |
| break; |
| } |
| switch (op) { |
| // These instructions do not carry weight of reuse themselves. |
| case HloOpcode::kParameter: |
| case HloOpcode::kConstant: |
| case HloOpcode::kGetTupleElement: |
| return 0; |
| case HloOpcode::kConditional: |
| return 10; |
| default: |
| return -10; |
| } |
| } |
| |
| // Returns true if `op` is worth hoisting. |
| bool WorthHoisting(HloOpcode op, HloOpcode child_op) { |
| // TOOD[b/169182921] The following cost model is rather incomplete. Will |
| // need to extend to cover most of element-wise ops. |
| switch (op) { |
| case HloOpcode::kConvert: |
| // If Convert is after AllReduce, it is worth moving out AllReduce |
| // out of conditional for AR/CRS combine. If Convert is after other |
| // ops such as Dot or Convolutional, it is better to keep convert |
| // within conditional so that convert can be fused with Dot or |
| // Convolutional. |
| switch (child_op) { |
| case HloOpcode::kAllReduce: |
| case HloOpcode::kReshape: |
| case HloOpcode::kGetTupleElement: |
| return true; |
| default: |
| return false; |
| } |
| case HloOpcode::kGetTupleElement: |
| switch (child_op) { |
| // do not move GTE if its operand is a parameter |
| case HloOpcode::kParameter: |
| return false; |
| default: |
| return true; |
| } |
| case HloOpcode::kAllReduce: |
| case HloOpcode::kReduceScatter: |
| case HloOpcode::kAbs: |
| case HloOpcode::kReduce: |
| case HloOpcode::kAdd: |
| case HloOpcode::kPower: |
| case HloOpcode::kCopy: |
| case HloOpcode::kConstant: |
| case HloOpcode::kSubtract: |
| case HloOpcode::kMultiply: |
| case HloOpcode::kDivide: |
| case HloOpcode::kTuple: |
| case HloOpcode::kSqrt: |
| case HloOpcode::kRsqrt: |
| case HloOpcode::kReshape: |
| case HloOpcode::kMinimum: |
| case HloOpcode::kMaximum: |
| return true; |
| default: |
| return false; |
| } |
| } |
| |
| // Compare if the instructions to be visited at each branches are identical. |
| bool InstructionWithinBranchIdentical( |
| const std::vector<HloInstruction*>& instructions, |
| bool is_layout_sensitive) { |
| // Identical includes the shape of each operands are equal. |
| auto eq_operand = [&](const HloInstruction* a, const HloInstruction* b) { |
| bool eq_operands = is_layout_sensitive |
| ? ShapeUtil::Equal(a->shape(), b->shape()) |
| : ShapeUtil::Compatible(a->shape(), b->shape()); |
| return eq_operands; |
| }; |
| |
| auto eq_computations = [](const HloComputation* a, const HloComputation* b) { |
| return *a == *b; |
| }; |
| |
| if (instructions.empty()) { |
| return false; |
| } |
| |
| if (instructions[0]->IsCrossModuleAllReduce()) { |
| return std::all_of( |
| instructions.begin(), instructions.end(), |
| [&](HloInstruction* instruction) { |
| if (!instruction->IsCrossModuleAllReduce()) { |
| return false; |
| } |
| auto old_channel_id = instruction->channel_id(); |
| instruction->set_channel_id(instructions[0]->channel_id()); |
| bool eq_instructions = instructions[0]->Identical( |
| *instruction, eq_operand, eq_computations, is_layout_sensitive); |
| instruction->set_channel_id(old_channel_id); |
| return eq_instructions; |
| }); |
| } |
| |
| return std::all_of(instructions.begin(), instructions.end(), |
| [&](HloInstruction* instruction) { |
| return instructions[0]->Identical( |
| *instruction, eq_operand, eq_computations, |
| is_layout_sensitive); |
| }); |
| } |
| |
| // Copy the boundary out of the conditional and update hoisted_boundaries. |
| void CopyOutOfConditional( |
| Boundary& boundary, HloInstruction* conditional, |
| absl::flat_hash_map<Boundary, Boundary>& hoisted_boundaries) { |
| CHECK(boundary.IsInsideBranch()); |
| absl::InlinedVector<HloInstruction*, 4> new_operands; |
| // All of the branch operands should have the same opcode and shape, so just |
| // use branch 0. |
| const HloInstruction* branch0_inst = boundary.operands()[0]; |
| for (int i = 0; i < branch0_inst->operands().size(); ++i) { |
| Boundary operand_boundary(boundary.GetPosition()); |
| for (HloInstruction* operand : boundary.operands()) { |
| operand_boundary.mutable_operands().push_back(operand->operands()[i]); |
| } |
| VLOG(2) << "Looking for: " << operand_boundary.ToString(); |
| auto hoisted_boundaries_it = hoisted_boundaries.find(operand_boundary); |
| CHECK(hoisted_boundaries_it != hoisted_boundaries.end()); |
| Boundary hoisted_boundary = hoisted_boundaries_it->second; |
| CHECK(hoisted_boundary.IsOutsideBranch()); |
| CHECK_EQ(hoisted_boundary.operands().size(), 1); |
| new_operands.push_back(hoisted_boundary.operands()[0]); |
| } |
| HloInstruction* new_instruction = conditional->parent()->AddInstruction( |
| branch0_inst->CloneWithNewOperands(branch0_inst->shape(), new_operands)); |
| VLOG(2) << "new instruction:" << new_instruction->ToString(); |
| // Maps the instruction outside of conditional to the instruction |
| // inside of the conditional. |
| Boundary hoisted_boundary(Boundary::Position::kOutsideBranch); |
| hoisted_boundary.mutable_operands().push_back(new_instruction); |
| hoisted_boundaries[boundary] = hoisted_boundary; |
| } |
| |
| // Copy the boundary into the conditional and update hoisted_boundaries. |
| void CopyIntoConditional( |
| Boundary& boundary, HloInstruction* conditional, |
| absl::flat_hash_map<Boundary, Boundary>& hoisted_boundaries) { |
| CHECK(boundary.IsOutsideBranch()); |
| CHECK_EQ(boundary.operands().size(), 1); |
| int num_branches = conditional->branch_count(); |
| std::vector<absl::InlinedVector<HloInstruction*, 4>> new_operands( |
| num_branches); |
| HloInstruction* op = boundary.operands()[0]; |
| for (HloInstruction* operand : op->operands()) { |
| Boundary operand_boundary(boundary.GetPosition()); |
| operand_boundary.mutable_operands().push_back(operand); |
| VLOG(2) << "Looking for: " << operand_boundary.ToString(); |
| auto hoisted_boundaries_it = hoisted_boundaries.find(operand_boundary); |
| if (hoisted_boundaries_it != hoisted_boundaries.end()) { |
| Boundary hoisted_boundary = hoisted_boundaries_it->second; |
| CHECK(hoisted_boundary.IsInsideBranch()); |
| CHECK_EQ(hoisted_boundary.operands().size(), num_branches); |
| for (int j = 0; j < num_branches; ++j) { |
| new_operands[j].push_back(hoisted_boundary.operands()[j]); |
| } |
| } else { |
| for (int j = 0; j < num_branches; ++j) { |
| switch (operand->opcode()) { |
| case HloOpcode::kConstant: { |
| auto new_operand = |
| conditional->branch_computation(j)->AddInstruction( |
| operand->Clone()); |
| VLOG(2) << "new instruction:" << new_operand->ToString(); |
| new_operands[j].push_back(new_operand); |
| break; |
| } |
| case HloOpcode::kGetTupleElement: { |
| auto gte = Cast<HloGetTupleElementInstruction>(operand); |
| int64_t index = gte->tuple_index(); |
| HloInstruction* root = |
| conditional->branch_computation(j)->root_instruction(); |
| CHECK(root->opcode() == HloOpcode::kTuple && |
| index < root->operand_count()) |
| << root->ToString() << " " << gte->ToString(); |
| auto new_operand = root->mutable_operand(index); |
| VLOG(2) << "new instruction:" << new_operand->ToString(); |
| new_operands[j].push_back(new_operand); |
| break; |
| } |
| default: |
| LOG(FATAL) << "Unexpected out-of-boundary instruction:" |
| << operand->ToString() << "\n"; |
| } |
| } |
| } |
| } |
| |
| Boundary hoisted_boundary(Boundary::Position::kInsideBranch); |
| for (int j = 0; j < num_branches; ++j) { |
| HloInstruction* new_instruction = |
| conditional->branch_computation(j)->AddInstruction( |
| op->CloneWithNewOperands(op->shape(), new_operands[j])); |
| VLOG(2) << "new instruction:" << new_instruction->ToString(); |
| hoisted_boundary.mutable_operands().push_back(new_instruction); |
| } |
| hoisted_boundaries[boundary] = hoisted_boundary; |
| } |
| |
| // Identify converts to be hoisted/rematerialized out of the branch |
| // computations. |
| absl::flat_hash_set<int64_t> FindSpecialConverts(HloInstruction* old_root, |
| int branch_count, |
| HloInstruction* conditional, |
| bool is_layout_sensitive) { |
| absl::flat_hash_set<int64_t> special_convert; |
| |
| // TODO(b/216487727): Allow hoisting converts that feed or fed by other |
| // converts by addressing possible duplicates left behind in the tuple output. |
| // The conditional code motion pass should handle these duplicates and hence, |
| // merging these snippets of code would be one alternative. |
| auto convert_invalid = |
| [](const HloInstruction* convert_set_candidate) -> bool { |
| bool invalid_user = absl::c_any_of( |
| convert_set_candidate->users(), [](const HloInstruction* user) -> bool { |
| return (user->opcode() == HloOpcode::kConvert); |
| }); |
| bool invalid_producer = |
| absl::c_any_of(convert_set_candidate->operands(), |
| [](const HloInstruction* operand) -> bool { |
| return (operand->opcode() == HloOpcode::kConvert); |
| }); |
| return (invalid_user || invalid_producer); |
| }; |
| |
| for (int64_t operand_num = 0; operand_num < old_root->operand_count(); |
| ++operand_num) { |
| if (old_root->operand(operand_num)->opcode() != HloOpcode::kConvert) { |
| continue; |
| } |
| bool replica = true; |
| HloInstruction* special_convert_candidate = |
| old_root->mutable_operand(operand_num); |
| // TODO(b/216487727): Remove duplicates in tuple outputs while hoisting. |
| auto repeated = |
| absl::c_count_if(old_root->operands(), |
| [&](const HloInstruction* operand) -> bool { |
| return (special_convert_candidate == operand); |
| }) > 1; |
| if (convert_invalid(special_convert_candidate) || repeated) { |
| continue; |
| } |
| // Check whether an identical candidate appears in other branches |
| for (int others = 1; others < branch_count; ++others) { |
| HloInstruction* others_root = |
| conditional->branch_computation(others)->root_instruction(); |
| const HloInstruction* other_convert = others_root->operand(operand_num); |
| if (other_convert->opcode() != HloOpcode::kConvert || |
| convert_invalid(other_convert)) { |
| replica = false; |
| break; |
| } |
| // Do not move converts if their operands have different shapes in |
| // different branches. |
| bool eq_shape = |
| is_layout_sensitive |
| ? ShapeUtil::Equal(other_convert->shape(), |
| special_convert_candidate->shape()) && |
| ShapeUtil::Equal( |
| other_convert->operand(0)->shape(), |
| special_convert_candidate->operand(0)->shape()) |
| : ShapeUtil::Compatible(other_convert->shape(), |
| special_convert_candidate->shape()) && |
| ShapeUtil::Compatible( |
| other_convert->operand(0)->shape(), |
| special_convert_candidate->operand(0)->shape()); |
| if (!eq_shape) { |
| replica = false; |
| break; |
| } |
| auto repeated = |
| absl::c_count_if(others_root->operands(), |
| [&](const HloInstruction* operand) -> bool { |
| return (special_convert_candidate == operand); |
| }) > 1; |
| if (repeated) { |
| replica = false; |
| break; |
| } |
| } |
| if (replica) { |
| special_convert.insert(operand_num); |
| } |
| } |
| return special_convert; |
| } |
| |
| // Restructuring the conditional instruction as follows: |
| // i.e., %result = conditional() becomes |
| // x = conditional() |
| // y.{0..n} = gte(x, {0..n}) |
| // z = tuple(y.0, y.1, ...y.n) |
| // Doing so ensures that we can accommodate the possible shape-change of the |
| // conditional when the instructions are hoisted. |
| Status RestructureConditionalInstruction(HloComputation* computation, |
| HloInstruction* conditional) { |
| HloInstruction* old_root = computation->root_instruction(); |
| std::vector<HloInstruction*> new_operands; |
| int cur_index = 0; |
| for (; cur_index < ShapeUtil::TupleElementCount(conditional->shape()); |
| ++cur_index) { |
| new_operands.push_back( |
| computation->AddInstruction(HloInstruction::CreateGetTupleElement( |
| ShapeUtil::GetTupleElementShape(conditional->shape(), cur_index), |
| conditional, cur_index))); |
| } |
| HloInstruction* new_tuple = |
| computation->AddInstruction(HloInstruction::CreateTuple(new_operands)); |
| if (old_root == conditional) { |
| computation->set_root_instruction(new_tuple); |
| } else { |
| std::vector<HloInstruction*> new_tuple_users; |
| for (auto conditional_user : conditional->users()) { |
| auto is_new_gte = absl::c_find_if( |
| new_operands, |
| [&](HloInstruction* instr) { return instr == conditional_user; }); |
| if (is_new_gte == new_operands.end()) { |
| new_tuple_users.push_back(conditional_user); |
| } |
| } |
| for (auto new_tuple_user : new_tuple_users) { |
| TF_RETURN_IF_ERROR( |
| conditional->ReplaceUseWith(new_tuple_user, new_tuple)); |
| } |
| } |
| VLOG(2) << "computation after root restructure:\n" << computation->ToString(); |
| return Status::OK(); |
| } |
| |
| StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional, |
| bool is_layout_sensitive) { |
| int branch_count = conditional->branch_count(); |
| if (branch_count <= 0) { |
| return false; |
| } |
| |
| // Determining whether all branch roots are tuples |
| for (int branch_num = 0; branch_num < branch_count; ++branch_num) { |
| HloInstruction* branch_root = |
| conditional->branch_computation(branch_num)->root_instruction(); |
| if (branch_root->opcode() != HloOpcode::kTuple) { |
| return false; |
| } |
| } |
| |
| HloInstruction* old_root = |
| conditional->branch_computation(0)->root_instruction(); |
| VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString(); |
| // Identify the gte using `index'. |
| auto find_gte = [](const HloInstruction* conditional_result, |
| int64_t index) -> HloInstruction* { |
| for (HloInstruction* instr : conditional_result->users()) { |
| if (instr->opcode() != HloOpcode::kGetTupleElement) { |
| return nullptr; |
| } |
| if (instr->tuple_index() == index) { |
| return instr; |
| } |
| } |
| return nullptr; |
| }; |
| |
| // Captures tuple indices refering to converts to be rematerialized/hoisted. |
| absl::flat_hash_set<int64_t> special_convert = FindSpecialConverts( |
| old_root, branch_count, conditional, is_layout_sensitive); |
| |
| // Exit if we cannot find any converts to be hoisted. |
| if (special_convert.empty()) { |
| return false; |
| } |
| |
| TF_RETURN_IF_ERROR( |
| RestructureConditionalInstruction(conditional->parent(), conditional)); |
| |
| for (int branch = 0; branch < branch_count; branch++) { |
| old_root = conditional->branch_computation(branch)->root_instruction(); |
| absl::flat_hash_map<HloInstruction*, int64_t> map_inst_to_tuple_index; |
| std::vector<HloInstruction*> new_operands(old_root->operand_count()); |
| absl::flat_hash_set<HloInstruction*> to_hoist_set; |
| |
| for (int64_t operand_num = 0; operand_num < old_root->operand_count(); |
| ++operand_num) { |
| map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] = |
| operand_num; |
| } |
| for (int64_t operand_num = 0; operand_num < old_root->operand_count(); |
| ++operand_num) { |
| HloInstruction* hoist = old_root->mutable_operand(operand_num); |
| if (!special_convert.contains(operand_num)) { |
| new_operands[operand_num] = old_root->mutable_operand(operand_num); |
| continue; |
| } |
| |
| to_hoist_set.insert(hoist); |
| int64_t new_tuple_count = old_root->operand_count(); |
| |
| // Replace the hoisted instr in the tuple with the operand/operands. |
| // We will replace at least one of the operands of the hoist at the |
| // tuple place; the rest will be added at the end. |
| bool inplace = true; |
| CHECK(!hoist->operands().empty()); |
| for (HloInstruction* prod : hoist->operands()) { |
| if (inplace) { |
| map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist]; |
| new_operands[map_inst_to_tuple_index[hoist]] = prod; |
| inplace = false; |
| } else { |
| map_inst_to_tuple_index[prod] = new_tuple_count++; |
| new_operands.push_back(prod); |
| } |
| } |
| } |
| |
| // Create the new root instruction. |
| HloComputation* cur_branch = conditional->branch_computation(branch); |
| HloInstruction* new_branch_root = |
| cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands)); |
| // The shape can vary since the operands to convert are now |
| // being returned through the branches' root. |
| cur_branch->set_root_instruction(new_branch_root, true /*new shape*/); |
| TF_CHECK_OK(cur_branch->RemoveInstruction(old_root)); |
| |
| // Only one of the branches needs to change the conditional->parent(). |
| if (branch != 0) { |
| continue; |
| } |
| HloComputation* conditional_parent = conditional->parent(); |
| HloInstruction* newconditional = |
| conditional_parent->AddInstruction(HloInstruction::CreateConditional( |
| cur_branch->root_instruction()->shape(), |
| conditional->mutable_operand(0), |
| absl::MakeSpan(conditional->branch_computations()), |
| absl::MakeSpan(conditional->operands()).subspan(1))); |
| // Ensure that all the users of conditional refer to the new one. |
| TF_RETURN_IF_ERROR( |
| conditional->ReplaceAllUsesWithDifferentShape(newconditional)); |
| TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional)); |
| conditional = newconditional; |
| // Add the hoisted instructions in the parent. |
| for (HloInstruction* hoist : to_hoist_set) { |
| VLOG(2) << "Hoisting instruction:" << hoist->ToString(); |
| int64_t hoist_index = map_inst_to_tuple_index[hoist]; |
| // Find out the gte that captured the hoisted instr result. |
| HloInstruction* gte_hoist = find_gte(conditional, hoist_index); |
| CHECK(gte_hoist != nullptr); |
| std::vector<HloInstruction*> new_operands; |
| for (HloInstruction* op : hoist->operands()) { |
| HloInstruction* gte = conditional_parent->AddInstruction( |
| HloInstruction::CreateGetTupleElement(op->shape(), conditional, |
| map_inst_to_tuple_index[op])); |
| new_operands.push_back(gte); |
| } |
| HloInstruction* hoisted = conditional_parent->AddInstruction( |
| hoist->CloneWithNewOperands(hoist->shape(), new_operands)); |
| VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString(); |
| TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted)); |
| TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist)); |
| } |
| // No need to explicitly delete a hoisted instruction since if its dead |
| // then the subsequent DCE will remove it. |
| } |
| VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString(); |
| return true; |
| } |
| |
| // Hoist identical ops out of the conditional. The definition of identical |
| // are the shape of the operands are identical and their properties are |
| // identical. Will start from the root instruction of each branch and get |
| // the identical ops to hoist. |
| StatusOr<bool> ConditionalCodeMotion::MoveInstructionOut( |
| HloInstruction* conditional, std::vector<Boundary>& to_move_out, |
| std::vector<Boundary>& new_boundaries) { |
| if (to_move_out.empty()) { |
| return false; |
| } |
| VLOG(1) << "Modifying code--number of boundaries to move out:" |
| << to_move_out.size() << "\n"; |
| HloComputation* conditional_parent = conditional->parent(); |
| // save the old users before add new conditional user instructions |
| std::vector<HloInstruction*> old_conditional_users = conditional->users(); |
| // Maps boundaries in the conditional body to boundaries hoisted outside |
| // the conditional that compute the same value. |
| absl::flat_hash_map<Boundary, Boundary> hoisted_boundaries; |
| // Insert GetTupleElement before the instructions whose operands might still |
| // be within the conditional. |
| VLOG(1) << "before opt:" |
| << conditional_parent->ToString(HloPrintOptions::Fingerprint()) |
| << "\n"; |
| int64_t op_index = 0; |
| for (const Boundary& b : new_boundaries) { |
| HloInstruction* op = b.operands()[0]; |
| CHECK(op != nullptr); |
| VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n"; |
| HloInstruction* gtr = conditional_parent->AddInstruction( |
| HloInstruction::CreateGetTupleElement(op->shape(), conditional, |
| op_index++)); |
| Boundary b2(Boundary::Position::kOutsideBranch); |
| b2.mutable_operands().push_back(gtr); |
| hoisted_boundaries[b] = b2; |
| } |
| // Copy boundary instructions out of the conditional. |
| // Visit the operands before its users and copy it, so that the copied |
| // user will point to the correct operand. |
| for (int64_t i = to_move_out.size() - 1; i >= 0; i--) { |
| CopyOutOfConditional(to_move_out[i], conditional, hoisted_boundaries); |
| } |
| VLOG(2) << "Done copy branch instructions out\n" |
| << conditional_parent->ToString(HloPrintOptions::Fingerprint()) |
| << "\n"; |
| // Change original users of the conditional to use the correct operands. |
| for (auto user_instr : old_conditional_users) { |
| VLOG(2) << "Checking conditional user: " << user_instr->ToString() << "\n"; |
| CHECK(user_instr->opcode() == HloOpcode::kGetTupleElement); |
| auto tuple_opd = static_cast<HloGetTupleElementInstruction*>(user_instr); |
| int64_t index = tuple_opd->tuple_index(); |
| Boundary old_user_boundary(Boundary::Position::kInsideBranch); |
| for (const HloComputation* called_computation : |
| conditional->called_computations()) { |
| HloInstruction* root = called_computation->root_instruction(); |
| CHECK(root->operands().size() > index); |
| old_user_boundary.mutable_operands().push_back(root->operands()[index]); |
| } |
| CHECK(ContainsKey(hoisted_boundaries, old_user_boundary)); |
| HloInstruction* new_opd = |
| hoisted_boundaries[old_user_boundary].operands()[0]; |
| CHECK(new_opd != nullptr); |
| VLOG(2) << "Try replace all uses of :" << old_user_boundary.ToString() |
| << "\n"; |
| TF_RETURN_IF_ERROR(user_instr->ReplaceAllUsesWith(new_opd)); |
| TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(user_instr)); |
| } |
| VLOG(2) << "Done changing conditional users\n" |
| << conditional_parent->ToString() << "\n"; |
| // Create tuple element within each branch and set it as root. |
| int64_t branch_count = conditional->branch_count(); |
| for (int i = 0; i < branch_count; i++) { |
| auto computation = conditional->branch_computation(i); |
| std::vector<HloInstruction*> elements; |
| for (const auto& b1 : new_boundaries) { |
| HloInstruction* op = b1.operands()[i]; |
| CHECK(op != nullptr); |
| VLOG(2) << "Adding to root " << i << " with " << op->ToString() << "\n"; |
| elements.push_back(op); |
| } |
| HloInstruction* tuple = |
| computation->AddInstruction(HloInstruction::CreateTuple(elements)); |
| computation->set_root_instruction(tuple, true); |
| VLOG(2) << "computation is :" << computation->ToString() << "\n"; |
| // Remove hoisted instructions from the branches. |
| for (const auto& b2 : to_move_out) { |
| auto instr_to_remove = b2.operands()[i]; |
| // Double check to make sure it is safe to delete the instruction. |
| // Complications may arise due to some operations in the alternative |
| // branches (branches 1..n) being placed into the boundaries multiple |
| // times. |
| if (!computation->IsMarkedAsDead(instr_to_remove) && |
| instr_to_remove->user_count() == 0) { |
| VLOG(2) << "Removing boundary:" << b2.ToString() << "\n"; |
| TF_RETURN_IF_ERROR(computation->RemoveInstruction(instr_to_remove)); |
| } |
| } |
| } |
| // Change conditional instruction shape to the shape of the new root. |
| HloInstruction* new_root = |
| conditional->branch_computation(0)->root_instruction(); |
| *conditional->mutable_shape() = new_root->shape(); |
| // Keep conditional instruction sharding consistent with the branches. Note |
| // that this sharding could be lost after this pass. |
| conditional->set_sharding(new_root->sharding_ptr()); |
| VLOG(1) << "done moving instructions out of branches\n" |
| << conditional_parent->ToString(HloPrintOptions::Fingerprint()) |
| << "\n"; |
| return true; |
| } |
| |
| // Hoist ops from outside of the conditional to inside the branches. |
| StatusOr<bool> ConditionalCodeMotion::MoveInstructionIn( |
| HloInstruction* conditional, std::vector<Boundary>& to_move_in, |
| std::vector<Boundary>& new_boundaries) { |
| if (to_move_in.empty()) { |
| return false; |
| } |
| VLOG(1) << "Modifying code---number of boundaries to move in:" |
| << to_move_in.size() << "\n"; |
| VLOG(1) << "before opt:" |
| << conditional->parent()->ToString(HloPrintOptions::Fingerprint()) |
| << "\n"; |
| // Mapping boundaries to be moved to their new representations. |
| absl::flat_hash_map<Boundary, Boundary> hoisted_boundaries; |
| int64_t to_move_in_size = to_move_in.size(); |
| int64_t branch_count = conditional->branch_count(); |
| HloGetTupleElementInstruction* tuple_use = |
| DynCast<HloGetTupleElementInstruction>(to_move_in[0].operands()[0]); |
| // If use_index is -1, the old conditional root entry used by to_move_in |
| // instructions still need to be included as an entry of the modified |
| // conditional root, and the new result of the to_move_in instructions |
| // need to be added as an extra entry of the modified root; otherwise, the |
| // old root entry will be replaced with the new result in the modified root. |
| // The entry replacement should be allowed only if tuple_use has <=1 users. |
| int64_t use_index = (tuple_use != nullptr && tuple_use->user_count() == 1) |
| ? tuple_use->tuple_index() |
| : -1; |
| VLOG(2) << "Tuple use index = " << use_index << "\n"; |
| // Number of old conditional entries still to be used outside. |
| // If conditional shape is not tuple, will create a tuple and use subscript |
| // 0 to save the old operand being used. |
| int64_t op_index = |
| conditional->shape().IsTuple() |
| ? ((use_index >= 0) ? conditional->shape().tuple_shapes_size() - 1 |
| : conditional->shape().tuple_shapes_size()) |
| : 0; |
| // Use to map the tuple_use instruction to its operand; |
| Boundary b_opd_use(Boundary::Position::kInsideBranch); |
| Boundary b_old_root(Boundary::Position::kInsideBranch); |
| // Create a new root instruction in each branch. |
| for (int i = 0; i < branch_count; i++) { |
| auto computation = conditional->branch_computation(i); |
| auto old_root = computation->root_instruction(); |
| b_old_root.mutable_operands().push_back(old_root); |
| std::vector<HloInstruction*> operands; |
| if (old_root->opcode() == HloOpcode::kTuple) { |
| // Use operands of old_root directly, so old_root can be removed later. |
| for (int i = 0; i < old_root->operand_count(); ++i) { |
| if (i != use_index) { |
| operands.push_back(old_root->operands()[i]); |
| } else { // Map conditional use to the tuple operand. |
| b_opd_use.mutable_operands().push_back(old_root->operands()[i]); |
| } |
| } |
| } else if (old_root->shape().IsTuple()) { |
| // If old_root is not a kTuple but has tuple shape, elements within the |
| // tuple must be extracted first to be used by the new instructions. |
| const Shape& old_shape = old_root->shape(); |
| for (int i = 0; i < old_shape.tuple_shapes_size(); ++i) { |
| auto element = |
| computation->AddInstruction(HloInstruction::CreateGetTupleElement( |
| old_shape.tuple_shapes(i), old_root, i)); |
| if (i != use_index) { |
| operands.push_back(element); |
| } else { |
| b_opd_use.mutable_operands().push_back(element); |
| } |
| } |
| } else { |
| // If old_root is not a tuple and does not have tuple shape, use it |
| // to replace the conditional directly in the new computation. |
| b_opd_use.mutable_operands().push_back(conditional); |
| } |
| |
| HloInstruction* new_root = |
| computation->AddInstruction(HloInstruction::CreateTuple(operands)); |
| VLOG(2) << "setting new root: " << new_root->ToString() << "\n"; |
| computation->set_root_instruction(new_root, |
| /*accept_different_shape*/ true); |
| if (old_root->opcode() == HloOpcode::kTuple) { |
| TF_RETURN_IF_ERROR(computation->RemoveInstruction(old_root)); |
| } |
| VLOG(2) << "new branch computation: " << computation->ToString() << "\n"; |
| } |
| // Update get tuple element index of the conditional. |
| if (use_index != -1) { |
| for (auto* user : conditional->users()) { |
| if (user->opcode() == HloOpcode::kGetTupleElement && |
| user->tuple_index() > use_index) { |
| user->set_tuple_index(user->tuple_index() - 1); |
| } |
| } |
| } |
| Boundary conditional_boundary(Boundary::Position::kOutsideBranch); |
| conditional_boundary.mutable_operands().push_back(conditional); |
| hoisted_boundaries[conditional_boundary] = b_old_root; |
| int64_t cp_start = 0; |
| if (use_index >= 0) { |
| VLOG(2) << "Mapping GTE: " << tuple_use->ToString() << "\n"; |
| Boundary tuple_use_boundary(Boundary::Position::kOutsideBranch); |
| tuple_use_boundary.mutable_operands().push_back(tuple_use); |
| hoisted_boundaries[tuple_use_boundary] = b_opd_use; |
| } |
| cp_start = (tuple_use != nullptr) ? 1 : 0; |
| for (int64_t to_move_index = cp_start; to_move_index < to_move_in_size; |
| to_move_index++) { |
| Boundary b_to_move = to_move_in[to_move_index]; |
| HloInstruction* op = b_to_move.operands()[0]; |
| CHECK(op != nullptr); |
| bool to_be_used_outside = true; |
| VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n"; |
| if (to_move_index < to_move_in_size - 1 && op->user_count() == 1 && |
| op->users()[0] == to_move_in[to_move_index + 1].operands()[0]) { |
| to_be_used_outside = false; |
| VLOG(2) << "Instruction is not to be used outside the branch\n"; |
| } |
| Boundary b(Boundary::Position::kInsideBranch); |
| CopyIntoConditional(b_to_move, conditional, hoisted_boundaries); |
| if (to_be_used_outside) { |
| for (int i = 0; i < branch_count; ++i) { |
| auto computation = conditional->branch_computation(i); |
| auto new_op = hoisted_boundaries[b_to_move].operands()[i]; |
| auto new_root = computation->root_instruction(); |
| new_root->AppendOperand(new_op); |
| *new_root->mutable_shape()->add_tuple_shapes() = new_op->shape(); |
| VLOG(2) << "Extending conditional root " << i << " : " |
| << new_root->ToString() << "\n"; |
| } |
| // Modify uses of instructions outside of the conditionals |
| HloInstruction* gtr = conditional->parent()->AddInstruction( |
| HloInstruction::CreateGetTupleElement(op->shape(), conditional, |
| op_index++)); |
| TF_RETURN_IF_ERROR(op->ReplaceAllUsesWith(gtr)); |
| if (conditional->parent()->root_instruction() == op) { |
| conditional->parent()->set_root_instruction(gtr); |
| } |
| } |
| } |
| VLOG(2) << "Done copying instructions inside branch: " |
| << conditional->ToString(HloPrintOptions::Fingerprint()) << "\n"; |
| // Change conditional instruction shape to the shape of the new root. |
| HloInstruction* new_root = |
| conditional->branch_computation(0)->root_instruction(); |
| *conditional->mutable_shape() = new_root->shape(); |
| // Keep conditional instruction sharding consistent with the branches. Note |
| // that this sharding could be lost after this pass. |
| conditional->set_sharding(new_root->sharding_ptr()); |
| VLOG(2) << "Before removing instructions:" |
| << conditional->parent()->ToString() << "\n"; |
| // Remove hoisted instructions from the branches. |
| for (int64_t i = to_move_in_size - 1; i >= 0; i--) { |
| Boundary boundary_to_move_in = to_move_in[i]; |
| HloInstruction* op = boundary_to_move_in.operands()[0]; |
| if (op->user_count() == 0) { |
| VLOG(2) << "Removing boundary:" << boundary_to_move_in.ToString() << "\n"; |
| TF_RETURN_IF_ERROR(conditional->parent()->RemoveInstruction(op)); |
| VLOG(2) << "Done removing boundary.\n"; |
| } |
| } |
| |
| // Reset shapes of user gtes to the new shape. |
| if (use_index != -1) { |
| for (auto* user : conditional->users()) { |
| if (user->opcode() == HloOpcode::kGetTupleElement) { |
| VLOG(2) << "Resetting shape of user: " << user->ToString() << "\n"; |
| *user->mutable_shape() = |
| conditional->shape().tuple_shapes(user->tuple_index()); |
| } |
| } |
| } |
| VLOG(1) << "Done moving instructions inside branches\n" |
| << conditional->parent()->ToString(HloPrintOptions::Fingerprint()) |
| << "\n"; |
| return true; |
| } |
| |
| // Group single chains of operands or uses of boundaries into new boundaries |
| class GroupConnectedBoundaries { |
| private: |
| std::vector<Boundary> connected_boundaries_, new_boundaries_; |
| HloInstruction* conditional_; |
| HloComputation* conditional_parent_; |
| bool is_layout_sensitive_; |
| // Instructions that have been visited but are not going to be moved. |
| absl::flat_hash_map<HloInstruction*, int>& visited_count_; |
| // The following four lines are configurations of the cost model, which will |
| // be used to determine whether to move an instruction (move_config_) and how |
| // strongly preferred it is to keep a pair of ops together (reuse_config_). |
| // The search_config_ is used to control how to navigate the search space of |
| // the cost model in the context of auto/manual tuning. The flipped array is |
| // used to save which entries in the configuration have been changed in the |
| // search/tuning process. |
| std::vector<std::vector<int64_t>>& move_config_; |
| std::vector<std::vector<int64_t>>& reuse_config_; |
| std::vector<int64_t>& search_config_vec_; |
| int64_t* search_config_; |
| int64_t search_subscript_; |
| absl::flat_hash_map<const int64_t*, int64_t> flipped_; |
| |
| // The FlipMutation function serves to implement the search of alternative |
| // cost models by deciding whether to flip a given configuration, saved in |
| // the loc parameter. The non_zero parameter provides the new value to use |
| // to flip a zero. The msg parameter is only used for debugging purpposes. |
| int64_t FlipMutation(int64_t* loc, const int64_t non_zero, |
| const std::string& msg) { |
| if (search_config_ == 0 || ContainsKey(flipped_, loc)) { |
| VLOG(2) << "Configured not to search or loc is already flipped."; |
| return *loc; |
| } |
| // The last 8 digits control when to start the first flip. |
| int c = ConditionalCodeMotion::flip_start(*search_config_); |
| VLOG(2) << "flip start index = " << c << "\n"; |
| // Only flip the decision if c reaches 0. |
| if (c > 0) { |
| (*search_config_)--; |
| return *loc; |
| } |
| // The 8-16 digits control the maximum number of times to flip a config. |
| auto flip_count = ConditionalCodeMotion::DecrementMaxFlip(search_config_); |
| VLOG(2) << "max flip count = " << flip_count << "\n"; |
| VLOG(2) << "Updating max Flipping configuration = " << *search_config_ |
| << "\n"; |
| if (flip_count == 0) { |
| VLOG(2) << "Maximum flip count has reached. "; |
| if (search_subscript_ + 1 < search_config_vec_.size()) { |
| VLOG(2) << "search_subscript_ = " << search_subscript_; |
| VLOG(2) << "search config vec size = " << search_config_vec_.size(); |
| search_config_ = &search_config_vec_[++search_subscript_]; |
| } else { |
| return *loc; |
| } |
| } |
| // Reload the 16-23 digits of the configuration, which controls how |
| // frequently a configuration should be flipped. |
| auto flip_stride = ConditionalCodeMotion::flip_stride(*search_config_); |
| *search_config_ += flip_stride; |
| VLOG(2) << "flip stride = " << flip_stride << "\n"; |
| VLOG(2) << "Updating Flipping Stride = " << *search_config_ << "\n"; |
| |
| flipped_[loc] = *loc; |
| // Copy the last 8 bits back to the first 8 bits of configuration. |
| switch (*loc) { |
| case 0: |
| *loc = non_zero; |
| break; |
| default: |
| *loc = 0; |
| break; |
| } |
| VLOG(2) << "Flipping decision for: " << msg << ": from " << flipped_[loc] |
| << " to " << *loc << "\n"; |
| return *loc; |
| } |
| |
| public: |
| explicit GroupConnectedBoundaries( |
| HloInstruction* conditional, bool is_layout_sensitive, |
| absl::flat_hash_map<HloInstruction*, int>& visited_count, |
| std::vector<std::vector<int64_t>>* move_config, |
| std::vector<std::vector<int64_t>>* reuse_config, |
| std::vector<int64_t>* search_config) |
| : conditional_(conditional), |
| conditional_parent_(conditional->parent()), |
| is_layout_sensitive_(is_layout_sensitive), |
| visited_count_(visited_count), |
| move_config_(*move_config), |
| reuse_config_(*reuse_config), |
| search_config_vec_(*search_config), |
| search_subscript_(0) { |
| VLOG(2) << "Initializing Group Connected Boundaries\n"; |
| CHECK_NE(search_config, nullptr); |
| if (search_config_vec_.empty()) { |
| search_config_vec_.push_back(0); |
| } |
| search_config_ = &search_config_vec_[0]; |
| } |
| // Returns estimation of potential reuses carried by a given pair of |
| // instructions. Use different integers to classify different levels |
| // of reuses. Assume all instructions can be fused to enable data reuses. |
| int64_t ReusesCarriedBy(HloInstruction* op, HloInstruction* user) { |
| std::vector<int64_t>& curconfig = |
| reuse_config_[static_cast<uint32_t>(op->opcode())]; |
| // Flip the reuse configuration if tuning the cost model. |
| // When flipping, use -10 if flipping to the default reuse model. Other |
| // values can be specified if needed to fine-control the decision making. |
| int64_t config = |
| ((*search_config_) < 0) |
| ? FlipMutation(&curconfig[static_cast<uint32_t>(user->opcode())], |
| -10, |
| HloOpcodeString(op->opcode()) + "->" + |
| HloOpcodeString(user->opcode())) |
| : curconfig[static_cast<uint32_t>(user->opcode())]; |
| VLOG(2) << "ConditionalCodeMotion: Add reuses carried by instr: " |
| << op->ToString() << "=>" << user->ToString() << " : " << config |
| << "\n"; |
| if (config < 0) { |
| // Assume the reuse decreases with increasing user count. |
| int count1 = CountNonLeafOps(op->users()); |
| int count2 = CountNonLeafOps(user->operands()); |
| return (-config) / count1 / count2; |
| } |
| return config; |
| } |
| void clear_recently_visited() { |
| for (const auto& boundary : new_boundaries_) { |
| visited_count_.erase(boundary.operands()[0]); |
| } |
| } |
| // Returns true if `instruction` is worth hoisting. |
| bool WorthHoisting(HloInstruction* instruction, bool is_inside_branch) { |
| // This is needed for the "moving-in" transformation, to prevent the root |
| // of the parent computation (which contains the conditional) to be moved |
| // inside the conditional. |
| HloOpcode opcode = instruction->opcode(); |
| if (opcode == HloOpcode::kTuple && |
| instruction == conditional_parent_->root_instruction()) { |
| return false; |
| } |
| // It is not safe to move collective ops from outside to inside |
| // conditional branches, as it may cause synchronization problems, |
| // when different layouts are assigned to different branches. |
| if (DynCast<HloCollectiveInstruction>(instruction) && !is_inside_branch) { |
| return false; |
| } |
| |
| // It is not legal to move the parameter instructions. |
| if (opcode == HloOpcode::kParameter) { |
| return false; |
| } |
| |
| // Use configuration given from outside (e.g., by autotuner). |
| std::vector<int64_t>& curconfig = |
| move_config_[static_cast<uint32_t>(opcode)]; |
| auto col = (curconfig.size() == 1) ? 0 |
| : (instruction->operand_count() > 0) |
| ? static_cast<uint32_t>(instruction->operand(0)->opcode()) |
| : 0; |
| VLOG(2) << "column = " << col << "\n"; |
| VLOG(2) << "config size = " << curconfig.size() << "\n"; |
| VLOG(2) << "search_config = " << *search_config_ << "\n"; |
| CHECK(col < curconfig.size()); |
| uint32_t config = ((*search_config_) > 0) |
| ? FlipMutation(&curconfig[col], 1, |
| "Move-" + HloOpcodeString(opcode)) |
| : curconfig[col]; |
| VLOG(2) << "Checking instruction is worth moving: " << config << "\n"; |
| VLOG(2) << "after checking search_config = " << *search_config_ << "\n"; |
| return (config != 0); |
| } |
| |
| int64_t ReusesBeforeBoundary(HloInstruction* user) { |
| int64_t reuses = 0; |
| for (auto op : user->operands()) { |
| // The operand must be an instruction that is not going to be moved (if |
| // user is inside the conditional); otherwise it must be the conditional |
| // itself and its user must be outside of the conditional. |
| if (!ContainsKey(visited_count_, op) && op != conditional_) { |
| continue; |
| } |
| if (auto tuple_gte = DynCast<HloGetTupleElementInstruction>(user)) { |
| if (op->opcode() == HloOpcode::kConditional) { |
| auto tuple = op->branch_computation(0)->root_instruction(); |
| if (tuple->opcode() == HloOpcode::kTuple) { |
| auto index = tuple_gte->tuple_index(); |
| CHECK(index < tuple->operand_count()); |
| op = tuple->mutable_operand(index); |
| } |
| } |
| reuses += ReusesCarriedBy(op, user->users()[0]); |
| } else { |
| reuses += ReusesCarriedBy(op, user); |
| } |
| } |
| VLOG(2) << "Reuses before instruction " << user->ToString() << ":" << reuses |
| << "\n"; |
| return reuses; |
| } |
| |
| int64_t ReusesAfterBoundary(HloInstruction* user) { |
| CHECK(user != nullptr); |
| auto all_users = user->users(); |
| // For now, assume that if an instruction has multiple-consumers, it |
| // will not be reused, as the reuse may require duplication in |
| // fusion and so is expensive. If the situation changes in the future, |
| // some aspects of the overall algorithm need to be redesigned to |
| // accommandate the change. |
| if (all_users.size() > 1) { |
| VLOG(2) << "Having multiple users from: " << user->ToString() << "\n"; |
| return 0; |
| } |
| if (!all_users.empty()) { |
| auto op = all_users[0]; |
| int64_t reuses = 0; |
| // Only count reuses that run through the conditional root. |
| if (op == conditional_->branch_computation(0)->root_instruction()) { |
| int64_t index = op->operand_index(user); |
| for (auto op2 : conditional_->users()) { |
| // If the use is not get tuple, right now do not consider it. |
| if (op2->opcode() == HloOpcode::kGetTupleElement) { |
| auto tuple_opd = static_cast<HloGetTupleElementInstruction*>(op2); |
| if (index == tuple_opd->tuple_index()) { |
| all_users = op2->users(); |
| if (!all_users.empty()) { |
| reuses += ReusesCarriedBy(user, all_users[0]); |
| break; |
| } |
| } |
| } |
| } |
| } else if (ContainsKey(visited_count_, op)) { |
| reuses += ReusesCarriedBy(user, op); |
| } |
| VLOG(2) << "reuses after instruction " << user->ToString() << ":" |
| << reuses << "\n"; |
| return reuses; |
| } |
| return 0; |
| } |
| |
| int64_t BenefitForMovingBoundaries(const std::vector<Boundary>& boundaries, |
| bool perform_reuse_analysis = true) { |
| int64_t reuses_before = 0, reuses_after = 0; |
| if (boundaries.size() == 1) { |
| if (boundaries[0].IsOutsideBranch() && |
| boundaries[0].operands()[0]->opcode() == |
| HloOpcode::kGetTupleElement) { |
| // The only boundary of moving-in is the get_tuple_element op. |
| return -1; |
| } |
| if (boundaries[0].IsInsideBranch() && |
| boundaries[0].operands()[0]->opcode() == HloOpcode::kTuple) { |
| // The only boundary of moving-out is the tuple op inside branches. |
| return -1; |
| } |
| } |
| // If trying alternative moving configurations, turn off reuse analysis. |
| if (!perform_reuse_analysis) { |
| return 1; |
| } |
| // For cases like : |
| // branch0 { |
| // ROOT copy |
| // } |
| // branch1 { |
| // ... |
| // } |
| // cond = conditional(branch0, branch1) |
| // copy = copy(cond) |
| // |
| // We can fold the two copies thus reducing computation. |
| auto get_copy_folding_benefit = [&](HloInstruction* hlo) -> int64_t { |
| if (hlo->opcode() != HloOpcode::kCopy) { |
| return 0; |
| } |
| const HloGetTupleElementInstruction* gte = |
| DynCast<HloGetTupleElementInstruction>(hlo->operand(0)); |
| if (gte == nullptr) { |
| return 0; |
| } |
| const HloInstruction* conditional = gte->operand(0); |
| if (conditional != conditional_) { |
| return 0; |
| } |
| int64_t benefit = 0; |
| for (auto* branch : conditional->called_computations()) { |
| HloInstruction* root = branch->root_instruction(); |
| if (root->opcode() == HloOpcode::kTuple) { |
| const auto* tuple_operand = root->operand(gte->tuple_index()); |
| if (tuple_operand->opcode() == HloOpcode::kCopy) { |
| if (Shape::Equal()(tuple_operand->operand(0)->shape(), |
| hlo->shape())) { |
| benefit += 10; |
| } |
| } |
| } |
| } |
| return benefit; |
| }; |
| for (const Boundary& b : boundaries) { |
| auto op = b.operands()[0]; |
| if (op == conditional_->branch_computation(0)->root_instruction()) { |
| continue; |
| } |
| VLOG(2) << "Benefit for " << op->ToString(); |
| reuses_before += ReusesBeforeBoundary(op); |
| VLOG(2) << "Reuses before boundary so far: " << reuses_before << "\n"; |
| reuses_after += ReusesAfterBoundary(op); |
| VLOG(2) << "Reuese after boundary so far : " << reuses_after << "\n"; |
| } |
| |
| int64_t copy_folding_benefit = 0; |
| if (boundaries[0].IsOutsideBranch()) { |
| for (const Boundary& b : boundaries) { |
| auto op = b.operands()[0]; |
| copy_folding_benefit += get_copy_folding_benefit(op); |
| } |
| } |
| VLOG(2) << "Copy folding benefit: " << copy_folding_benefit; |
| |
| if (reuses_after == 0 && reuses_before == 0 && copy_folding_benefit == 0) { |
| return -1; |
| } else if (boundaries[0].IsInsideBranch()) { |
| return reuses_after - reuses_before; |
| } else { |
| return reuses_before - reuses_after - 1 + copy_folding_benefit; |
| } |
| } |
| |
| Boundary GetNextBoundary(const Boundary& b, int64_t op_index) { |
| Boundary b2(b.GetPosition()); |
| for (int j = 0; j < b.operands().size(); ++j) { |
| HloInstruction* inst = b.operands()[j]; |
| CHECK(inst != nullptr); |
| HloInstruction* op = (b.IsInsideBranch()) ? inst->operands()[op_index] |
| : inst->users()[op_index]; |
| CHECK(op != nullptr); |
| b2.mutable_operands().push_back(op); |
| } |
| return b2; |
| } |
| |
| // Checking whether it is safe to move a boundary when visited through a |
| // dependent already considered for moving. |
| bool IsSafeToMoveBoundary(const Boundary& next_boundary) { |
| int64_t next_boundary_count = |
| (next_boundary.IsInsideBranch()) |
| ? next_boundary.operands()[0]->user_count() |
| : CountNonLeafOps(next_boundary.operands()[0]->operands()); |
| if (next_boundary_count <= 1) { |
| // If boundary has only a single or no dependent, safe to move. |
| return true; |
| } else { |
| if (!ContainsKey(visited_count_, next_boundary.operands()[0])) { |
| VLOG(2) << "Skip next boundary " << next_boundary.ToString() << "\n" |
| << " because it has multiple dependents: " |
| << next_boundary_count << "\n"; |
| visited_count_[next_boundary.operands()[0]] = 1; |
| new_boundaries_.push_back(next_boundary); |
| } else { |
| auto pos = std::find(new_boundaries_.begin(), new_boundaries_.end(), |
| next_boundary); |
| if (pos != new_boundaries_.end() || |
| next_boundary.operands().size() == 1) { |
| int count = ++visited_count_[next_boundary.operands()[0]]; |
| if (count == next_boundary_count) { |
| VLOG(2) << "Recovering next boundary " << next_boundary.ToString() |
| << "\n" |
| << " because all of its dependents have been visited: " |
| << next_boundary_count << "\n"; |
| visited_count_.erase(next_boundary.operands()[0]); |
| if (pos != new_boundaries_.end()) { |
| new_boundaries_.erase(pos); |
| } |
| return true; |
| } |
| } else { |
| VLOG(2) << "Skip incompatible multi-dependent boundary: " |
| << next_boundary.ToString() << ":" << next_boundary_count |
| << "\n"; |
| } |
| } |
| } |
| return false; |
| } |
| // This function is reused both for moving the boundary outside or into a |
| // conditional. As the result, the readability is somewhat compromised. |
| // It might be nice to refactor this function to factor the outside-inside |
| // considerations into separate function pointer parameters to improve |
| // readability. |
| void AddBoundaries(const Boundary& boundary) { |
| BoundaryVisitor visitor; |
| visitor.AddToWorkList(boundary); |
| while (visitor.HasNextBoundary()) { |
| Boundary b = visitor.PopNextBoundary(); |
| VLOG(2) << "visiting boundary " << b.ToString() << "\n"; |
| if ((b.IsOutsideBranch() || InstructionWithinBranchIdentical( |
| b.operands(), is_layout_sensitive_)) && |
| IsSafeToMoveBoundary(b) && |
| WorthHoisting(b.operands()[0], b.IsInsideBranch())) { |
| connected_boundaries_.push_back(b); |
| VLOG(2) << "boundary can be moved\n"; |
| int64_t operand_count = (b.IsInsideBranch()) |
| ? b.operands()[0]->operand_count() |
| : b.operands()[0]->users().size(); |
| for (int i = 0; i < operand_count; i++) { |
| Boundary next_boundary = GetNextBoundary(b, i); |
| VLOG(2) << "Add operand/user " << i << " to visit later\n"; |
| visitor.AddToWorkList(next_boundary); |
| } |
| } else { |
| VLOG(2) << "boundary cannot be moved\n"; |
| visited_count_[b.operands()[0]] = 1; |
| new_boundaries_.push_back(b); |
| } |
| } |
| } |
| std::vector<Boundary> BoundariesToMoveInOrOut(HloInstruction* conditional, |
| const Boundary& b) { |
| // At the beginning of optimization, a conditional itself is added to a |
| // worklist. Here the conditional is expanded into two sets of boundaries: |
| // the first set contains the boundary that is inside branches and |
| // contains the root of all branches; the second set of boundaries |
| // contains all the users of the conditional. |
| HloInstruction* inst = b.operands()[0]; |
| if (inst == conditional) { |
| int branch_count = inst->branch_count(); |
| // Add conditional roots as a new boundary to visit. |
| Boundary boundary_in(Boundary::Position::kInsideBranch); |
| for (int i = 0; i < branch_count; i++) { |
| HloComputation* branch_computation = inst->branch_computation(i); |
| HloInstruction* root_inst = branch_computation->root_instruction(); |
| CHECK(root_inst != nullptr); |
| boundary_in.mutable_operands().push_back(root_inst); |
| } |
| new_boundaries_.push_back(boundary_in); |
| // Add conditional users as new boundaries to visit. |
| for (auto u : inst->users()) { |
| Boundary boundary_in(Boundary::Position::kOutsideBranch); |
| boundary_in.mutable_operands().push_back(u); |
| new_boundaries_.push_back(boundary_in); |
| } |
| } else { |
| AddBoundaries(b); |
| } |
| return connected_boundaries_; |
| } |
| void AddNewBoundaries(std::vector<Boundary>& b) { |
| b.insert(b.end(), new_boundaries_.begin(), new_boundaries_.end()); |
| } |
| }; |
| |
| ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion( |
| HloInstruction* conditional, const Boundary& cur_boundary, |
| std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries, |
| absl::flat_hash_map<HloInstruction*, int>& visited_count) { |
| GroupConnectedBoundaries connect(conditional, is_layout_sensitive_, |
| visited_count, &move_config_, &reuse_config_, |
| &search_config_); |
| auto move_in_or_out = |
| connect.BoundariesToMoveInOrOut(conditional, cur_boundary); |
| if (!move_in_or_out.empty()) { |
| auto benefit = connect.BenefitForMovingBoundaries( |
| move_in_or_out, search_config_map_.empty()); |
| VLOG(2) << "benefit of moving in or out " |
| << cur_boundary.operands()[0]->ToString() << ":" << benefit << "\n"; |
| if (benefit >= 0) { |
| new_boundaries.clear(); |
| connect.AddNewBoundaries(new_boundaries); |
| // The whole sequence in move_in_or_out is either all moving into a |
| // conditional, or all moving out of a conditional. So looking only |
| // at the first entry of the sequence is sufficient to know which |
| // direction the move is intended. |
| to_move = move_in_or_out; |
| return Decision(to_move[0].IsInsideBranch() |
| ? Decision::Direction::kMoveOutOfBranch |
| : Decision::Direction::kMoveIntoBranch, |
| benefit); |
| } else { |
| connect.clear_recently_visited(); |
| } |
| } else { |
| connect.AddNewBoundaries(new_boundaries); |
| } |
| return Decision(Decision::Direction::kNoChange, 0); |
| } |
| |
| StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) { |
| VLOG(2) << "Begin a new pass of conditional code motion optimization.\n"; |
| // Use to support debugging of optimization, by disabling the opt after it has |
| // been applied a pre-determined times (to isolate impact of transformations). |
| if (!ConsumeFuel("conditional_code_motion", [&] { |
| return "Skipping conditional opt after allowed limit reaching 0.\n"; |
| })) { |
| return false; |
| } |
| bool changed = false; |
| bool cleanup_changed = false; |
| { |
| HloPassPipeline subpipeline("before_conditional_code_motion"); |
| subpipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/is_layout_sensitive_); |
| subpipeline.AddPass<HloDCE>(); |
| TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module)); |
| cleanup_changed |= cleanup_changed_now; |
| } |
| // Gather all the conditional ops in the module ahead of time, to avoid |
| // potential complications of modifying the code that affecting traversal. |
| std::vector<HloInstruction*> conditional_ops; |
| // Track how many times each branch computation is shared. |
| absl::flat_hash_map<HloComputation*, int> conditional_computations; |
| for (auto* comp : module->MakeComputationPostOrder()) { |
| for (auto* instr : comp->MakeInstructionPostOrder()) { |
| if (instr->opcode() == HloOpcode::kConditional) { |
| int branch_count = instr->branch_count(); |
| for (int i = 0; i < branch_count; ++i) { |
| HloComputation* branch_i = instr->branch_computation(i); |
| if (ContainsKey(conditional_computations, branch_i)) { |
| conditional_computations[branch_i]++; |
| } else { |
| conditional_computations[branch_i] = 0; |
| } |
| } |
| if (instr->shape().IsTuple()) { |
| bool can_change_tuple_shape = true; |
| for (auto user : instr->users()) { |
| VLOG(2) << "user is : " << user->ToString() << "\n"; |
| if (user->opcode() != HloOpcode::kGetTupleElement) { |
| can_change_tuple_shape = false; |
| } |
| } |
| if (can_change_tuple_shape) { |
| conditional_ops.push_back(instr); |
| } |
| } else { |
| conditional_ops.push_back(instr); |
| } |
| } |
| } |
| } |
| |
| int64_t conditional_index = 0; |
| // Use to collect mappings between cloned instructions. |
| HloCloneContext clone_context(module); |
| for (HloInstruction* conditional : conditional_ops) { |
| if (conditional_index == 0 || !search_config_map_.empty()) { |
| auto config_entry = search_config_map_.find(conditional_index); |
| if (config_entry != search_config_map_.end()) { |
| search_config_ = (*config_entry).second; |
| VLOG(2) << "config entry value extracted:" << search_config_.size(); |
| search_config_index_ = 0; |
| } |
| VLOG(2) << "Obtaining default configuration for conditional " |
| << conditional_index << "\n"; |
| SetDefaultMoveConfig(); |
| VLOG(2) << "Done obtaining default configuration\n"; |
| conditional_index++; |
| } |
| int branch_count = conditional->branch_count(); |
| // check for shared conditional computations |
| bool conditional_is_shared = false; |
| for (int i = 0; i < branch_count; ++i) { |
| HloComputation* branch_i = conditional->branch_computation(i); |
| if (conditional_computations[branch_i] > 0) { |
| conditional_is_shared = true; |
| break; |
| } |
| } |
| |
| // Boundaries to move out or to move into the branches. |
| std::vector<std::vector<Boundary>> to_move_out, to_move_in; |
| std::vector<std::vector<Boundary>> new_boundaries_for_moveout; |
| std::vector<std::vector<Boundary>> new_boundaries_for_movein; |
| // Number of times each instruction has been visited for moving. |
| absl::flat_hash_map<HloInstruction*, int> visited_count; |
| int benefit_move_out = 0, benefit_move_in = 0; |
| Decision::Direction final_d = Decision::Direction::kNoChange; |
| // The conditional is moved into a worklist as the seed (starting point). |
| // The conditional will be expanded into multiple seeds (starting points), |
| // its roots and its users, when it is visited by GroupConnectedBoundaries. |
| // A NO_CHANGE decision will always be returned for the conditional itself, |
| // so that the other seeding boundaries can be visited in turn. |
| BoundaryVisitor visitor(conditional); |
| VLOG(2) << "Analyzing conditional:" << conditional->ToString() << "\n"; |
| // Try visit all the boundaries, collect the analysis results, and save |
| // all the benefitical non-conflicting decisions. If two decisions conflict |
| // with each other, save the more benefitical one. |
| while (visitor.HasNextBoundary()) { |
| std::vector<Boundary> to_move, next_boundary; |
| Boundary boundary = visitor.PopNextBoundary(); |
| VLOG(2) << "Analyzing boundary:" << boundary.ToString() << "\n"; |
| auto d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary, |
| visited_count); |
| switch (d.GetDirection()) { |
| case Decision::Direction::kMoveOutOfBranch: |
| VLOG(2) << "Local Decision is move out of branch\n"; |
| to_move_out.push_back(to_move); |
| new_boundaries_for_moveout.push_back(next_boundary); |
| benefit_move_out += d.GetBenefit(); |
| if (benefit_move_out >= benefit_move_in) { |
| final_d = Decision::Direction::kMoveOutOfBranch; |
| VLOG(2) << "Current Decision is move out of branch (" |
| << to_move_out.size() << ")\n"; |
| } else { |
| VLOG(2) << "Current Decision remains move into branch\n"; |
| } |
| break; |
| case Decision::Direction::kMoveIntoBranch: |
| VLOG(2) << "Decision is move into branch\n"; |
| to_move_in.push_back(to_move); |
| new_boundaries_for_movein.push_back(next_boundary); |
| benefit_move_in += d.GetBenefit(); |
| if (benefit_move_out >= benefit_move_in) { |
| VLOG(2) << "Current Decision remains move out of branch\n"; |
| } else { |
| final_d = Decision::Direction::kMoveIntoBranch; |
| VLOG(2) << "Current Decision is move into branch (" |
| << to_move_in.size() << ")\n"; |
| } |
| break; |
| case Decision::Direction::kNoChange: |
| VLOG(2) << "Decision is no change\n"; |
| for (const Boundary& b : next_boundary) { |
| visitor.AddToWorkList(b); |
| VLOG(2) << "Adding new boundary to worklist:" << b.ToString() |
| << "\n"; |
| } |
| break; |
| } |
| } |
| // If modification is to be made, need to clone the shared branches. |
| if (final_d != Decision::Direction::kNoChange && conditional_is_shared) { |
| for (int i = 0; i < branch_count; ++i) { |
| HloComputation* branch_i = conditional->branch_computation(i); |
| if (conditional_computations[branch_i] > 0) { |
| // Cloning is absolutely needed if the computation is shared by |
| // different branches, but the cloning can be potentially avoided |
| // if the sharing is only among branches of the same conditional. |
| // If cloning these branches causes a problem due to space issues, |
| // a fix can pass a vector of unique branches to the actual |
| // transformations, as an alternative representation of the |
| // conditional branches to be modified. Right now we assume the |
| // overhead of cloning is minimal since later stages of the compiler |
| // inline all the computations anyway. |
| HloComputation* clone_i = |
| conditional->parent()->parent()->AddEmbeddedComputation( |
| branch_i->Clone("clone", &clone_context)); |
| conditional->set_branch_computation(i, clone_i); |
| conditional_computations[branch_i]--; |
| // Need to translate the analysis result to generate correct result. |
| auto update_boundary = [&](Boundary& boundary) { |
| auto cloned_instr = |
| clone_context.FindInstruction(boundary.operands()[i]); |
| CHECK(cloned_instr != nullptr); |
| VLOG(2) << "boundary before cloning:" << boundary.operands()[i] |
| << "\n"; |
| boundary.mutable_operands()[i] = cloned_instr; |
| VLOG(2) << "boundary after cloning:" << boundary.operands()[i] |
| << "\n"; |
| }; |
| // Only boundaries to move out need to be updated. |
| if (final_d == Decision::Direction::kMoveOutOfBranch) { |
| for (int i = 0; i < to_move_out.size(); ++i) { |
| std::vector<Boundary>& m = to_move_out[i]; |
| std::for_each(m.begin(), m.end(), update_boundary); |
| } |
| for (int i = 0; i < new_boundaries_for_moveout.size(); ++i) { |
| std::vector<Boundary>& m = new_boundaries_for_moveout[i]; |
| std::for_each(m.begin(), m.end(), update_boundary); |
| } |
| } |
| } |
| } |
| VLOG(2) << "Cloned branches as needed: " << conditional->ToString() |
| << "\n"; |
| } |
| // At most one of to_move_out or to_move_in can be non-empty, since there is |
| // only one optimization decision. |
| if (final_d == Decision::Direction::kMoveOutOfBranch) { |
| CHECK(to_move_out.size() == new_boundaries_for_moveout.size()); |
| for (int i = 0; i < to_move_out.size(); ++i) { |
| TF_ASSIGN_OR_RETURN(bool result, |
| MoveInstructionOut(conditional, to_move_out[i], |
| new_boundaries_for_moveout[i])); |
| changed |= result; |
| } |
| VLOG(2) << "Done moving out of branches " << to_move_out.size() |
| << " times. \n"; |
| if (!ConsumeFuel("conditional_code_motion", [&] { |
| return "Skipping conditional opt after allowed limit reaching 0.\n"; |
| })) { |
| break; |
| } |
| } else if (final_d == Decision::Direction::kMoveIntoBranch) { |
| CHECK(to_move_in.size() == new_boundaries_for_movein.size()); |
| for (int i = 0; i < to_move_in.size(); ++i) { |
| TF_ASSIGN_OR_RETURN(bool result, |
| MoveInstructionIn(conditional, to_move_in[i], |
| new_boundaries_for_movein[i])); |
| changed |= result; |
| } |
| VLOG(2) << "Done moving into branches " << to_move_in.size() |
| << " times. \n"; |
| if (!ConsumeFuel("conditional_code_motion", [&] { |
| return "Skipping conditional opt after allowed limit reaching 0.\n"; |
| })) { |
| break; |
| } |
| } else if (pursue_full_conditional_code_motion_ && !conditional_is_shared) { |
| // Invoke special handling for convert rematerialization/hoisting |
| // We need to make sure no sharing is present in the branches because no |
| // cloning has been done by the earlier analysis. |
| // TOOD[b/165848866]: extend solution to handle cloning for special move. |
| TF_ASSIGN_OR_RETURN( |
| bool convert_result, |
| ConvertSpecialMove(conditional, is_layout_sensitive_)); |
| if (convert_result) { |
| VLOG(2) << "Done special moving of convert\n"; |
| if (!ConsumeFuel("conditional_code_motion", [&] { |
| return "Skipping conditional opt after allowed limit reaching " |
| "0.\n"; |
| })) { |
| break; |
| } |
| } |
| changed |= convert_result; |
| } |
| } |
| if (changed) { |
| HloPassPipeline subpipeline( |
| "after_conditional_code_motion_after_convert_hoisting"); |
| VLOG(2) << "starting after motion passes: DCE\n"; |
| subpipeline.AddPass<HloDCE>(); |
| subpipeline.AddPass<TupleSimplifier>(); |
| subpipeline.AddPass<HloDCE>(); |
| TF_ASSIGN_OR_RETURN(auto cleanup_changed_now, subpipeline.Run(module)); |
| cleanup_changed |= cleanup_changed_now; |
| } |
| if (cleanup_changed) { |
| VLOG(2) << "subpipeline cleanup have modified code\n"; |
| } |
| return changed; |
| } |
| |
| void ConditionalCodeMotion::SetDefaultMoveConfig() { |
| VLOG(2) << "search_config_index = " << search_config_index_ << "\n"; |
| VLOG(2) << "search_config_ size = " << search_config_.size() << "\n"; |
| int64_t cur_search_config = (search_config_index_ < 0 || |
| search_config_index_ >= search_config_.size()) |
| ? 0 |
| : search_config_[search_config_index_]; |
| enum class TuningOption { |
| kDoNotTune = 0, |
| kTuneTransformationDecision = 1, |
| kTuneReuseModel = 2, |
| }; |
| TuningOption tuning_option = |
| (cur_search_config == 0) ? TuningOption::kDoNotTune |
| : (cur_search_config > 0) ? TuningOption::kTuneTransformationDecision |
| : TuningOption::kTuneReuseModel; |
| |
| auto row = HloOpcodeCount(); |
| auto col = row; |
| VLOG(2) << "Start setting default configuration\n"; |
| reuse_config_.clear(); |
| move_config_.clear(); |
| reuse_config_.reserve(row); |
| move_config_.reserve(row); |
| for (int64_t opcode = 0; opcode < row; ++opcode) { |
| // To save whether an instruction is preferred to be moved. |
| std::vector<int64_t> reuse_vec(col, 0); |
| for (uint32_t j = 0; j < col; ++j) { |
| reuse_vec[j] = ReusesCarriedBy(static_cast<HloOpcode>(opcode), |
| static_cast<HloOpcode>(j)); |
| } |
| reuse_config_.push_back(reuse_vec); |
| std::vector<int64_t> move_vec; |
| switch (tuning_option) { |
| case TuningOption::kTuneTransformationDecision: |
| // Tuning transformation decision --- start with all yes. |
| // Only a single entry is needed if we don't consider operands of an op |
| // when searching/tuning transformation decisions. |
| move_vec.push_back(1); |
| break; |
| // Tune the ReusesCarriedBy results only. |
| case TuningOption::kTuneReuseModel: |
| case TuningOption::kDoNotTune: |
| // No tuning --- use the default configuration. |
| // Use the opcode of first operand to configure default. |
| move_vec.reserve(col); |
| for (uint32_t j = 0; j < col; ++j) { |
| move_vec.push_back(WorthHoisting(static_cast<HloOpcode>(opcode), |
| static_cast<HloOpcode>(j))); |
| } |
| break; |
| } |
| move_config_.push_back(move_vec); |
| } |
| } |
| |
| // The search configuration is specified using a string in the format of |
| // 'config1;config2; ...;config_n', where each config_i is in the format of |
| // 'index,start,max,stride' (four integers separated by comma), which specify |
| // the index number of the conditional being configured, the index of the first |
| // transformation decision to flip for the conditional, the max number of |
| // decisions to flip, and how many decisions to skip in between the flips. |
| void ConditionalCodeMotion::ParseSearchConfiguration( |
| const std::string& search_config) { |
| if (search_config.empty()) { |
| return; |
| } |
| search_config_index_ = 0; |
| std::vector<std::string> configs = absl::StrSplit(search_config, ';'); |
| for (const std::string& config : configs) { |
| std::vector<std::string> specs = absl::StrSplit(config, ','); |
| CHECK_EQ(specs.size(), 4); |
| int64_t condition_index; |
| CHECK(absl::SimpleAtoi(specs[0], &condition_index)); |
| auto& cur_config_entry = search_config_map_[condition_index]; |
| int64_t flip_start, max_flip, flip_stride; |
| CHECK(absl::SimpleAtoi(specs[1], &flip_start)); |
| CHECK(absl::SimpleAtoi(specs[2], &max_flip)); |
| CHECK(absl::SimpleAtoi(specs[3], &flip_stride)); |
| int64_t cur_config = MakeSearchConfig(flip_start, max_flip, flip_stride); |
| cur_config_entry.push_back(cur_config); |
| VLOG(2) << "Setting search config " << condition_index << "->" << cur_config |
| << "\n"; |
| } |
| } |
| |
| } // namespace conditional_opt |
| |
| } // namespace xla |