| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/service/memory_space_assignment.h" |
| |
| #include "tensorflow/compiler/xla/debug_options_flags.h" |
| #include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h" |
| #include "tensorflow/core/lib/math/math_util.h" |
| namespace xla { |
| |
| namespace { |
| // Define a dummy chunk for chunks that will be allocated in the default memory |
| // space and for keeping track of number of asynchronous copies. |
| const HeapSimulator::Chunk kDummyChunk{-1, -1}; |
| // This variable is used by the cost analysis in estimating how many times each |
| // while loop will execute. Nested loops will be assumed to have executed |
| // pow(kWhileExecutionCount, nesting_level) times. |
| const int kWhileExecutionCount = 5; |
| |
| } // namespace |
| |
| /*static*/ StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>> |
| MemorySpaceAssignmentCostAnalysis::Create( |
| const HloCostAnalysis& cost_analysis, |
| float async_copy_bandwidth_bytes_per_second, |
| float alternate_mem_bandwidth_bytes_per_second, const HloModule& module) { |
| TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module)); |
| TF_ASSIGN_OR_RETURN(auto hlo_live_range, |
| HloLiveRange::Run(module.schedule(), *alias_analysis, |
| module.entry_computation())); |
| auto call_graph = CallGraph::Build(&module); |
| return absl::WrapUnique(new MemorySpaceAssignmentCostAnalysis( |
| cost_analysis, async_copy_bandwidth_bytes_per_second, |
| alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis), |
| std::move(hlo_live_range), std::move(call_graph))); |
| } |
| |
| float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( |
| const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem, |
| MemorySpaceAssignmentCostAnalysis::Cache* cache) const { |
| float elapsed_time_due_to_compute = |
| GetInstructionElapsedDueToCompute(instruction); |
| float elapsed_time_due_to_memory = |
| GetInstructionElapsedDueToMemory(instruction); |
| if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) { |
| // Memory bound, return how much alternate memory is better. |
| float while_nest_multiplier; |
| if (cache) { |
| // If there is a cache provided, memoize the while nest multiplier. |
| auto it = cache->while_nest_multiplier.find(&instruction); |
| if (it != cache->while_nest_multiplier.end()) { |
| while_nest_multiplier = it->second; |
| } else { |
| while_nest_multiplier = tensorflow::MathUtil::IPow<float>( |
| kWhileExecutionCount, CalculateWhileLoopNestLevel(&instruction)); |
| cache->while_nest_multiplier[&instruction] = while_nest_multiplier; |
| } |
| } else { |
| while_nest_multiplier = tensorflow::MathUtil::IPow<float>( |
| kWhileExecutionCount, CalculateWhileLoopNestLevel(&instruction)); |
| } |
| return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) * |
| while_nest_multiplier; |
| } else { |
| // Compute bound, return how far off are we to memory boundedness. |
| return elapsed_time_due_to_memory - elapsed_time_due_to_compute; |
| } |
| } |
| |
| float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness( |
| const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, |
| MemorySpaceAssignmentCostAnalysis::Cache* cache) const { |
| const HloInstruction& defining_instruction = |
| *interval.buffer->defining_instruction(); |
| float alternate_mem_benefit = GetAlternateMemoryBenefit( |
| defining_instruction, |
| GetInstructionElapsedDueToMemory(defining_instruction, |
| /*operand_in_alternate_mem=*/{}, |
| /*output_in_alternate_mem=*/true), |
| cache); |
| for (const HloBuffer* buffer : alias_analysis_->ComputeBuffersAt( |
| interval.buffer->defining_position().instruction, |
| interval.buffer->defining_position().index)) { |
| for (const HloValue* value : buffer->values()) { |
| for (const HloUse& use : value->uses()) { |
| // We look inside the called computations of while and conditional, so |
| // don't use the benefit of while and conditional directly. |
| if (use.instruction->opcode() == HloOpcode::kWhile || |
| use.instruction->opcode() == HloOpcode::kConditional) { |
| continue; |
| } |
| float use_alternate_mem_benefit = |
| GetAlternateMemoryBenefit(*use.instruction, |
| GetInstructionElapsedDueToMemory( |
| *use.instruction, use.operand_number), |
| cache); |
| // If the benefit is positive (memory bound), add it to this buffer's |
| // benefit. If the benefit is negative (compute bound), calculate the |
| // maximum. |
| if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) { |
| alternate_mem_benefit += use_alternate_mem_benefit; |
| } else { |
| alternate_mem_benefit = |
| std::max(alternate_mem_benefit, use_alternate_mem_benefit); |
| } |
| } |
| } |
| } |
| |
| // Get performance slowdown in seconds of prefetching current BufferInterval |
| // causing to other BufferIntervals. |
| float alternate_mem_slowdown = |
| GetInstructionElapsedDueToMemorySlowdown(interval.size); |
| |
| // Divide by the size of the buffer to prioritize smaller buffers that will |
| // give the largest alternate memory benefit. |
| return (alternate_mem_benefit - alternate_mem_slowdown) / interval.size; |
| } |
| |
| int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel( |
| const HloInstruction* instruction) const { |
| int nest_level = 0; |
| const HloComputation* computation = instruction->parent(); |
| while (!computation->IsEntryComputation()) { |
| auto node = call_graph_->GetNode(computation); |
| auto callsites = node.caller_callsites(); |
| CHECK_EQ(callsites.size(), 1) << "The module is not flattened!"; |
| auto callsite = callsites[0]; |
| if (callsite.instruction()->opcode() == HloOpcode::kWhile) { |
| ++nest_level; |
| } |
| computation = callsite.instruction()->parent(); |
| } |
| return nest_level; |
| } |
| |
| float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute( |
| const HloInstruction& instruction) const { |
| return std::max( |
| cost_analysis_.flop_count(instruction) / |
| cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey), |
| cost_analysis_.transcendental_count(instruction) / |
| cost_analysis_.per_second_rate(HloCostAnalysis::kTranscendentalsKey)); |
| } |
| |
| float MemorySpaceAssignmentCostAnalysis:: |
| GetInstructionElapsedDueToMemorySlowdown(int64 bytes) const { |
| return bytes / |
| cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); |
| } |
| |
| float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory( |
| const HloInstruction& instruction, |
| absl::optional<int64> operand_in_alternate_mem, |
| bool output_in_alternate_mem) const { |
| float bytes_accessed = cost_analysis_.bytes_accessed(instruction); |
| float elapsed_due_to_bytes = |
| bytes_accessed / |
| cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); |
| if (operand_in_alternate_mem) { |
| // Estimate the elapsed time due to the operand being in the alternate |
| // memory space. |
| float operand_bytes_accessed = cost_analysis_.operand_bytes_accessed( |
| instruction, *operand_in_alternate_mem); |
| float elapsed_due_to_operand_bytes = |
| operand_bytes_accessed / alternate_mem_bandwidth_bytes_per_second_; |
| bytes_accessed -= operand_bytes_accessed; |
| elapsed_due_to_bytes = |
| elapsed_due_to_operand_bytes + |
| bytes_accessed / |
| cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); |
| } |
| if (output_in_alternate_mem) { |
| // Estimate the elapsed time due to the output being in the alternate memory |
| // space. |
| float output_bytes_accessed = |
| cost_analysis_.output_bytes_accessed(instruction); |
| float elapsed_due_to_output_bytes = |
| output_bytes_accessed / alternate_mem_bandwidth_bytes_per_second_; |
| bytes_accessed -= output_bytes_accessed; |
| elapsed_due_to_bytes = |
| elapsed_due_to_output_bytes + |
| bytes_accessed / |
| cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); |
| } |
| return elapsed_due_to_bytes; |
| } |
| |
| float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsed( |
| const HloInstruction& instruction) const { |
| return std::max(GetInstructionElapsedDueToCompute(instruction), |
| GetInstructionElapsedDueToMemory(instruction)); |
| } |
| |
| float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedInAlternateMemory( |
| const HloInstruction& instruction, |
| absl::optional<int64> operand_in_alternate_mem, |
| bool output_in_alternate_mem) const { |
| return std::max( |
| GetInstructionElapsedDueToCompute(instruction), |
| GetInstructionElapsedDueToMemory(instruction, operand_in_alternate_mem, |
| output_in_alternate_mem)); |
| } |
| |
| float MemorySpaceAssignmentCostAnalysis::GetAsyncCopyElapsed( |
| const Shape& shape) const { |
| int64 size_in_bytes = cost_analysis_.GetShapeSize(shape); |
| return static_cast<float>(size_in_bytes) / |
| async_copy_bandwidth_bytes_per_second_; |
| } |
| |
| int64 MemorySpaceAssignmentCostAnalysis::GetScheduleEndTime() const { |
| return hlo_live_range_->schedule_end_time(); |
| } |
| |
| bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy( |
| const Shape& shape, int64 start_time, int64 end_time) const { |
| return end_time - start_time <= max_overlap_count_; |
| } |
| |
| int64 InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime( |
| const Shape& shape, int64 start_time, int64 latest_end_time) const { |
| return std::min(start_time + min_overlap_count_, latest_end_time); |
| } |
| |
| void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use, |
| int64 start_time, |
| int64 end_time) { |
| end_time_ = end_time; |
| current_prefetch_time_ = std::max(start_time, end_time_ - max_overlap_count_); |
| } |
| |
| int64 InstructionCountPrefetchIntervalPicker::Next() { |
| CHECK(!Done()) << "Prefetch interval picker's Next() is called even though " |
| "Done() is false"; |
| return current_prefetch_time_++; |
| } |
| |
| bool InstructionCountPrefetchIntervalPicker::Done() const { |
| return end_time_ - current_prefetch_time_ <= min_overlap_count_; |
| } |
| |
| std::string InstructionCountPrefetchIntervalPicker::ToDebugString() const { |
| return absl::StrCat("Overlapped HLOs = ", end_time_ - current_prefetch_time_); |
| } |
| |
| std::string InstructionCountPrefetchIntervalPicker::ToNoCopyDebugString( |
| const Shape& shape, int64 start_time, int64 end_time) const { |
| return absl::StrCat("Overlapped HLOs = ", end_time - start_time); |
| } |
| |
| CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( |
| const MemorySpaceAssignmentCostAnalysis& cost_analysis, |
| float min_async_copy_to_overlap_ratio, |
| float max_async_copy_to_overlap_ratio, |
| float preferred_async_copy_to_overlap_ratio) |
| : while_nest_level_( |
| cost_analysis.hlo_live_range().instruction_schedule().size(), 0), |
| cost_analysis_(cost_analysis), |
| min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio), |
| max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio), |
| preferred_async_copy_to_overlap_ratio_( |
| preferred_async_copy_to_overlap_ratio) { |
| instruction_schedule_ = |
| &cost_analysis_.hlo_live_range().instruction_schedule(); |
| |
| // Create a vector of elapsed times and while nesting levels of HLO |
| // instructions. The elapsed times are multiplied by pow(kWhileExecutionCount, |
| // nest_level) to account for executing the HLOs multiple times in while |
| // loops. |
| std::vector<float> instructions_elapsed_time(instruction_schedule_->size(), |
| 0.0); |
| for (const auto& instruction_and_logical_time : *instruction_schedule_) { |
| // To avoid double counting, don't include the elapsed time of while and |
| // conditional HLOs. |
| const HloInstruction* instruction = instruction_and_logical_time.first; |
| if (instruction->opcode() == HloOpcode::kWhile || |
| instruction->opcode() == HloOpcode::kConditional) { |
| continue; |
| } |
| float elapsed_time = cost_analysis_.GetInstructionElapsed( |
| *instruction_and_logical_time.first); |
| int64 logical_time = instruction_and_logical_time.second; |
| if (logical_time >= instructions_elapsed_time.size()) { |
| instructions_elapsed_time.resize(logical_time + 1, 0.0); |
| while_nest_level_.resize(logical_time + 1, 0); |
| } |
| int nest_level = cost_analysis_.CalculateWhileLoopNestLevel( |
| instruction_and_logical_time.first); |
| while_nest_level_[logical_time] = nest_level; |
| instructions_elapsed_time[logical_time] = |
| elapsed_time * |
| tensorflow::MathUtil::IPow<float>(kWhileExecutionCount, nest_level); |
| } |
| // As an optimization, create a cumulative sum vector of elapsed time. |
| float cumsum = 0.0; |
| elapsed_time_cumsum_.reserve(instructions_elapsed_time.size()); |
| for (float elapsed_time : instructions_elapsed_time) { |
| cumsum += elapsed_time; |
| elapsed_time_cumsum_.push_back(cumsum); |
| } |
| // To be able to accurately determine the minimum nest level between a start |
| // time and an end time efficiently, populate a data structure that stores the |
| // closest nest level change index. |
| int prev_nest_level = 0; |
| int change_idx = -1; |
| while_nest_level_change_.reserve(instructions_elapsed_time.size()); |
| for (int i = 0; i < while_nest_level_.size(); ++i) { |
| int nest_level = while_nest_level_[i]; |
| if (nest_level != prev_nest_level) { |
| prev_nest_level = nest_level; |
| change_idx = i - 1; |
| } |
| while_nest_level_change_.push_back(change_idx); |
| } |
| } |
| |
| bool CostAnalysisPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy( |
| const Shape& shape, int64 start_time, int64 end_time) const { |
| // Even though this method returns if we allow the buffer in alternate memory |
| // _without_ asynchronous copies, calculate how long it would have taken to |
| // copy it and compare it to the elapsed time in the logical interval. |
| float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape); |
| float logical_interval_elapsed = |
| GetLogicalIntervalElapsed(start_time, end_time); |
| return max_async_copy_to_overlap_ratio_ * max_overlap_multiplier_ * |
| async_copy_elapsed > |
| logical_interval_elapsed; |
| } |
| |
| int64 CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime( |
| const Shape& shape, int64 start_time, int64 latest_end_time) const { |
| float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape); |
| int64 end_time; |
| for (end_time = start_time + 1; end_time <= latest_end_time; ++end_time) { |
| float logical_interval_elapsed = |
| GetLogicalIntervalElapsed(start_time, end_time); |
| if (logical_interval_elapsed >= |
| min_async_copy_to_overlap_ratio_ * async_copy_elapsed) { |
| break; |
| } |
| } |
| return end_time; |
| } |
| |
| void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use, |
| int64 start_time, |
| int64 end_time) { |
| const Shape& shape = ShapeUtil::GetSubshape( |
| use.instruction->operand(use.operand_number)->shape(), use.operand_index); |
| // Find the earliest time that satisfies max_async_copy_to_overlap_ratio_. |
| async_copy_elapsed_ = cost_analysis_.GetAsyncCopyElapsed(shape); |
| // Estimate the time we would save by having this op in alternate memory. |
| float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction); |
| float elapsed_time_in_alternate_mem = |
| cost_analysis_.GetInstructionElapsedInAlternateMemory( |
| *use.instruction, use.operand_number, |
| /*output_in_alternate_mem=*/false); |
| inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem; |
| end_logical_time_ = end_time; |
| int end_nest_level = while_nest_level_[end_logical_time_]; |
| |
| // Find the latest time we're allowed to start prefetching. |
| float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed_; |
| for (latest_prefetch_time_ = end_logical_time_ - 1; |
| latest_prefetch_time_ >= start_time && |
| (while_nest_level_[latest_prefetch_time_] != end_nest_level || |
| min_interval > GetLogicalIntervalElapsed(latest_prefetch_time_, |
| end_logical_time_) + |
| inst_elapsed_reduction_); |
| --latest_prefetch_time_) { |
| } |
| |
| // Find the earliest time we're allowed to start prefetching. |
| float max_interval = max_async_copy_to_overlap_ratio_ * |
| max_overlap_multiplier_ * async_copy_elapsed_; |
| for (earliest_prefetch_time_ = start_time; |
| earliest_prefetch_time_ <= end_logical_time_ && |
| (while_nest_level_[earliest_prefetch_time_] != end_nest_level || |
| max_interval < GetLogicalIntervalElapsed(earliest_prefetch_time_, |
| end_logical_time_)); |
| ++earliest_prefetch_time_) { |
| } |
| if (earliest_prefetch_time_ > latest_prefetch_time_) { |
| // There is no available prefetch interval for the given start and end |
| // times. Set the iterators accordingly to ensure Done() returns true. |
| increasing_prefetch_time_iterator_ = earliest_prefetch_time_; |
| decreasing_prefetch_time_iterator_ = latest_prefetch_time_; |
| CHECK(Done()); |
| return; |
| } |
| |
| // Between the earliest and latest prefetch interval, find the interval |
| // closest to the preferred interval and start iterating from there. |
| int64 starting_prefetch_time = earliest_prefetch_time_; |
| float preferred_interval = |
| preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed_; |
| float best_interval = |
| GetLogicalIntervalElapsed(earliest_prefetch_time_, end_logical_time_); |
| for (int64 prefetch_time = earliest_prefetch_time_ + 1; |
| prefetch_time <= latest_prefetch_time_; ++prefetch_time) { |
| float interval = |
| GetLogicalIntervalElapsed(prefetch_time, end_logical_time_); |
| if (while_nest_level_[prefetch_time] == end_nest_level && |
| std::abs(preferred_interval - interval) < |
| std::abs(preferred_interval - best_interval)) { |
| best_interval = interval; |
| starting_prefetch_time = prefetch_time; |
| } |
| } |
| VLOG(4) << "Interval min/max/preferred = " << min_interval << " " |
| << max_interval << " " << preferred_interval |
| << " prefetch time earliest/latest/starting = " |
| << earliest_prefetch_time_ << " " << latest_prefetch_time_ << " " |
| << starting_prefetch_time; |
| |
| increasing_prefetch_time_iterator_ = starting_prefetch_time; |
| decreasing_prefetch_time_iterator_ = starting_prefetch_time; |
| using_increasing_prefetch_time_iterator_ = true; |
| // Since both iterators start at the same position, call Next() once to |
| // advance one of the iterators. |
| Next(); |
| } |
| |
| int64 CostAnalysisPrefetchIntervalPicker::Next() { |
| CHECK(!Done()) << "Prefetch interval picker's Next() is called even though " |
| "Done() is false"; |
| if (using_increasing_prefetch_time_iterator_) { |
| int64 prefetch_time = increasing_prefetch_time_iterator_++; |
| while (increasing_prefetch_time_iterator_ <= latest_prefetch_time_ && |
| while_nest_level_[increasing_prefetch_time_iterator_] != |
| while_nest_level_[end_logical_time_]) { |
| ++increasing_prefetch_time_iterator_; |
| } |
| if (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_) { |
| using_increasing_prefetch_time_iterator_ = false; |
| } |
| return prefetch_time; |
| } else { |
| int64 prefetch_time = decreasing_prefetch_time_iterator_--; |
| while (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_ && |
| while_nest_level_[decreasing_prefetch_time_iterator_] != |
| while_nest_level_[end_logical_time_]) { |
| --decreasing_prefetch_time_iterator_; |
| } |
| if (increasing_prefetch_time_iterator_ <= latest_prefetch_time_) { |
| using_increasing_prefetch_time_iterator_ = true; |
| } |
| return prefetch_time; |
| } |
| } |
| |
| bool CostAnalysisPrefetchIntervalPicker::Done() const { |
| return increasing_prefetch_time_iterator_ > latest_prefetch_time_ && |
| decreasing_prefetch_time_iterator_ < earliest_prefetch_time_; |
| } |
| |
| void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) { |
| // Use twice as large max overlap limit in each retry. |
| max_overlap_multiplier_ = 1 << retry_number; |
| } |
| |
| int CostAnalysisPrefetchIntervalPicker::GetMinWhileNestLevel( |
| int64 start_time, int64 end_time) const { |
| int min_nest_level = |
| std::min(while_nest_level_[start_time], while_nest_level_[end_time]); |
| int change_idx = while_nest_level_change_[end_time]; |
| while (change_idx >= start_time) { |
| min_nest_level = std::min(min_nest_level, while_nest_level_[change_idx]); |
| change_idx = while_nest_level_change_[change_idx]; |
| } |
| return min_nest_level; |
| } |
| |
| float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed( |
| int64 start_time, int64 end_time) const { |
| CHECK_LE(start_time, end_time); |
| if (start_time == end_time) { |
| return 0.0; |
| } |
| if (start_time < 0) { |
| start_time = 0; |
| } |
| // Since elapsed_time_cumsum_ is already weighed by the while loop nesting |
| // level, normalize the elapsed time by dividing with the nesting factor of |
| // the interval (start and end times). |
| int interval_nest_level = GetMinWhileNestLevel(start_time, end_time); |
| return (elapsed_time_cumsum_[end_time - 1] - |
| elapsed_time_cumsum_[start_time]) / |
| tensorflow::MathUtil::IPow<float>(kWhileExecutionCount, |
| interval_nest_level); |
| } |
| |
| std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const { |
| int current_logical_prefetch_time = using_increasing_prefetch_time_iterator_ |
| ? increasing_prefetch_time_iterator_ |
| : decreasing_prefetch_time_iterator_; |
| float logical_interval_elapsed = GetLogicalIntervalElapsed( |
| current_logical_prefetch_time, end_logical_time_); |
| return absl::StrCat( |
| "Async copy elapsed (s) = ", async_copy_elapsed_, |
| ", inst elapsed reduction (s) = ", inst_elapsed_reduction_, |
| ", logical interval elapsed (s) = ", logical_interval_elapsed, |
| ", interval = (", current_logical_prefetch_time, ", ", end_logical_time_, |
| ")"); |
| } |
| |
| std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString( |
| const Shape& shape, int64 start_time, int64 end_time) const { |
| float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape); |
| float logical_interval_elapsed = |
| GetLogicalIntervalElapsed(start_time, end_time); |
| return absl::StrCat( |
| "Async copy elapsed (s) = ", async_copy_elapsed, |
| ", logical interval elapsed (s) = ", logical_interval_elapsed); |
| } |
| |
| absl::optional<float> |
| CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit( |
| const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { |
| return cost_analysis_.GetMemoryBoundedness(interval); |
| } |
| |
| std::string MemorySpaceAssignment::AllocationValue::ToString() const { |
| std::string out = absl::StrCat("computation = ", computation()->name()); |
| absl::StrAppend(&out, "\n position:\n"); |
| absl::StrAppend(&out, " ", defining_position_.ToString(), "\n"); |
| absl::StrAppend(&out, " uses:\n"); |
| for (const Use& use : uses_) { |
| absl::StrAppend(&out, " ", use.hlo_use.ToString(), "\n"); |
| } |
| return out; |
| } |
| |
| std::string MemorySpaceAssignment::AllocationValue::ToShortString() const { |
| return absl::StrCat("computation = ", computation()->name(), |
| ", position = ", defining_position_.ToString(), |
| ", value = ", value_->ToShortString()); |
| } |
| |
| void AlternateMemoryBestFitHeap::CreateAllocationValues( |
| const HloValue* value, std::vector<AllocationValue>* allocation_values) { |
| VLOG(3) << "Creating AllocationValues for: " << value->ToString(); |
| |
| // Find and sort all non-trivial (excluding GTE, Tuple, and bitcast) |
| // positions. We create an AllocationValue object for each non-trivial |
| // position. And for each AllocationValue object, we create an |
| // AllocationSequence consisting of one or more Allocation objects.The reason |
| // why we exclude the trivial positions from AllocationValue is because |
| // Allocation objects have special support for tuples and bitcasts. |
| const absl::flat_hash_map<const HloInstruction*, int64>& |
| instruction_schedule = hlo_live_range_.instruction_schedule(); |
| std::vector<HloPosition> positions; |
| for (const HloPosition& position : value->positions()) { |
| const HloInstruction* instruction = position.instruction; |
| if (instruction->opcode() != HloOpcode::kGetTupleElement && |
| instruction->opcode() != HloOpcode::kTuple && |
| instruction->opcode() != HloOpcode::kBitcast) { |
| positions.push_back(position); |
| } |
| } |
| absl::c_stable_sort(positions, |
| [&](const HloPosition& pos1, const HloPosition& pos2) { |
| return instruction_schedule.at(pos1.instruction) < |
| instruction_schedule.at(pos2.instruction); |
| }); |
| |
| // Create an AllocationValue for each non-trivial position. |
| absl::flat_hash_set<const HloComputation*> computations; |
| int beginning_idx = allocation_values->size(); |
| for (int i = 0; i < positions.size(); ++i) { |
| const HloPosition& position = positions.at(i); |
| allocation_values->emplace_back(value, position); |
| } |
| |
| std::vector<HloUse> uses(value->uses()); |
| absl::c_stable_sort(uses, [&](const HloUse& use1, const HloUse& use2) { |
| return instruction_schedule.at(use1.instruction) < |
| instruction_schedule.at(use2.instruction); |
| }); |
| |
| // Associate each use with an AllocationValue. Each AllocationValue contains a |
| // position and uses in the same computation. Furthermore, if the original |
| // HloValue had multiple non-trivial positions in the same computation, those |
| // will get their own AllocationValue as well. We split these HloValues so |
| // that when we insert CopyStart/CopyDone in CopyAllocation::Process, they |
| // point to the latest position. We then replace the operand of the use with |
| // CopyStart/CopyDone with an operand of the latest position. |
| for (const HloUse& use : uses) { |
| int64 use_time = instruction_schedule.at(use.instruction); |
| HloComputation* use_computation = use.instruction->parent(); |
| |
| AllocationValue* last_allocation_value = nullptr; |
| for (int i = beginning_idx; i < allocation_values->size(); ++i) { |
| AllocationValue* allocation_value = &allocation_values->at(i); |
| if (allocation_value->computation() == use_computation && |
| instruction_schedule.at( |
| allocation_value->defining_position().instruction) < use_time) { |
| last_allocation_value = allocation_value; |
| } |
| } |
| CHECK(last_allocation_value != nullptr); |
| last_allocation_value->AddUse(use, use_time); |
| } |
| |
| for (int i = beginning_idx; i < allocation_values->size(); ++i) { |
| VLOG(3) << "Created allocation value: " |
| << allocation_values->at(i).ToString(); |
| } |
| } |
| |
| void AlternateMemoryBestFitHeap::FindAliases( |
| std::vector<AllocationValue>* allocation_values) const { |
| absl::flat_hash_map<const HloInstruction*, const AllocationValue*> |
| values_by_defining_inst; |
| for (AllocationValue& value : *allocation_values) { |
| // Skip the value if it doesn't have any uses. |
| if (value.uses().empty()) { |
| continue; |
| } |
| CHECK_EQ(values_by_defining_inst.count(value.defining_instruction()), 0); |
| values_by_defining_inst[value.defining_instruction()] = &value; |
| } |
| auto maybe_add_alias_with_instruction = [&](const HloInstruction* instruction, |
| AllocationValue::Use* use) { |
| auto aliased_value_it = values_by_defining_inst.find(instruction); |
| if (aliased_value_it != values_by_defining_inst.end()) { |
| VLOG(3) << "Adding aliasing for use " << use->hlo_use.ToString() << " to " |
| << aliased_value_it->second->ToShortString(); |
| use->aliases.push_back(aliased_value_it->second->defining_position()); |
| } |
| }; |
| |
| for (AllocationValue& value : *allocation_values) { |
| for (AllocationValue::Use& use : value.uses()) { |
| // Find any aliases with the instruction itself (operand and output must |
| // alias). |
| maybe_add_alias_with_instruction(use.hlo_use.instruction, &use); |
| |
| // Find any aliases with the parameters of called computations. |
| for (const HloComputation* called_computation : |
| use.hlo_use.instruction->called_computations()) { |
| for (const HloInstruction* parameter_instruction : |
| called_computation->parameter_instructions()) { |
| maybe_add_alias_with_instruction(parameter_instruction, &use); |
| } |
| } |
| |
| // Special case for kWhile: the root of the body computation must alias as |
| // well. |
| if (use.hlo_use.instruction->opcode() == HloOpcode::kWhile) { |
| HloPosition root_alias{ |
| use.hlo_use.instruction->while_body()->root_instruction(), |
| use.hlo_use.operand_index}; |
| VLOG(3) << "Adding while body root aliasing for use " |
| << use.hlo_use.ToString() << " to " << root_alias; |
| use.aliases.push_back(root_alias); |
| } |
| } |
| } |
| } |
| |
| std::vector<const GlobalDecreasingSizeBestFitHeap::BufferInterval*> |
| AlternateMemoryBestFitHeap::GetSortedColocatedIntervals( |
| const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { |
| std::vector<const BufferInterval*> colocated_intervals; |
| std::vector<const BufferInterval*> worklist = {&interval}; |
| while (!worklist.empty()) { |
| const BufferInterval* item = worklist.back(); |
| worklist.pop_back(); |
| colocated_intervals.push_back(item); |
| for (const HloValue* buffer_colocated : item->colocations) { |
| worklist.push_back(&buffer_intervals_.at(buffer_colocated)); |
| } |
| } |
| |
| absl::c_stable_sort(colocated_intervals, [&](const BufferInterval* x, |
| const BufferInterval* y) { |
| return std::make_pair(x->start, x->end) < std::make_pair(y->start, y->end); |
| }); |
| return colocated_intervals; |
| } |
| |
| bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory( |
| const AllocationValue& value, const HloUse& use) const { |
| const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); |
| if (use.instruction->opcode() == HloOpcode::kWhile) { |
| HloComputation* while_body = use.instruction->while_body(); |
| |
| // We don't want to allocate this buffer in alternate memory if it will be |
| // evicted anyway. Find out if it has an early use or a late definition that |
| // would make sense to keep it in the alternate memory. |
| HloValue* parameter_value = |
| &alias_analysis_.dataflow_analysis().GetUniqueValueAt( |
| while_body->parameter_instruction(0), use.operand_index); |
| int64 parameter_time = |
| instruction_schedule.at(while_body->parameter_instruction(0)); |
| int64 root_time = instruction_schedule.at(while_body->root_instruction()); |
| int64 min_use_time = root_time; |
| for (const HloUse& parameter_use : parameter_value->uses()) { |
| int64 use_time = instruction_schedule.at(parameter_use.instruction); |
| if (parameter_use.instruction->opcode() != HloOpcode::kGetTupleElement && |
| parameter_use.instruction->opcode() != HloOpcode::kTuple && |
| parameter_use.instruction->opcode() != HloOpcode::kBitcast && |
| use_time > parameter_time) { |
| min_use_time = std::min(min_use_time, use_time); |
| } |
| } |
| // If there is no use of this buffer inside the while loop, there is no need |
| // to allocate it in the loop. |
| if (min_use_time == root_time) { |
| VLOG(4) << "While allocation not allowed in alternate memory. " |
| << "use time = " << min_use_time << ", root time = " << root_time; |
| return false; |
| } |
| const Shape& shape = parameter_value->shape(); |
| // Allow the buffer in alternate memory if the buffer has a short live range |
| // either at the beginning or end of the while loop body. |
| if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( |
| shape, parameter_time, min_use_time)) { |
| VLOG(4) << "While allocation not allowed in alternate memory. " |
| << "use time = " << min_use_time << ", root time = " << root_time; |
| return false; |
| } |
| // Check if there is a required assignment for the while loop output. |
| HloValue* while_value = |
| &alias_analysis_.dataflow_analysis().GetUniqueValueAt( |
| use.instruction, use.operand_index); |
| int64 while_time = instruction_schedule.at(use.instruction); |
| auto existing_required_assignment = |
| RequiredMemoryAssignmentAt(while_value, while_time); |
| if (existing_required_assignment) { |
| // TODO(berkin): Failing for now when the output is requested to be in |
| // alternate memory, and the buffer is a while loop output. |
| CHECK(existing_required_assignment->memory_space == MemorySpace::kDefault) |
| << "While loop buffers pinned to alternate memory not " |
| "currently supported."; |
| VLOG(4) << "While allocation not allowed in alternate memory because " |
| "there is a required default memory assignment."; |
| return false; |
| } |
| } else if (use.instruction->opcode() == HloOpcode::kConditional) { |
| // For any use of this conditional (the same value might be passed into |
| // multiple called computations), determine if the parameter->first use |
| // dependency is short. |
| int64 conditional_time = instruction_schedule.at(use.instruction); |
| for (const AllocationValue::Use& other_use : value.uses()) { |
| if (other_use.hlo_use.instruction != use.instruction) { |
| continue; |
| } |
| HloComputation* called_computation = |
| use.instruction->called_computations().at( |
| other_use.hlo_use.operand_number - 1); |
| const HloInstruction* parameter_instruction = |
| called_computation->parameter_instruction(0); |
| HloValue* parameter_value = |
| &alias_analysis_.dataflow_analysis().GetUniqueValueAt( |
| parameter_instruction, other_use.hlo_use.operand_index); |
| int64 parameter_time = instruction_schedule.at(parameter_instruction); |
| int64 min_use_time = conditional_time; |
| for (const HloUse& parameter_use : parameter_value->uses()) { |
| if (parameter_use.instruction->parent() == called_computation && |
| parameter_use.instruction->opcode() != |
| HloOpcode::kGetTupleElement && |
| parameter_use.instruction->opcode() != HloOpcode::kTuple && |
| parameter_use.instruction->opcode() != HloOpcode::kBitcast) { |
| min_use_time = std::min( |
| min_use_time, instruction_schedule.at(parameter_use.instruction)); |
| } |
| } |
| if (options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( |
| parameter_value->shape(), parameter_time, min_use_time)) { |
| VLOG(4) << "Conditional allocation allowed in alternate memory for " |
| "computation = " |
| << called_computation->name() |
| << ", parameter time = " << parameter_time |
| << ", min use time = " << min_use_time; |
| return true; |
| } else { |
| VLOG(4) << "Conditional allocation not allowed in alternate memory for " |
| "computation = " |
| << called_computation->name() |
| << ", parameter time = " << parameter_time |
| << ", min use time = " << min_use_time; |
| } |
| } |
| return false; |
| } |
| |
| return true; |
| } |
| |
| void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString( |
| const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, |
| std::string* debug_str) const { |
| // Columns in buffer information: |
| // buffer_id: int. This value can be used to match the allocation in |
| // allocation information. |
| // buffer_name: string. |
| // alt_mem_benefit: float. Roughly corresponds to how much the cost analysis |
| // thought it would be beneficial to put this in the alternate memory. The |
| // higher the value, the more it is memory bound. |
| // size: int. In bytes. |
| // definition_time: int. Logical time this value was defined in the schedule. |
| // use_times: string. This is a semicolon-separated list of integers for all |
| // the use times. |
| // use_names: string. This is a semicolon-separated list of string |
| // representation of uses. |
| if (debug_str->empty()) { |
| // Append the column names. |
| absl::StrAppend(debug_str, |
| "buffer_id,buffer_name,alt_mem_benefit,size," |
| "definition_time,use_times,use_names\n"); |
| } |
| const HloBuffer& buffer = |
| alias_analysis_.GetBufferContainingValue(*interval.buffer); |
| const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); |
| int64 definition_time = |
| instruction_schedule.at(interval.buffer->defining_position().instruction); |
| std::vector<std::pair<int64, std::string>> uses; |
| for (const HloValue* value : buffer.values()) { |
| for (const HloUse& use : value->uses()) { |
| uses.push_back( |
| {instruction_schedule.at(use.instruction), use.ToString()}); |
| } |
| } |
| absl::c_sort(uses); |
| std::vector<int64> use_times; |
| std::vector<std::string> use_names; |
| use_times.reserve(uses.size()); |
| use_names.reserve(uses.size()); |
| for (const auto& use : uses) { |
| use_times.push_back(use.first); |
| use_names.push_back(use.second); |
| } |
| |
| absl::StrAppend(debug_str, buffer.id(), ","); |
| absl::StrAppend(debug_str, "\"", interval.buffer->ToShortString(), "\","); |
| auto alternate_memory_benefit = |
| options_.prefetch_interval_picker->BufferIntervalAlternateMemoryBenefit( |
| interval); |
| absl::StrAppend( |
| debug_str, alternate_memory_benefit ? *alternate_memory_benefit : 0, ","); |
| absl::StrAppend(debug_str, interval.size, ","); |
| absl::StrAppend(debug_str, definition_time, ","); |
| absl::StrAppend(debug_str, "\"", absl::StrJoin(use_times, ";"), "\","); |
| absl::StrAppend(debug_str, "\"", absl::StrJoin(use_names, ";"), "\""); |
| absl::StrAppend(debug_str, "\n"); |
| } |
| |
| void AlternateMemoryBestFitHeap::AppendAllocationInfoDebugString( |
| const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, |
| const MemorySpaceAssignment::Allocation& allocation, |
| std::string* debug_str) const { |
| // Columns in allocation information: |
| // buffer_id: int. This value can be used the match with buffer info. |
| // size: int. In bytes. |
| // offset: int. In bytes. |
| // start_time: int. Logical start time of the allocation. |
| // end_time: int. Logical end time of the allocation. |
| if (debug_str->empty()) { |
| // Append the column names. |
| absl::StrAppend(debug_str, "buffer_id,size,offset,start_time,end_time\n"); |
| } |
| if (allocation.memory_space() == MemorySpace::kAlternate) { |
| const HloBuffer& buffer = |
| alias_analysis_.GetBufferContainingValue(*interval.buffer); |
| absl::StrAppend(debug_str, buffer.id(), ","); |
| absl::StrAppend(debug_str, interval.size, ","); |
| absl::StrAppend(debug_str, allocation.chunk().offset, ","); |
| absl::StrAppend(debug_str, allocation.start_time(), ","); |
| absl::StrAppend(debug_str, allocation.end_time(), "\n"); |
| } |
| } |
| |
| void AlternateMemoryBestFitHeap::DumpDebugStringsIfEnabled() const { |
| if (!options_.dump_fn) { |
| return; |
| } |
| options_.dump_fn("bufferinfo", buffer_info_str_); |
| options_.dump_fn("allocinfo", allocation_info_str_); |
| } |
| |
| HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { |
| std::vector<BufferInterval> sorted_buffer_intervals = |
| GetSortedBufferIntervals(); |
| |
| VLOG(1) << "Assigning buffers to alternate memory. Max heap size = " |
| << options_.max_size_in_bytes; |
| |
| AddInputAndOutputRequiredAssignments(); |
| |
| if (VLOG_IS_ON(3)) { |
| VLOG(3) << "Flattened instruction sequence:"; |
| const auto& instruction_sequence = |
| hlo_live_range_.flattened_instruction_sequence().instructions(); |
| for (int i = 0; i < instruction_sequence.size(); ++i) { |
| VLOG(3) << " " << i << ": " << instruction_sequence[i]->parent()->name() |
| << " " << instruction_sequence[i]->name(); |
| } |
| } |
| |
| for (auto& interval : sorted_buffer_intervals) { |
| if (!interval.need_allocation) { |
| continue; |
| } |
| |
| if (!MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( |
| interval)) { |
| continue; |
| } |
| |
| HloInstruction* inst = interval.buffer->instruction(); |
| HloModule* module = inst->GetModule(); |
| |
| // Don't intra-program prefetch a cross program prefetch |
| if (inst->opcode() == HloOpcode::kParameter && |
| absl::c_count(module->CrossProgramPrefetches(), |
| std::make_pair(inst->parameter_number(), |
| interval.buffer->index())) > 0) { |
| VLOG(3) << "Skip " << interval.buffer->ToShortString() |
| << " because it is cross-program prefetched."; |
| continue; |
| } |
| |
| auto colocated_intervals = GetSortedColocatedIntervals(interval); |
| |
| if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) { |
| VLOG(3) << "Interval " << interval.buffer->ToShortString() |
| << " is reserved in the alternate memory. Total reserved bytes = " |
| << reserved_in_bytes_; |
| for (const BufferInterval* colocated_interval : colocated_intervals) { |
| const HloValue* value = colocated_interval->buffer; |
| // Color all of the aliased reserved buffers here because reserved |
| // alternate memory allocations will not have an entry in preset |
| // allocations that is normally used for coloring. |
| for (auto& position : value->positions()) { |
| VLOG(4) << "Coloring " << position.ToString(); |
| Shape* shape = ShapeUtil::GetMutableSubshape( |
| position.instruction->mutable_shape(), position.index); |
| CHECK(shape->IsArray()) << "Coloring a shape that is not an array: " |
| << position.ToString(); |
| shape->mutable_layout()->set_memory_space( |
| options_.alternate_memory_space); |
| } |
| } |
| // Increment the reserved part of alternate memory so that it is not |
| // available for other buffers. Since all colocated intervals should have |
| // the same size, just use the first one. |
| reserved_in_bytes_ += options_.size_fn(*colocated_intervals[0]->buffer); |
| continue; |
| } |
| |
| if (colocated_intervals.size() > 1 && |
| !options_.allocate_across_sequential_calls) { |
| VLOG(4) << "Not allocating " << interval.buffer->ToShortString() |
| << " because it aliases with another interval and " |
| << " allocate_across_sequential_calls is false."; |
| continue; |
| } |
| |
| if (!ConsumeFuel("memory_space_assignment", [&] { |
| return absl::StrCat("Ran out of fuel at buffer: ", |
| colocated_intervals[0]->buffer->ToShortString()); |
| })) { |
| continue; |
| } |
| |
| AppendBufferInfoDebugString(interval, &buffer_info_str_); |
| |
| // Retry allocating this value with larger limits if allocation fails. |
| for (int retry_number = 0; retry_number < options_.max_retries; |
| retry_number++) { |
| final_retry_ = (retry_number == options_.max_retries - 1); |
| options_.prefetch_interval_picker->SetRetryNumber(retry_number); |
| bool success = AllocateColocatedIntervals(colocated_intervals); |
| if (success) { |
| break; |
| } |
| VLOG(2) << "Couldn't allocate. Retry number " << retry_number; |
| } |
| } |
| |
| VLOG(3) << "Debug buffer info: "; |
| VLOG(3) << buffer_info_str_; |
| VLOG(3) << "Debug allocation info: "; |
| VLOG(3) << allocation_info_str_; |
| DumpDebugStringsIfEnabled(); |
| |
| return result_; |
| } |
| |
| bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals( |
| const std::vector<const AlternateMemoryBestFitHeap::BufferInterval*>& |
| colocated_intervals) { |
| // TODO(berkin): For now, place the phi values due to conditionals in |
| // default memory. |
| for (const BufferInterval* colocated_interval : colocated_intervals) { |
| const HloValue* value = colocated_interval->buffer; |
| for (const auto& position : value->positions()) { |
| if (position.instruction->opcode() == HloOpcode::kConditional) { |
| VLOG(3) << "Adding required assignment for condition output: " |
| << value->ToShortString(); |
| AddRequiredAssignment(position.instruction, position.index, |
| MemorySpace::kDefault); |
| for (const HloComputation* called_computation : |
| position.instruction->called_computations()) { |
| AddRequiredAssignment(called_computation->root_instruction(), |
| position.index, MemorySpace::kDefault); |
| } |
| } |
| } |
| } |
| |
| // Create AllocationValues for all the colocated intervals. |
| std::vector<AllocationValue> allocation_values; |
| for (const auto& colocated_interval : colocated_intervals) { |
| CreateAllocationValues(colocated_interval->buffer, &allocation_values); |
| } |
| FindAliases(&allocation_values); |
| const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); |
| |
| // Data structure to contain the preferred offset for a given computation. |
| // We ensure that the same offset will be allocated outside the while loop |
| // as well as inside the while loop. |
| absl::flat_hash_map<const HloComputation*, int64> |
| preferred_offset_for_computation; |
| |
| bool allocation_success = true; |
| for (auto& allocation_value : allocation_values) { |
| int64 definition_time = |
| instruction_schedule.at(allocation_value.defining_instruction()); |
| |
| absl::optional<int64> preferred_offset; |
| auto preferred_offset_it = |
| preferred_offset_for_computation.find(allocation_value.computation()); |
| if (preferred_offset_it != preferred_offset_for_computation.end()) { |
| preferred_offset = preferred_offset_it->second; |
| } |
| |
| // Iterate over the uses. |
| for (int use_idx = 0; use_idx < allocation_value.uses().size(); ++use_idx) { |
| const AllocationValue::Use& use = allocation_value.uses().at(use_idx); |
| const HloUse hlo_use = use.hlo_use; |
| int64 use_time = instruction_schedule.at(hlo_use.instruction); |
| int64 latest_prefetch_time = use_time; |
| bool allow_no_copy_alternate_mem_allocation = true; |
| absl::optional<int64> earliest_prefetch_time = absl::nullopt; |
| |
| // Sequential calls include kWhile, kCall, and kConditional opcodes. |
| bool is_sequential_call = |
| (GetInstructionCallContext(hlo_use.instruction->opcode()) == |
| CallContext::kSequential); |
| if (is_sequential_call) { |
| for (const HloComputation* called_computation : |
| hlo_use.instruction->called_computations()) { |
| const HloLiveRange::TimeBound& computation_span = |
| hlo_live_range_.computation_span_times().at(called_computation); |
| latest_prefetch_time = |
| std::min(computation_span.start - 1, latest_prefetch_time); |
| } |
| if (hlo_use.instruction->opcode() == HloOpcode::kWhile) { |
| // Given an example while loop and flattened schedule (logical times |
| // shown on the left): |
| // |
| // 0: a = ... |
| // 1: ... |
| // cond { |
| // 2: p = param(0) |
| // 3: ... |
| // } |
| // body { |
| // 4: p = param(0) |
| // 5: ... |
| // 6: ROOT ... |
| // } |
| // 7: w = while(a), body=body, cond=cond |
| // |
| // When processing "a" (time 0) and its while use (time 7), we update |
| // the interval to time 0-4. This is so that the remaining interval |
| // (5-6) can be allocated separately and this buffer doesn't waste |
| // alternate memory space within the while loop body. |
| HloComputation* while_body = hlo_use.instruction->while_body(); |
| // We require while body ROOTs to be the last in the schedule. |
| CHECK_EQ(instruction_schedule.at(while_body->root_instruction()) + 1, |
| instruction_schedule.at(hlo_use.instruction)) |
| << "While body ROOTs need to be the last in the schedule! " |
| "Please run RootInstructionSinker."; |
| // Replace the use time with the parameter time so that we can decide |
| // on alternate memory allocations within the while loop body when we |
| // look at uses within the while loop body. |
| use_time = |
| instruction_schedule.at(while_body->parameter_instruction(0)); |
| } else if (hlo_use.instruction->opcode() == HloOpcode::kConditional) { |
| // Replace the use time with the earliest parameter of called |
| // computations. |
| for (const HloComputation* called_computation : |
| hlo_use.instruction->called_computations()) { |
| use_time = std::min( |
| use_time, instruction_schedule.at( |
| called_computation->parameter_instruction(0))); |
| } |
| } |
| } |
| |
| // Add a required assignment in default memory if the use not allowed in |
| // alternate memory. |
| if (!IsUseAllowedInAlternateMemory(allocation_value, hlo_use)) { |
| AddRequiredAssignment(allocation_value.value(), hlo_use.instruction, |
| MemorySpace::kDefault, use_time); |
| } else if (use_idx > 0) { |
| // We allow buffers in alternate memory that are passed into |
| // conditionals to give up their alternate memory allocation inside the |
| // called computation. This means that if a conditional operator has an |
| // alternate memory allocation, subsequent uses cannot use the same |
| // alternate memory allocation in order not to clobber data. So we force |
| // default memory allocation for these subsequent uses. |
| const AllocationValue::Use& previous_use = |
| allocation_value.uses().at(use_idx - 1); |
| if (previous_use.hlo_use.instruction->opcode() == |
| HloOpcode::kConditional && |
| previous_use.hlo_use.instruction != hlo_use.instruction) { |
| allow_no_copy_alternate_mem_allocation = false; |
| earliest_prefetch_time = |
| instruction_schedule.at(previous_use.hlo_use.instruction); |
| VLOG(3) << "Previous use (" << previous_use.hlo_use.ToString() |
| << ") of use (" << hlo_use.ToString() |
| << ") is a conditional, so this use will need to evict. " |
| << "Earliest prefetch time = " << *earliest_prefetch_time; |
| } |
| } |
| |
| // Bitcasts don't define buffers and don't directly consume buffers. Skip |
| // allocating buffers for bitcast uses. The uses that feed from bitcasts |
| // will be handled specially. |
| if (hlo_use.instruction->opcode() != HloOpcode::kBitcast) { |
| AllocationRequest request; |
| // Rarely, (e.g., when conditional true and false parameters are the |
| // same), definition time can be the time of the conditional and use |
| // time is the parameter use, which is less. |
| request.start_time = std::min(definition_time, use_time); |
| request.end_time = use_time; |
| request.latest_prefetch_time = latest_prefetch_time; |
| request.size = colocated_intervals[0]->size; |
| request.allow_no_copy_alternate_mem_allocation = |
| allow_no_copy_alternate_mem_allocation; |
| request.earliest_prefetch_time = earliest_prefetch_time; |
| request.preferred_offset = preferred_offset; |
| request.use = &use; |
| request.allocation_value = &allocation_value; |
| if (!AllocateSegment(request)) { |
| // If the allocation finding failed (e.g., due to running out of |
| // asynchronous copies), then fall back to allocating the buffer |
| // entirely in the default memory. |
| UncommitPendingChunks(); |
| allocation_success = false; |
| break; |
| } |
| |
| // If there are multiple uses, they can try using the memory allocation |
| // already at the alternate memory. |
| definition_time = instruction_schedule.at(hlo_use.instruction); |
| } |
| |
| // Propagate the allocation to any aliases this use might have had. |
| MemorySpaceAssignment::Allocation* aliased_allocation = |
| GetLiveAllocationAt(*allocation_value.allocation_sequence(), |
| use_time); |
| for (const HloPosition& aliased_position : use.aliases) { |
| AddAliasedRequiredAssignment(aliased_position.instruction, |
| aliased_position.index, |
| aliased_allocation); |
| } |
| |
| // Special case for while loops since the root offset must agree with |
| // other offsets: remember the preferred offset for the while loop body. |
| if (hlo_use.instruction->opcode() == HloOpcode::kWhile && |
| aliased_allocation->memory_space() == MemorySpace::kAlternate) { |
| preferred_offset_for_computation[hlo_use.instruction->while_body()] = |
| aliased_allocation->chunk().offset; |
| } |
| } |
| if (!allocation_success) { |
| break; |
| } |
| } |
| if (allocation_success) { |
| for (AllocationValue& allocation_value : allocation_values) { |
| for (auto& allocation : *allocation_value.allocation_sequence()) { |
| AppendAllocationInfoDebugString(*colocated_intervals[0], *allocation, |
| &allocation_info_str_); |
| allocations_->push_back(std::move(allocation)); |
| } |
| } |
| } |
| |
| pending_chunks_.clear(); |
| pending_async_copies_.clear(); |
| pending_required_assignments_.clear(); |
| return allocation_success; |
| } |
| |
| bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) { |
| return (a.start_time < b.start_time && a.end_time <= b.end_time) || |
| (a.start_time <= b.start_time && a.end_time < b.end_time); |
| } |
| |
| void AsynchronousCopyOrdering::AddCopy(const AsynchronousCopy& copy) { |
| auto it_and_inserted = ranges_.insert(copy); |
| CHECK(it_and_inserted.second || |
| it_and_inserted.first->start_time == copy.start_time); |
| } |
| |
| void AsynchronousCopyOrdering::RemoveCopy(const AsynchronousCopy& copy) { |
| auto copy_it = ranges_.find(copy); |
| CHECK(copy_it != ranges_.end()); |
| ranges_.erase(copy_it); |
| } |
| |
| bool AsynchronousCopyOrdering::ViolatesOrdering(int64 start_time, |
| int64 end_time) const { |
| // We allow identical start and end times. It is enough to check for just the |
| // start time in case we find a match in ranges_ because the found value will |
| // either be identical to {start_time, end_time} (and this doesn't violate) or |
| // its start_time will be smaller and end_time will be larger (this violates). |
| auto copy_it = ranges_.find( |
| {start_time, end_time, MemorySpaceAssignment::MemorySpace::kAlternate}); |
| return copy_it != ranges_.end() && copy_it->start_time != start_time; |
| } |
| |
| /*static*/ MemorySpaceAssignment::Allocation* |
| AlternateMemoryBestFitHeap::GetLiveAllocationAt( |
| const MemorySpaceAssignment::AllocationSequence& allocations, int64 time) { |
| for (auto allocation_it = allocations.rbegin(); |
| allocation_it != allocations.rend(); ++allocation_it) { |
| if ((*allocation_it)->start_time() <= time && |
| (*allocation_it)->end_time() >= time) { |
| return allocation_it->get(); |
| } |
| } |
| return nullptr; |
| } |
| |
| void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer( |
| HloModule* module, absl::optional<BufferInterval> prefetch_candidate) { |
| if (!prefetch_candidate) { |
| return; |
| } |
| |
| ChunkCandidate chunk_candidate = FindChunkCandidate(*prefetch_candidate); |
| if (chunk_candidate.chunk.offset != 0 || |
| chunk_candidate.heap_size > available_heap_size()) { |
| LOG(WARNING) |
| << "Could not allocate preferred memory for cross program prefetch"; |
| return; |
| } |
| AddToPendingChunks(*prefetch_candidate, chunk_candidate); |
| |
| const HloValue* buffer = prefetch_candidate->buffer; |
| int64 parameter = buffer->instruction()->parameter_number(); |
| module->AddCrossProgramPrefetch(parameter, buffer->index()); |
| |
| MemorySpaceAssignment::AllocationSequence allocations; |
| allocations.push_back(absl::make_unique<MemorySpaceAssignment::Allocation>( |
| buffer->defining_position(), MemorySpace::kDefault, kDummyChunk, |
| prefetch_candidate->start, prefetch_candidate->end)); |
| |
| // Find the earliest use. |
| const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); |
| auto uses = buffer->uses(); |
| auto first_use = |
| absl::c_min_element(uses, [&](const HloUse& lhs, const HloUse& rhs) { |
| return instruction_schedule.at(lhs.instruction) < |
| instruction_schedule.at(rhs.instruction); |
| }); |
| int64 latest_prefetch_time = instruction_schedule.at(first_use->instruction); |
| |
| AddAsyncCopy(*allocations.back(), MemorySpace::kAlternate, |
| chunk_candidate.chunk, prefetch_candidate->start, |
| prefetch_candidate->end, latest_prefetch_time, &allocations); |
| absl::c_for_each(uses, [&](auto& use) { allocations.back()->AddUse(use); }); |
| for (auto& allocation : allocations) { |
| allocations_->push_back(std::move(allocation)); |
| } |
| |
| pending_chunks_.clear(); |
| pending_async_copies_.clear(); |
| pending_required_assignments_.clear(); |
| } |
| |
| absl::optional<RequiredMemoryAssignment> |
| AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer, |
| int64 time) const { |
| auto required_assignment_it = required_assignments_.find(buffer); |
| absl::optional<RequiredMemoryAssignment> required_assignment_at_time; |
| if (required_assignment_it != required_assignments_.end()) { |
| for (const RequiredMemoryAssignment& required_assignment : |
| required_assignment_it->second) { |
| if (required_assignment.time == time) { |
| // Sanity check that there is only one required at time. |
| CHECK(!required_assignment_at_time); |
| required_assignment_at_time = required_assignment; |
| } |
| } |
| } |
| return required_assignment_at_time; |
| } |
| |
| absl::optional<RequiredMemoryAssignment> |
| AlternateMemoryBestFitHeap::AliasedRequiredAssignmentForUse( |
| const AllocationValue::Use& use) const { |
| absl::optional<RequiredMemoryAssignment> required_assignment; |
| for (const HloPosition& position : use.aliases) { |
| const HloValue* value = |
| &alias_analysis_.dataflow_analysis().GetUniqueValueAt( |
| position.instruction, position.index); |
| int64 time = |
| hlo_live_range_.instruction_schedule().at(position.instruction); |
| absl::optional<RequiredMemoryAssignment> required_assignment_for_alias = |
| RequiredMemoryAssignmentAt(value, time); |
| if (required_assignment == absl::nullopt) { |
| required_assignment = required_assignment_for_alias; |
| } else { |
| CHECK(required_assignment_for_alias == absl::nullopt || |
| required_assignment->equals_ignoring_time( |
| *required_assignment_for_alias)); |
| } |
| } |
| return required_assignment; |
| } |
| |
| void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment( |
| const HloInstruction* instruction, ShapeIndex index, |
| const MemorySpaceAssignment::Allocation* aliased_allocation) { |
| absl::optional<Chunk> chunk; |
| if (aliased_allocation->memory_space() == MemorySpace::kAlternate) { |
| chunk = aliased_allocation->chunk(); |
| } |
| AddRequiredAssignment(instruction, index, aliased_allocation->memory_space(), |
| chunk); |
| } |
| |
| void AlternateMemoryBestFitHeap::AddRequiredAssignment( |
| const HloValue* value, const HloInstruction* instruction, |
| MemorySpaceAssignment::MemorySpace memory_space, int64 time, |
| absl::optional<HeapSimulator::Chunk> chunk) { |
| // Check for existing required assignment at this time and make sure it is the |
| // same as this if there is one. |
| auto existing_required_assignment = RequiredMemoryAssignmentAt(value, time); |
| if (existing_required_assignment) { |
| CHECK(memory_space == existing_required_assignment->memory_space) |
| << "inst = " << instruction->ToString() << " at " << time; |
| CHECK((!chunk && !existing_required_assignment->chunk) || |
| chunk->offset == existing_required_assignment->chunk->offset); |
| VLOG(3) << "Not adding required assignment because there is one already: " |
| << value->ToShortString() << " at " << time << " at " |
| << (memory_space == MemorySpace::kDefault ? "def" : "alt"); |
| } else { |
| VLOG(3) << "Adding required assignment: " << value->ToShortString() |
| << " at " << time << " at " |
| << (memory_space == MemorySpace::kDefault ? "def" : "alt"); |
| RequiredMemoryAssignment required_assignment{memory_space, time, chunk}; |
| required_assignments_[value].push_back(required_assignment); |
| pending_required_assignments_.push_back({value, required_assignment}); |
| } |
| } |
| |
| void AlternateMemoryBestFitHeap::AddRequiredAssignment( |
| const HloInstruction* instruction, ShapeIndex index, |
| MemorySpace memory_space, absl::optional<Chunk> chunk) { |
| const HloValue* value = |
| &alias_analysis_.dataflow_analysis().GetUniqueValueAt(instruction, index); |
| int64 instruction_time = |
| hlo_live_range_.instruction_schedule().at(instruction); |
| AddRequiredAssignment(value, instruction, memory_space, instruction_time, |
| chunk); |
| } |
| |
| void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() { |
| // Go through the parameters and outputs and pin them to the corresponding |
| // memory by adding a required assignment. |
| const HloModule& module = alias_analysis_.dataflow_analysis().module(); |
| const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); |
| HloComputation* entry_computation = module.entry_computation(); |
| for (HloInstruction* parameter_instruction : |
| entry_computation->parameter_instructions()) { |
| int64 parameter_instruction_time = |
| instruction_schedule.at(parameter_instruction); |
| ShapeUtil::ForEachSubshape( |
| parameter_instruction->shape(), |
| [&](const Shape& subshape, const ShapeIndex& index) { |
| MemorySpace memory_space = MemorySpace::kDefault; |
| if (subshape.has_layout() && subshape.layout().memory_space() == |
| options_.alternate_memory_space) { |
| memory_space = MemorySpace::kAlternate; |
| } |
| for (const HloBuffer* buffer : |
| alias_analysis_.ComputeBuffersAt(parameter_instruction, index)) { |
| for (const HloValue* value : buffer->values()) { |
| VLOG(3) << "Adding required assignment for parameter value = " |
| << value->ToShortString() |
| << " time = " << parameter_instruction_time << " space = " |
| << (memory_space == MemorySpace::kDefault ? "def" |
| : "alt"); |
| required_assignments_[value].push_back( |
| {memory_space, /*time=*/parameter_instruction_time}); |
| } |
| } |
| }); |
| } |
| HloInstruction* root_instruction = entry_computation->root_instruction(); |
| int64 root_instruction_time = instruction_schedule.at(root_instruction); |
| ShapeUtil::ForEachSubshape( |
| root_instruction->shape(), |
| [&](const Shape& subshape, const ShapeIndex& index) { |
| MemorySpace memory_space = MemorySpace::kDefault; |
| if (subshape.has_layout() && subshape.layout().memory_space() == |
| options_.alternate_memory_space) { |
| memory_space = MemorySpace::kAlternate; |
| } |
| for (const HloBuffer* buffer : |
| alias_analysis_.ComputeBuffersAt(root_instruction, index)) { |
| for (const HloValue* value : buffer->values()) { |
| VLOG(3) << "Adding required assignment for output value = " |
| << value->ToShortString() |
| << " time = " << root_instruction_time << " space = " |
| << (memory_space == MemorySpace::kDefault ? "def" : "alt"); |
| required_assignments_[value].push_back( |
| {memory_space, /*time=*/root_instruction_time}); |
| } |
| } |
| }); |
| } |
| |
| bool AlternateMemoryBestFitHeap::AreIntervalsReservedInAlternateMemory( |
| absl::Span<const BufferInterval* const> colocated_intervals) const { |
| auto is_position_in_alternate_memory = [&](const HloPosition& position) { |
| const Shape& shape = position.shape(); |
| return shape.has_layout() && |
| shape.layout().memory_space() == options_.alternate_memory_space; |
| }; |
| |
| const HloModule& module = alias_analysis_.dataflow_analysis().module(); |
| const HloComputation* entry_computation = module.entry_computation(); |
| const HloInstruction* root_instruction = |
| entry_computation->root_instruction(); |
| for (const BufferInterval* colocated_interval : colocated_intervals) { |
| const HloValue* value = colocated_interval->buffer; |
| if (value->defining_instruction()->opcode() == HloOpcode::kParameter && |
| value->defining_instruction()->parent() == entry_computation && |
| is_position_in_alternate_memory(value->defining_position())) { |
| return true; |
| } |
| |
| for (const HloPosition& position : value->positions()) { |
| if (position.instruction == root_instruction && |
| is_position_in_alternate_memory(position)) { |
| return true; |
| } |
| } |
| } |
| return false; |
| } |
| |
| void AlternateMemoryBestFitHeap::UncommitPendingChunks() { |
| for (const auto& interval_and_chunk : pending_chunks_) { |
| const BufferInterval& interval = interval_and_chunk.first; |
| const Chunk& chunk = interval_and_chunk.second.chunk; |
| VLOG(3) << "Uncommitting: (" << interval.start << ", " << interval.end |
| << ") off = " << chunk.offset << " size = " << chunk.size; |
| interval_tree_.Remove(interval.start, interval.end, chunk); |
| } |
| for (const auto& interval : pending_async_copies_) { |
| if (interval.destination == MemorySpace::kAlternate) { |
| prefetch_interval_tree_.Remove(interval.start_time, interval.end_time, |
| kDummyChunk); |
| async_copy_ordering_.RemoveCopy(interval); |
| } else { |
| eviction_interval_tree_.Remove(interval.start_time, interval.end_time, |
| kDummyChunk); |
| } |
| } |
| for (const auto& value_and_required_assignment : |
| pending_required_assignments_) { |
| auto& required_assignment_vector = |
| required_assignments_[value_and_required_assignment.first]; |
| const RequiredMemoryAssignment& required_assignment = |
| value_and_required_assignment.second; |
| VLOG(3) << "Removing required assignment: " |
| << (required_assignment.memory_space == MemorySpace::kDefault |
| ? "def" |
| : "alt") |
| << " time = " << required_assignment.time << " off = " |
| << (required_assignment.chunk ? required_assignment.chunk->offset |
| : -1); |
| for (auto it = required_assignment_vector.begin(); |
| it != required_assignment_vector.end(); ++it) { |
| if (*it == value_and_required_assignment.second) { |
| required_assignment_vector.erase(it); |
| break; |
| } |
| } |
| } |
| pending_chunks_.clear(); |
| pending_async_copies_.clear(); |
| pending_required_assignments_.clear(); |
| } |
| |
| void AlternateMemoryBestFitHeap::AddToPendingChunks( |
| const BufferInterval& buffer_interval, |
| const ChunkCandidate& chunk_candidate) { |
| VLOG(3) << "Committing chunk: " << buffer_interval.start << "-" |
| << buffer_interval.end << " : [" << chunk_candidate.chunk.offset |
| << ", " << chunk_candidate.chunk.size << "]"; |
| pending_chunks_.emplace_back(buffer_interval, chunk_candidate); |
| CommitChunk(buffer_interval, chunk_candidate); |
| } |
| |
| bool AlternateMemoryBestFitHeap::AllocateSegment( |
| const AllocationRequest& request) { |
| auto allocation_sequence = request.allocation_value->allocation_sequence(); |
| // start_time == end_time is a special case where the value is consumed |
| // multiple times by the same instruction. We can just find the previous |
| // allocation and use that allocation. |
| if (request.start_time == request.end_time) { |
| MemorySpaceAssignment::Allocation* allocation = |
| GetLiveAllocationAt(*allocation_sequence, request.end_time); |
| CHECK_NE(allocation, nullptr); |
| allocation->AddUse(request.use->hlo_use); |
| return true; |
| } |
| |
| const HloPosition& defining_position = |
| request.allocation_value->defining_position(); |
| VLOG(2) << "Finding allocation for " |
| << request.allocation_value->ToShortString() << " (" |
| << request.start_time << ", " << request.end_time |
| << ") latest prefetch = " << request.latest_prefetch_time |
| << " last use = " << request.allocation_value->uses().back().time |
| << " use = " << request.use->hlo_use.ToString() |
| << ". Size = " << request.size |
| << ", def pos = " << defining_position.ToString(); |
| CHECK_LE(request.start_time, request.end_time); |
| |
| // There could be a requirement to pin this buffer to default memory either |
| // because it is a parameter or an output. If the buffer is a parameter, then |
| // we're allowed to prefetch. If the use expects the output to be in default |
| // memory, we cannot prefetch it because if we did, it would be in alternate |
| // memory instead. |
| auto required_assignment_at_start = RequiredMemoryAssignmentAt( |
| request.allocation_value->value(), request.start_time); |
| absl::optional<MemorySpace> required_memory_space_at_start; |
| if (required_assignment_at_start) { |
| required_memory_space_at_start = required_assignment_at_start->memory_space; |
| } |
| // Find required assignment both for the use and its aliases. If they are both |
| // non-nullopt, then make sure they require the same assignment. |
| auto required_assignment_at_end = RequiredMemoryAssignmentAt( |
| request.allocation_value->value(), request.end_time); |
| auto aliased_required_assignment_at_end = |
| AliasedRequiredAssignmentForUse(*request.use); |
| if (required_assignment_at_end != aliased_required_assignment_at_end) { |
| if (required_assignment_at_end == absl::nullopt) { |
| required_assignment_at_end = aliased_required_assignment_at_end; |
| } else { |
| CHECK(aliased_required_assignment_at_end == absl::nullopt || |
| aliased_required_assignment_at_end->equals_ignoring_time( |
| *required_assignment_at_end)); |
| } |
| } |
| absl::optional<MemorySpace> required_memory_space_at_end; |
| if (required_assignment_at_end) { |
| required_memory_space_at_end = required_assignment_at_end->memory_space; |
| } |
| |
| if (required_assignment_at_start) { |
| if (!allocation_sequence->empty() && |
| required_assignment_at_start->memory_space == MemorySpace::kAlternate) { |
| const auto& prev_allocation = allocation_sequence->back(); |
| CHECK(prev_allocation->memory_space() == |
| required_assignment_at_start->memory_space); |
| CHECK_EQ(prev_allocation->chunk().offset, |
| required_assignment_at_start->chunk->offset); |
| prev_allocation->Extend(request.start_time); |
| } else { |
| allocation_sequence->push_back( |
| absl::make_unique<MemorySpaceAssignment::Allocation>( |
| defining_position, required_assignment_at_start->memory_space, |
| required_assignment_at_start->chunk, request.start_time, |
| request.start_time)); |
| } |
| } |
| |
| // First try keeping the allocation entirely in the alternate memory. |
| if (required_memory_space_at_start != MemorySpace::kDefault && |
| required_memory_space_at_end != MemorySpace::kDefault && |
| request.allow_no_copy_alternate_mem_allocation && |
| AllocateInAlternateMemoryNoCopy(request)) { |
| return true; |
| } |
| |
| auto prev_allocation_it = allocation_sequence->rbegin(); |
| // Find a previous allocation that is in the default memory space (not |
| // necessarily the very last allocation). |
| auto prev_allocation_in_default_mem_it = std::find_if( |
| allocation_sequence->rbegin(), allocation_sequence->rend(), |
| [&](const auto& allocation) { |
| return allocation->memory_space() == MemorySpace::kDefault && |
| allocation->defining_position() == defining_position; |
| }); |
| |
| if (prev_allocation_in_default_mem_it == allocation_sequence->rend() && |
| prev_allocation_it != allocation_sequence->rend() && |
| (*prev_allocation_it)->memory_space() == MemorySpace::kAlternate && |
| (*prev_allocation_it)->defining_position() == defining_position) { |
| // If there was an allocation for this HloValue that was in the alternate |
| // memory space, we also need to perform an eviction. |
| if (!Evict(request)) { |
| return false; |
| } |
| prev_allocation_in_default_mem_it = allocation_sequence->rbegin(); |
| } else if (prev_allocation_in_default_mem_it == allocation_sequence->rend()) { |
| allocation_sequence->push_back( |
| absl::make_unique<MemorySpaceAssignment::Allocation>( |
| defining_position, MemorySpace::kDefault, /*chunk=*/absl::nullopt, |
| request.start_time, request.end_time)); |
| prev_allocation_in_default_mem_it = allocation_sequence->rbegin(); |
| } |
| |
| CHECK(prev_allocation_in_default_mem_it != allocation_sequence->rend()); |
| CHECK((*prev_allocation_in_default_mem_it)->memory_space() == |
| MemorySpace::kDefault); |
| |
| // If the buffer must be in default memory at the end_time, don't prefetch. |
| if (required_memory_space_at_end == MemorySpace::kDefault) { |
| VLOG(3) |
| << "Not trying to prefetch because use requires buffer in default mem."; |
| (*prev_allocation_in_default_mem_it)->Extend(request.end_time); |
| (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use); |
| return true; |
| } |
| |
| // Finally, try to prefetch the buffer into alternate memory. |
| if (Prefetch(request, **prev_allocation_in_default_mem_it)) { |
| return true; |
| } |
| if (!final_retry_ && prefetch_failed_due_to_async_copy_) { |
| // If prefetching failed due to asynchronous copy and we're not in our final |
| // try, return false (failure) so that we can retry this interval with |
| // larger limits. |
| return false; |
| } |
| |
| // If the end assignment was required to be in alternate memory but that |
| // wasn't possible, then this allocation is invalid. |
| if (required_memory_space_at_end == MemorySpace::kAlternate) { |
| return false; |
| } |
| |
| // If a copy wasn't inserted, then add this use to the latest allocation in |
| // default memory. |
| (*prev_allocation_in_default_mem_it)->Extend(request.end_time); |
| (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use); |
| return true; |
| } |
| |
| void AlternateMemoryBestFitHeap::AddAsyncCopy( |
| const MemorySpaceAssignment::Allocation& prev_allocation, |
| MemorySpace memory_space, absl::optional<Chunk> chunk, int64 start_time, |
| int64 end_time, int64 copy_done_schedule_before_time, |
| MemorySpaceAssignment::AllocationSequence* allocations) { |
| VLOG(3) << "Copy to " |
| << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault |
| ? "default" |
| : "alternate") |
| << " memory between " << start_time << " and " |
| << copy_done_schedule_before_time << " keeping until " << end_time; |
| CHECK_LT(start_time, copy_done_schedule_before_time); |
| |
| allocations->push_back( |
| absl::make_unique<MemorySpaceAssignment::CopyAllocation>( |
| prev_allocation, memory_space, chunk, start_time, end_time, |
| copy_done_schedule_before_time)); |
| |
| // Register the additional async copy with the interval tree to keep track of |
| // the limit at any given time. |
| pending_async_copies_.push_back( |
| {start_time, copy_done_schedule_before_time, memory_space}); |
| if (memory_space == MemorySpaceAssignment::MemorySpace::kAlternate) { |
| prefetch_interval_tree_.Add(start_time, copy_done_schedule_before_time, |
| kDummyChunk); |
| async_copy_ordering_.AddCopy(pending_async_copies_.back()); |
| } else { |
| eviction_interval_tree_.Add(start_time, copy_done_schedule_before_time, |
| kDummyChunk); |
| } |
| } |
| |
| bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies( |
| int64 start_time, int64 end_time, bool is_prefetch, |
| int64 extra_async_copy_limit) const { |
| if (options_.max_outstanding_prefetches < 0 && is_prefetch) { |
| return false; |
| } |
| if (options_.max_outstanding_evictions < 0 && !is_prefetch) { |
| return false; |
| } |
| |
| // Count the prefetches/evictions in the interval tree for the given interval. |
| if (is_prefetch) { |
| int64 num_prefetches = |
| prefetch_interval_tree_.ChunksOverlappingInTime(start_time, end_time) |
| .size(); |
| return num_prefetches >= |
| options_.max_outstanding_prefetches + extra_async_copy_limit; |
| } else { |
| int64 num_evictions = |
| eviction_interval_tree_.ChunksOverlappingInTime(start_time, end_time) |
| .size(); |
| return num_evictions >= |
| options_.max_outstanding_evictions + extra_async_copy_limit; |
| } |
| } |
| |
| bool AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering( |
| int64 start_time, int64 end_time) const { |
| return async_copy_ordering_.ViolatesOrdering(start_time, end_time); |
| } |
| |
| bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( |
| const AllocationRequest& request) { |
| MemorySpaceAssignment::Allocation* prev_allocation = nullptr; |
| bool can_eliminate_copy = false; |
| if (request.allocation_value->allocation_sequence()->empty()) { |
| // There hasn't been any allocations for this interval so far. We can |
| // eliminate copy if the value can be placed in the alternate memory. |
| can_eliminate_copy = options_.is_allowed_in_alternate_mem_fn( |
| *request.allocation_value->value()); |
| } else { |
| // If there has been a previous allocation, we can eliminate the copy if the |
| // previous allocation was also in the alternate memory. |
| prev_allocation = |
| request.allocation_value->allocation_sequence()->back().get(); |
| can_eliminate_copy = |
| (prev_allocation->memory_space() == MemorySpace::kAlternate); |
| } |
| |
| if (!can_eliminate_copy) { |
| return false; |
| } |
| |
| const HloPosition& defining_position = |
| request.allocation_value->defining_position(); |
| if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( |
| defining_position.shape(), request.start_time + 1, |
| request.end_time)) { |
| return false; |
| } |
| |
| BufferInterval alternate_mem_interval; |
| alternate_mem_interval.buffer = request.allocation_value->value(); |
| alternate_mem_interval.size = request.size; |
| alternate_mem_interval.end = request.end_time; |
| alternate_mem_interval.start = request.start_time; |
| |
| // Prefer the offset that was previously used for the previous allocation. |
| absl::optional<int64> preferred_offset; |
| if (prev_allocation != nullptr) { |
| preferred_offset = prev_allocation->chunk().offset; |
| // If there is a previous allocation, set the start time one after the end |
| // of the previous allocation's end. |
| alternate_mem_interval.start = prev_allocation->end_time() + 1; |
| } |
| |
| if (request.preferred_offset) { |
| // Sanity check that if there is a preferred offset provided in the request, |
| // it matches with the previous allocation. |
| CHECK(!preferred_offset || request.preferred_offset == preferred_offset) |
| << "preferred_offset = " << *preferred_offset |
| << ", request.preferred_offset = " << *request.preferred_offset; |
| preferred_offset = request.preferred_offset; |
| } |
| |
| VLOG(3) << "We can eliminate copy to alternate memory. Preferred offset = " |
| << (preferred_offset ? *preferred_offset : -1); |
| // In case there are additional uses after this use, we rely on the last use |
| // time to try to reserve a chunk in the heap simulator. This is to prevent |
| // the following scenario: |
| // |
| // +-------+ |
| // / \ |
| // Producer--->Use1 +-->Use2 |
| // +---------+---------+ |
| // New buffer: | | | |
| // +---------+---------+ |
| // |
| // +-----------+ |
| // Current heap: | offset: 0 | |
| // --------------------------+-----------+------ |
| // |
| // Because we allocate buffers greedily, Producer to Use1 segment first, and |
| // then Use1 to Use2 segment, it is possible to allocate the first segment at |
| // an offset that is available for the first segment (e.g. offset 0) but not |
| // for the entire live range. This can result in unnecessary copies. By using |
| // the last use time, we try to find an allocation that is available for the |
| // entire Producer to Use2 range. |
| absl::optional<ChunkCandidate> chunk_candidate = FindBestChunkCandidate( |
| request, preferred_offset, &alternate_mem_interval); |
| // Check if the new heap size fits within limits. Also ensure if a |
| // preferred offset was provided, that offset was used. |
| if (chunk_candidate) { |
| VLOG(3) << "Keep the buffer in alternate memory. Offset = " |
| << chunk_candidate->chunk.offset |
| << ", size = " << chunk_candidate->chunk.size |
| << ", heap_size = " << chunk_candidate->heap_size |
| << ", prefetch picker = " |
| << options_.prefetch_interval_picker->ToNoCopyDebugString( |
| defining_position.shape(), request.start_time, |
| request.end_time); |
| AddToPendingChunks(alternate_mem_interval, *chunk_candidate); |
| |
| // If there was a previous allocation, the buffer location is the |
| // same as the previous. Otherwise, it is the operand. |
| if (prev_allocation != nullptr && |
| (prev_allocation->is_copy_allocation() || |
| prev_allocation->defining_position() == defining_position)) { |
| prev_allocation->Extend(request.end_time); |
| } else { |
| request.allocation_value->allocation_sequence()->push_back( |
| absl::make_unique<MemorySpaceAssignment::Allocation>( |
| defining_position, MemorySpace::kAlternate, |
| chunk_candidate->chunk, request.start_time, request.end_time)); |
| } |
| request.allocation_value->allocation_sequence()->back()->AddUse( |
| request.use->hlo_use); |
| return true; |
| } |
| return false; |
| } |
| |
| bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { |
| CHECK_GT(request.allocation_value->allocation_sequence()->size(), 0); |
| MemorySpaceAssignment::Allocation* prev_allocation = |
| request.allocation_value->allocation_sequence()->back().get(); |
| int64 eviction_start_time = prev_allocation->start_time(); |
| int64 eviction_end_time = prev_allocation->end_time(); |
| CHECK(eviction_start_time <= eviction_end_time); |
| |
| int64 preferred_eviction_end_time = |
| std::max(options_.prefetch_interval_picker->PreferredEvictionEndTime( |
| request.allocation_value->defining_position().shape(), |
| eviction_start_time, request.end_time), |
| eviction_end_time); |
| // Evictions must complete by the time of this use. |
| preferred_eviction_end_time = |
| std::min(preferred_eviction_end_time, request.latest_prefetch_time); |
| |
| BufferInterval eviction_mem_interval; |
| eviction_mem_interval.buffer = request.allocation_value->value(); |
| eviction_mem_interval.size = request.size; |
| // Try to reserve a buffer from the end of the previous allocation to the |
| // preferred eviction end time. |
| eviction_mem_interval.start = eviction_end_time + 1; |
| eviction_mem_interval.end = preferred_eviction_end_time; |
| int64 preferred_offset = prev_allocation->chunk().offset; |
| VLOG(3) << "Eviction (" << eviction_start_time << ", " << eviction_end_time |
| << ") preferred end time = " << eviction_mem_interval.end; |
| |
| for (; eviction_mem_interval.end > eviction_end_time; |
| --eviction_mem_interval.end) { |
| ChunkCandidate chunk_candidate = |
| FindChunkCandidate(eviction_mem_interval, preferred_offset); |
| if (chunk_candidate.chunk.offset == preferred_offset) { |
| AddToPendingChunks(eviction_mem_interval, chunk_candidate); |
| break; |
| } |
| } |
| eviction_end_time = eviction_mem_interval.end; |
| |
| VLOG(3) << "Evicting buffer at " << prev_allocation->chunk().offset << " (" |
| << eviction_start_time << ", " << eviction_end_time << ")"; |
| |
| bool eviction_interval_too_short = (eviction_start_time == eviction_end_time); |
| bool eviction_violates_outstanding_copies = |
| ViolatesMaximumOutstandingAsyncCopies(eviction_start_time, |
| eviction_end_time, |
| /*is_prefetch=*/false); |
| |
| // See if this interval would violate the asynchronous copy limit. |
| if (!eviction_interval_too_short && !eviction_violates_outstanding_copies) { |
| prev_allocation->Extend(eviction_end_time); |
| AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, |
| /*chunk=*/absl::nullopt, eviction_start_time, |
| prev_allocation->end_time(), eviction_end_time, |
| request.allocation_value->allocation_sequence()); |
| } else { |
| if (eviction_violates_outstanding_copies) { |
| VLOG(3) << "This violates the maximum async copies."; |
| } else { |
| VLOG(3) << "Eviction interval is too short (" << eviction_start_time |
| << ", " << eviction_end_time << ")."; |
| } |
| // If the original interval violated the limit, try sub-intervals within |
| // this interval. |
| bool eviction_scheduled = false; |
| for (int64 time = eviction_start_time; time < eviction_end_time; ++time) { |
| VLOG(4) << "Try evicting (" << time << ", " << time + 1 << ")"; |
| if (!ViolatesMaximumOutstandingAsyncCopies(time, time + 1, |
| /*is_prefetch=*/false)) { |
| VLOG(3) << "Eviction successful."; |
| AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, |
| /*chunk=*/absl::nullopt, time, time + 1, time + 1, |
| request.allocation_value->allocation_sequence()); |
| eviction_scheduled = true; |
| break; |
| } |
| } |
| |
| if (!eviction_scheduled) { |
| // If the eviction couldn't be scheduled, then fail. This buffer will be |
| // kept in the default memory. |
| VLOG(3) << "Bailing: Could not evict " << request.use->hlo_use.ToString() |
| << " because we hit the limit of maximum asynchronous copies " |
| << "between " |
| << hlo_live_range_.flattened_instruction_sequence() |
| .instructions()[eviction_start_time] |
| << " and " |
| << hlo_live_range_.flattened_instruction_sequence() |
| .instructions()[eviction_end_time]; |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| bool AlternateMemoryBestFitHeap::Prefetch( |
| const AllocationRequest& request, |
| const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem) { |
| // Try partially placing the buffer in the alternate space. The time that is |
| // overlapped will be used to asynchronously copy the buffer from the |
| // default memory to the alternate memory. |
| // |
| // start end |
| // time time |
| // X---------------------X |
| // Alternate: +------+ |
| // Default: +---------------------+ |
| // ^ ^ |
| // Copy Copy |
| // Start Done |
| int64 earliest_prefetch_time = |
| prev_allocation_in_default_mem.earliest_available_time(); |
| if (request.earliest_prefetch_time) { |
| earliest_prefetch_time = |
| std::max(earliest_prefetch_time, *request.earliest_prefetch_time); |
| } |
| options_.prefetch_interval_picker->Begin(request.use->hlo_use, |
| earliest_prefetch_time, |
| request.latest_prefetch_time); |
| VLOG(3) << "Trying prefetch picker = " |
| << options_.prefetch_interval_picker->ToDebugString(); |
| |
| // Create an alternate memory interval that starts at the earliest |
| // possible position, given by max_prefetch_interval. |
| BufferInterval alternate_mem_interval; |
| alternate_mem_interval.buffer = request.allocation_value->value(); |
| alternate_mem_interval.size = request.size; |
| // If any of the prefetch intervals couldn't be used due to number of |
| // outstanding async copy limit or async copy ordering, set |
| // prefetch_failed_due_to_async_copy_. |
| prefetch_failed_due_to_async_copy_ = false; |
| // While uses might be allowed to have additional outstanding prefetches. |
| int64 extra_async_copy_limit = |
| request.use->hlo_use.instruction->opcode() == HloOpcode::kWhile |
| ? options_.while_use_extra_outstanding_prefetch_limit |
| : 0; |
| while (!options_.prefetch_interval_picker->Done()) { |
| alternate_mem_interval.start = options_.prefetch_interval_picker->Next(); |
| CHECK_LT(alternate_mem_interval.start, request.latest_prefetch_time); |
| VLOG(4) << "Trying alternate memory allocation (" |
| << alternate_mem_interval.start << ", " << request.end_time << ")"; |
| // If this additional asynchronous copy would violate the limit, try a |
| // different interval. |
| if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start, |
| request.latest_prefetch_time)) { |
| VLOG(4) << "This would violate asynchronous copy ordering."; |
| prefetch_failed_due_to_async_copy_ = true; |
| continue; |
| } |
| if (ViolatesMaximumOutstandingAsyncCopies( |
| alternate_mem_interval.start, request.latest_prefetch_time, |
| /*is_prefetch=*/true, extra_async_copy_limit)) { |
| VLOG(4) << "This would violate the outstanding async copy limit."; |
| prefetch_failed_due_to_async_copy_ = true; |
| continue; |
| } |
| |
| auto chunk_candidate = FindBestChunkCandidate( |
| request, request.preferred_offset, &alternate_mem_interval); |
| // Check if we could find a suitable chunk. |
| if (chunk_candidate) { |
| VLOG(3) << "Move the buffer to alternate memory at " |
| << alternate_mem_interval.start |
| << ". Offset = " << chunk_candidate->chunk.offset |
| << ", size = " << chunk_candidate->chunk.size |
| << ", heap_size = " << chunk_candidate->heap_size |
| << ", prefetch picker = " |
| << options_.prefetch_interval_picker->ToDebugString(); |
| AddToPendingChunks(alternate_mem_interval, *chunk_candidate); |
| |
| AddAsyncCopy(prev_allocation_in_default_mem, MemorySpace::kAlternate, |
| chunk_candidate->chunk, alternate_mem_interval.start, |
| request.end_time, request.latest_prefetch_time, |
| request.allocation_value->allocation_sequence()); |
| |
| request.allocation_value->allocation_sequence()->back()->AddUse( |
| request.use->hlo_use); |
| prefetch_failed_due_to_async_copy_ = false; |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| absl::optional<AlternateMemoryBestFitHeap::ChunkCandidate> |
| AlternateMemoryBestFitHeap::FindBestChunkCandidate( |
| const AllocationRequest& request, absl::optional<int64> preferred_offset, |
| BufferInterval* alternate_mem_interval) const { |
| int64 end_time = request.end_time; |
| if (!preferred_offset) { |
| // First find the earliest use that is the same or later than the end time. |
| const auto& uses = request.allocation_value->uses(); |
| auto use_it = uses.begin(); |
| for (; use_it->time < end_time; ++use_it) { |
| } |
| CHECK(use_it != uses.end()); |
| int64 earliest_use = use_it->time; |
| |
| // Then find the latest use that can be allocated contiguously without |
| // copies. |
| const Shape& shape = request.allocation_value->defining_position().shape(); |
| for (; |
| (use_it + 1) != uses.end() && |
| options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( |
| shape, use_it->time, (use_it + 1)->time); |
| ++use_it) { |
| } |
| CHECK(use_it != uses.end()); |
| int64 latest_contiguous_use = use_it->time; |
| |
| // Find a chunk that's as long living as possible iterating in reverse over |
| // the use times. |
| for (; use_it >= uses.begin() && use_it->time >= end_time; --use_it) { |
| alternate_mem_interval->end = use_it->time; |
| ChunkCandidate chunk_candidate = |
| FindChunkCandidate(*alternate_mem_interval); |
| if (chunk_candidate.heap_size <= available_heap_size()) { |
| alternate_mem_interval->end = end_time; |
| VLOG(3) << "FindBestChunkCandidate earliest use = " << earliest_use |
| << ", latest contiguous use = " << latest_contiguous_use |
| << ", use with available mem = " << use_it->time |
| << ", offset = " << chunk_candidate.chunk.offset; |
| return chunk_candidate; |
| } |
| } |
| alternate_mem_interval->end = end_time; |
| return absl::nullopt; |
| } |
| // If a preferred offset is given, try to find an allocation at that offset |
| // only. |
| alternate_mem_interval->end = end_time; |
| ChunkCandidate chunk_candidate = |
| FindChunkCandidate(*alternate_mem_interval, *preferred_offset); |
| if (chunk_candidate.chunk.offset == *preferred_offset) { |
| return chunk_candidate; |
| } |
| return absl::nullopt; |
| } |
| |
| StatusOr<MemorySpaceAssignment::AsyncCopyStats> |
| MemorySpaceAssignment::CalculateAsyncCopyStats() const { |
| AsyncCopyStats stats; |
| stats.max_outstanding_async_copies = 0; |
| stats.num_prefetches = 0; |
| stats.prefetch_bytes = 0; |
| stats.num_evictions = 0; |
| stats.eviction_bytes = 0; |
| int64 current_copies = 0; |
| TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow_analysis, |
| HloDataflowAnalysis::Run(*module_)); |
| for (const HloComputation* computation : |
| module_->MakeNonfusionComputations()) { |
| for (HloInstruction* instruction : computation->instructions()) { |
| if (instruction->opcode() == HloOpcode::kCopyStart) { |
| current_copies++; |
| } else if (instruction->opcode() == HloOpcode::kCopyDone) { |
| current_copies--; |
| int64 size = |
| options_.size_fn(dataflow_analysis->GetUniqueValueAt(instruction)); |
| if (instruction->shape().layout().memory_space() == |
| options_.alternate_memory_space) { |
| ++stats.num_prefetches; |
| stats.prefetch_bytes += size; |
| } else { |
| ++stats.num_evictions; |
| stats.eviction_bytes += size; |
| } |
| } |
| stats.max_outstanding_async_copies = |
| std::max(stats.max_outstanding_async_copies, current_copies); |
| } |
| } |
| return stats; |
| } |
| |
| /*static*/ MemorySpaceAssignment::BufferIntervalCompare |
| MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( |
| const MemorySpaceAssignmentCostAnalysis& cost_analysis, |
| MemorySpaceAssignmentCostAnalysis::Cache* cache) { |
| return [&cost_analysis, cache](const BufferInterval& x, |
| const BufferInterval& y) { |
| float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x, cache); |
| float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y, cache); |
| if (x_memory_boundedness != y_memory_boundedness) { |
| return x_memory_boundedness > y_memory_boundedness; |
| } |
| // Tie-break if the memory boundedness is the same. |
| return GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare()( |
| x, y); |
| }; |
| } |
| |
| namespace { |
| |
| bool LooksLikeAnActivation(const HloInstruction* inst) { |
| for (HloInstruction* user : inst->users()) { |
| switch (user->opcode()) { |
| case HloOpcode::kConvolution: |
| case HloOpcode::kDot: |
| if (user->operand(0) == inst) { |
| return true; |
| } |
| break; |
| case HloOpcode::kGather: |
| if (user->operand(1) == inst) { |
| return true; |
| } |
| break; |
| case HloOpcode::kFusion: |
| for (int i = 0; i < user->operand_count(); ++i) { |
| if (user->operand(i) == inst && |
| LooksLikeAnActivation(user->fused_parameter(i))) { |
| return true; |
| } |
| } |
| break; |
| case HloOpcode::kBitcast: |
| return LooksLikeAnActivation(user); |
| default: |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| bool IsCrossProgramPrefetchCandidate( |
| const HloValue& value, const MemorySpaceAssignment::Options& options) { |
| return value.instruction()->parent() == |
| value.instruction()->GetModule()->entry_computation() && |
| value.instruction()->opcode() == HloOpcode::kParameter && |
| value.index().size() == 1 && value.shape().IsArray() && |
| !value.uses().empty() && |
| options.size_fn(value) <= options.max_size_in_bytes && |
| absl::c_all_of(value.uses(), [&](const HloUse& use) { |
| const HloInstruction* inst = |
| use.instruction->operand(use.operand_number); |
| |
| // Skip the LooksLikeAnActivation test since we're testing the |
| // parent GTE and its children below. |
| if (inst->opcode() == HloOpcode::kBitcast && |
| inst->operand(0)->opcode() == HloOpcode::kGetTupleElement && |
| inst->operand(0)->operand(0)->opcode() == |
| HloOpcode::kParameter) { |
| return true; |
| } |
| |
| return inst->opcode() == HloOpcode::kGetTupleElement && |
| !LooksLikeAnActivation(inst); |
| }); |
| } |
| |
| absl::optional<MemorySpaceAssignment::BufferInterval> |
| FindCrossProgramPrefetchCandidate( |
| const HloAliasAnalysis& alias_analysis, const HloLiveRange& hlo_live_range, |
| const MemorySpaceAssignment::Options& options) { |
| std::vector<MemorySpaceAssignment::BufferInterval> candidates; |
| for (HloValue* value : alias_analysis.dataflow_analysis().values()) { |
| if (IsCrossProgramPrefetchCandidate(*value, options)) { |
| MemorySpaceAssignment::BufferInterval interval; |
| interval.buffer = value; |
| interval.size = options.size_fn(*value); |
| interval.start = 0; |
| interval.end = hlo_live_range.schedule_end_time(); |
| interval.need_allocation = true; |
| candidates.emplace_back(interval); |
| } |
| } |
| |
| // The buffer_interval_compare ought to do a good job picking the most |
| // appropriate buffer to cross program prefetch, but empirically, it makes |
| // worse choices than just picking the largest buffer. |
| // TODO(b/152421603): Investigate. |
| auto size_compare = [](const auto& x, const auto& y) { |
| return x.size < y.size; |
| }; |
| auto& compare = options.default_cross_program_prefetch_heuristic && |
| options.buffer_interval_compare |
| ? *options.buffer_interval_compare |
| : size_compare; |
| |
| auto best_candidate = absl::c_max_element(candidates, compare); |
| if (best_candidate == candidates.end()) { |
| return absl::nullopt; |
| } |
| return *best_candidate; |
| } |
| } // namespace |
| |
| /*static*/ StatusOr<std::unique_ptr<PresetAssignments>> |
| MemorySpaceAssignment::Run(HloModule* module, |
| const HloLiveRange& hlo_live_range, |
| const HloAliasAnalysis& alias_analysis, |
| const Options& options) { |
| CHECK(module->has_schedule()); |
| VLOG(3) << "Module before memory space assignment: "; |
| XLA_VLOG_LINES(3, module->ToString()); |
| VLOG(3) << "Schedule: " << module->schedule().ToString(); |
| MemorySpaceAssignment memory_space_assignment(module, options, |
| hlo_live_range); |
| |
| return memory_space_assignment.RunMemorySpaceAssignment(hlo_live_range, |
| alias_analysis); |
| } |
| |
| StatusOr<std::unique_ptr<PresetAssignments>> |
| MemorySpaceAssignment::RunMemorySpaceAssignment( |
| const HloLiveRange& hlo_live_range, |
| const HloAliasAnalysis& alias_analysis) { |
| TF_RETURN_IF_ERROR(FindAllocationSequence(hlo_live_range, alias_analysis)); |
| TF_RETURN_IF_ERROR(Process()); |
| ScheduleAsynchronousCopies(); |
| TF_RETURN_IF_ERROR(SimplifyGraph()); |
| TF_RETURN_IF_ERROR(FixSchedule()); |
| TF_RETURN_IF_ERROR(ExportAndColorBuffers()); |
| |
| VLOG(3) << "Module after memory space assignment: "; |
| XLA_VLOG_LINES(3, module_->ToString()); |
| TF_CHECK_OK(module_->schedule().Verify()); |
| TF_ASSIGN_OR_RETURN(AsyncCopyStats stats, CalculateAsyncCopyStats()); |
| VLOG(1) << "Maximum number of outstanding async copies: " |
| << stats.max_outstanding_async_copies; |
| VLOG(1) << "Number of prefetches: " << stats.num_prefetches |
| << ", in bytes: " << stats.prefetch_bytes; |
| VLOG(1) << "Number of evictions: " << stats.num_evictions |
| << ", in bytes: " << stats.eviction_bytes; |
| |
| TF_RETURN_IF_ERROR(VerifyAndExportHeapSimulatorTrace()); |
| |
| return std::move(preset_assignments_); |
| } |
| |
| Status MemorySpaceAssignment::FindAllocationSequence( |
| const HloLiveRange& hlo_live_range, |
| const HloAliasAnalysis& alias_analysis) { |
| auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>( |
| &allocations_, options_, alias_analysis, hlo_live_range); |
| |
| if (options_.enable_cross_program_prefetch) { |
| absl::optional<AlternateMemoryBestFitHeap::BufferInterval> |
| prefetch_candiate = FindCrossProgramPrefetchCandidate( |
| alias_analysis, hlo_live_range, options_); |
| algorithm->AllocateCrossProgramPrefetchBuffer(module_, prefetch_candiate); |
| } |
| |
| HeapSimulator::Options heap_simulator_options; |
| heap_simulator_options.may_reuse_operand_buffers = false; |
| TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module_, |
| module_->schedule(), alias_analysis, |
| options_.size_fn, |
| heap_simulator_options) |
| .status()); |
| return Status::OK(); |
| } |
| |
| void MemorySpaceAssignment::Allocation::AddUse(HloUse use) { |
| HloInstruction* operand = |
| use.instruction->mutable_operand(use.operand_number); |
| // If the use is a tuple, look inside the tuple to find the actual use. |
| for (int64 index : use.operand_index) { |
| if (operand->opcode() != HloOpcode::kTuple) { |
| break; |
| } |
| operand = operand->mutable_operand(index); |
| } |
| |
| // Look beyond GetTupleElement(Tuple()) pattern for any bitcasts. |
| std::function<HloInstruction*(HloInstruction*)> get_simplified_operand; |
| get_simplified_operand = [&](HloInstruction* instruction) { |
| while (instruction->opcode() == HloOpcode::kGetTupleElement) { |
| HloInstruction* operand = |
| get_simplified_operand(instruction->mutable_operand(0)); |
| if (operand->opcode() == HloOpcode::kTuple) { |
| instruction = operand->mutable_operand(instruction->tuple_index()); |
| } else { |
| return instruction; |
| } |
| } |
| return instruction; |
| }; |
| operand = get_simplified_operand(operand); |
| |
| uses_.push_back(use); |
| } |
| |
| Status MemorySpaceAssignment::Allocation::Process( |
| MemorySpaceAssignment* memory_space_assignment) { |
| HloInstruction* producing_instruction = AddGetTupleElements(); |
| HloComputation* computation = producing_instruction->parent(); |
| for (const HloUse& use : uses_) { |
| Shape operand_shape = use.instruction->operand(use.operand_number)->shape(); |
| HloInstruction* replacement_instruction = producing_instruction; |
| if (operand_shape.IsTuple()) { |
| TF_ASSIGN_OR_RETURN( |
| replacement_instruction, |
| ReplaceTupleWith(producing_instruction, |
| use.instruction->mutable_operand(use.operand_number), |
| use.operand_index)); |
| } else if (operand_shape != producing_instruction->shape()) { |
| VLOG(4) << "Old shape = " << operand_shape.ToString() |
| << ", new shape = " << producing_instruction->shape().ToString() |
| << "; inserting a bitcast."; |
| replacement_instruction = computation->AddInstruction( |
| HloInstruction::CreateBitcast(operand_shape, producing_instruction)); |
| } |
| TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith( |
| use.operand_number, replacement_instruction)); |
| } |
| return Status::OK(); |
| } |
| |
| StatusOr<HloInstruction*> MemorySpaceAssignment::Allocation::ReplaceTupleWith( |
| HloInstruction* new_instruction, HloInstruction* tuple, |
| ShapeIndex shape_index) { |
| const Shape& tuple_shape = tuple->shape(); |
| CHECK(tuple->shape().IsTuple()) |
| << "ReplaceTupleWith was called for a non-tuple. Tuple = " |
| << tuple->ToString() |
| << ", new_instruction = " << new_instruction->ToString() |
| << ", shape_index = " << shape_index.ToString(); |
| |
| HloComputation* computation = new_instruction->parent(); |
| std::vector<HloInstruction*> tuple_args(tuple_shape.tuple_shapes_size()); |
| for (int64 i = 0; i < tuple_shape.tuple_shapes_size(); ++i) { |
| const Shape& subshape = tuple_shape.tuple_shapes(i); |
| if (i == shape_index[0]) { |
| // If the subshape is still a tuple, recurse and pass a new shape index |
| // for the one level deeper. |
| if (subshape.IsTuple()) { |
| HloInstruction* get_tuple_element = computation->AddInstruction( |
| HloInstruction::CreateGetTupleElement(subshape, tuple, i)); |
| TF_ASSIGN_OR_RETURN(tuple_args[i], |
| ReplaceTupleWith(new_instruction, get_tuple_element, |
| ShapeIndex(shape_index.begin() + 1, |
| shape_index.end()))); |
| } else { |
| if (subshape != new_instruction->shape()) { |
| VLOG(4) << "Old shape = " << subshape.ToString() |
| << ", new shape = " << new_instruction->shape().ToString() |
| << "; inserting a bitcast."; |
| new_instruction = computation->AddInstruction( |
| HloInstruction::CreateBitcast(subshape, new_instruction)); |
| } |
| tuple_args[i] = new_instruction; |
| } |
| } else { |
| HloInstruction* get_tuple_element = computation->AddInstruction( |
| HloInstruction::CreateGetTupleElement(subshape, tuple, i)); |
| tuple_args[i] = get_tuple_element; |
| } |
| } |
| return computation->AddInstruction(HloInstruction::CreateTuple(tuple_args)); |
| } |
| |
| HloInstruction* MemorySpaceAssignment::Allocation::AddGetTupleElements() { |
| HloInstruction* producing_instruction = defining_position().instruction; |
| CHECK_NE(producing_instruction, nullptr); |
| |
| Shape shape = defining_position().shape(); |
| CHECK(shape.IsArray()) << "Allocation shape is not an array. Shape = " |
| << shape.ToString() |
| << " position = " << defining_position().shape(); |
| HloComputation* computation = producing_instruction->parent(); |
| |
| // If the instruction we're processing is a tuple, we (recursively) create |
| // kGetTupleElement instructions and copy that value. Asynchronous copies only |
| // support array types. |
| for (int64 index : defining_position().index) { |
| producing_instruction = |
| computation->AddInstruction(HloInstruction::CreateGetTupleElement( |
| producing_instruction->shape().tuple_shapes(index), |
| producing_instruction, index)); |
| } |
| return producing_instruction; |
| } |
| |
| std::string MemorySpaceAssignment::Allocation::ToString() const { |
| return absl::StrCat("Allocation in ", |
| memory_space_ == MemorySpace::kDefault ? "def" : "alt", |
| " defined at ", defining_position_.ToString()); |
| } |
| |
| std::string MemorySpaceAssignment::CopyAllocation::ToString() const { |
| return absl::StrCat("Copy Allocation in ", |
| memory_space_ == MemorySpace::kDefault ? "def" : "alt", |
| " from ", prev_allocation_.ToString()); |
| } |
| |
| Status MemorySpaceAssignment::CopyAllocation::Process( |
| MemorySpaceAssignment* memory_space_assignment) { |
| // Copy allocations need to insert asynchronous copy nodes. |
| Shape shape = defining_position().shape(); |
| HloInstruction* producing_instruction = AddGetTupleElements(); |
| HloComputation* computation = producing_instruction->parent(); |
| copy_start_ = computation->AddInstruction(HloInstruction::CreateUnary( |
| ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}), |
| HloOpcode::kCopyStart, producing_instruction)); |
| copy_done_ = computation->AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_)); |
| // Update the allocation position with the copy done instruction so that if |
| // there are further copies from it, it can find the correct position. |
| defining_position_ = HloPosition{copy_done_, {}}; |
| |
| // Replace all the uses with the new copy instruction. |
| for (HloUse use : uses_) { |
| // If the operand is a tuple, we need to descend to the actual instruction |
| // we want to replace. |
| HloInstruction* replacement_instruction; |
| Shape operand_shape = use.instruction->operand(use.operand_number)->shape(); |
| if (operand_shape.IsTuple()) { |
| TF_ASSIGN_OR_RETURN( |
| replacement_instruction, |
| ReplaceTupleWith(copy_done_, |
| use.instruction->mutable_operand(use.operand_number), |
| use.operand_index)); |
| } else if (operand_shape != copy_done_->shape()) { |
| VLOG(4) << "Old shape = " << operand_shape.ToString() |
| << ", new shape = " << copy_done_->shape().ToString() |
| << "; inserting a bitcast."; |
| replacement_instruction = computation->AddInstruction( |
| HloInstruction::CreateBitcast(operand_shape, copy_done_)); |
| } else { |
| replacement_instruction = copy_done_; |
| } |
| TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith( |
| use.operand_number, replacement_instruction)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status MemorySpaceAssignment::Process() { |
| VLOG(1) << "Processing assigned buffers..."; |
| // Insert CopyStart/CopyDone pairs. |
| for (auto& allocation : allocations_) { |
| VLOG(3) << "Processing: " << allocation->ToString(); |
| TF_RETURN_IF_ERROR(allocation->Process(this)); |
| // Add the offset and size of the allocation in the alternate memory to |
| // the output map. |
| if (allocation->memory_space() == MemorySpace::kAlternate) { |
| alternate_memory_assignments_.emplace_back( |
| allocation->defining_position(), allocation->chunk()); |
| alternate_memory_size_ = |
| std::max(alternate_memory_size_, allocation->chunk().chunk_end()); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| Status MemorySpaceAssignment::ExportAndColorBuffers() { |
| VLOG(1) << "Exporting buffers..."; |
| TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module_)); |
| absl::flat_hash_map<int64, int64> seen_buffer_offsets; |
| VLOG(3) << "Exported alternate memory allocations:"; |
| for (const auto& position_and_chunk : alternate_memory_assignments_) { |
| const HloPosition& defining_position = position_and_chunk.first; |
| const Chunk& chunk = position_and_chunk.second; |
| const HloBuffer& buffer = alias_analysis->GetUniqueBufferAt( |
| defining_position.instruction, defining_position.index); |
| auto seen_buffer_offset_it = seen_buffer_offsets.find(buffer.id()); |
| if (seen_buffer_offset_it != seen_buffer_offsets.end()) { |
| CHECK_EQ(chunk.offset, seen_buffer_offset_it->second) |
| << "Mismatch in offset for positions that map to the same value: " |
| << buffer.ToString() << ", pos: " << defining_position.ToString(); |
| } else { |
| VLOG(3) << " [" << chunk.offset << ", " << chunk.size |
| << "] : " << defining_position.ToString() << " (" |
| << buffer.ToString() << ")"; |
| preset_assignments_->add_chunk(defining_position, chunk); |
| seen_buffer_offsets[buffer.id()] = chunk.offset; |
| } |
| } |
| |
| if (!preset_assignments_->chunks().empty()) { |
| preset_assignments_ |
| ->assignment_information_for_space(options_.alternate_memory_space) |
| ->size = alternate_memory_size_; |
| } |
| |
| VLOG(3) << "Exported alternate memory sizes:"; |
| for (auto& pair : preset_assignments_->assignment_informations()) { |
| VLOG(3) << " space: " << pair.first << ", size: " << pair.second.size; |
| } |
| |
| VLOG(1) << "Coloring buffers..."; |
| // Color the pending positions and all of their aliased buffers. |
| for (const auto& defining_position_and_chunk : |
| preset_assignments_->chunks()) { |
| const HloPosition& defining_position = defining_position_and_chunk.first; |
| for (auto& buffer : alias_analysis->ComputeBuffersAt( |
| defining_position.instruction, defining_position.index)) { |
| for (auto& value : buffer->values()) { |
| for (auto& position : value->positions()) { |
| VLOG(4) << "Coloring " << position.ToString(); |
| Shape* shape = ShapeUtil::GetMutableSubshape( |
| position.instruction->mutable_shape(), position.index); |
| CHECK(shape->IsArray()) << "Coloring a shape that is not an array: " |
| << position.ToString(); |
| shape->mutable_layout()->set_memory_space( |
| options_.alternate_memory_space); |
| } |
| } |
| } |
| } |
| return Status::OK(); |
| } |
| |
| void MemorySpaceAssignment::RemoveAssignmentForInstruction( |
| const HloInstruction* instruction) { |
| for (auto& position_and_chunk : alternate_memory_assignments_) { |
| const HloPosition& position = position_and_chunk.first; |
| if (position.instruction == instruction) { |
| VLOG(3) << "Removing instruction from alternate memory assignments."; |
| // Swap the removed position and chunk with the back and pop back. |
| position_and_chunk = alternate_memory_assignments_.back(); |
| alternate_memory_assignments_.pop_back(); |
| break; |
| } |
| } |
| } |
| |
| Status MemorySpaceAssignment::SimplifyGraph() { |
| VLOG(1) << "Simplifying graph..."; |
| for (HloComputation* computation : module_->MakeNonfusionComputations()) { |
| // Parallel computations aren't in the schedule and don't need to be |
| // modified. |
| if (!computations_in_schedule_.contains(computation)) { |
| VLOG(4) << "Not simplifying " << computation->name() |
| << " because it's not in the schedule."; |
| continue; |
| } |
| // Drop control dependencies. Since the computation is already scheduled, we |
| // don't need control dependencies anymore, and having control |
| // predecessors/successors prevents us from removing instructions without |
| // users (HloComputation::IsSafelyRemovable returns false if there are |
| // control dependencies). |
| for (HloInstruction* instruction : |
| computation->MakeInstructionPostOrder()) { |
| TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); |
| } |
| // We perform limited DCE and forward the tuple operand in patterns like |
| // GetTupleElement(Tuple(a, b), 0). This is mostly because memory space |
| // assignment is ran late in compilation (after DCE and arithmetic |
| // simplification passes) and we don't want to generate redundant code. Run |
| // to fixed point. |
| bool computation_modified = true; |
| while (computation_modified) { |
| computation_modified = false; |
| VLOG(4) << "Running simplify graph loop over " << computation->name(); |
| for (HloInstruction* instruction : |
| computation->MakeInstructionPostOrder()) { |
| if (computation->IsSafelyRemovable(instruction) && |
| instruction->user_count() == 0 && !instruction->HasSideEffect() && |
| instruction != computation->root_instruction() && |
| instruction->opcode() != HloOpcode::kCopyStart && |
| instruction->opcode() != HloOpcode::kCopyDone) { |
| VLOG(4) << "Instruction removed: " << instruction->ToString(); |
| // Ensure the alternate memory assignments don't contain a reference |
| // to the removed instruction. |
| RemoveAssignmentForInstruction(instruction); |
| // Instead of deleting the instruction from the schedule, replace it |
| // with a nullptr. This is needed because FixSchedule relies on the |
| // logical time that is the index into flattened_instructions_ for |
| // scheduling asynchronous copies. |
| auto instruction_it = |
| absl::c_find(flattened_instructions_, instruction); |
| if (instruction_it != flattened_instructions_.end()) { |
| *instruction_it = nullptr; |
| } |
| TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); |
| computation_modified = true; |
| } else if (instruction->opcode() == HloOpcode::kGetTupleElement) { |
| HloInstruction* operand = instruction->mutable_operand(0); |
| if (operand->opcode() == HloOpcode::kTuple) { |
| HloInstruction* forwarded_instruction = |
| operand->mutable_operand(instruction->tuple_index()); |
| VLOG(4) << "Replacing uses of " << instruction->ToString() |
| << " with " << forwarded_instruction->ToString(); |
| TF_RETURN_IF_ERROR( |
| instruction->ReplaceAllUsesWith(forwarded_instruction)); |
| computation_modified = true; |
| } |
| } else if (instruction->opcode() == HloOpcode::kTuple) { |
| // Replace Tuple(GetTupleElement(x), ..., GetTupleElement(x)) pattern |
| // with x. |
| bool can_replace = |
| instruction->operand_count() > 0 && |
| instruction->operand(0)->opcode() == |
| HloOpcode::kGetTupleElement && |
| instruction->operand(0) |
| ->operand(0) |
| ->shape() |
| .tuple_shapes_size() == instruction->operand_count(); |
| for (int operand_number = 0; |
| operand_number < instruction->operand_count(); |
| ++operand_number) { |
| const HloInstruction* operand = |
| instruction->operand(operand_number); |
| if (operand->opcode() != HloOpcode::kGetTupleElement || |
| operand->tuple_index() != operand_number || |
| operand->operand(0) != instruction->operand(0)->operand(0)) { |
| can_replace = false; |
| break; |
| } |
| } |
| if (can_replace) { |
| HloInstruction* forwarded_instruction = |
| instruction->mutable_operand(0)->mutable_operand(0); |
| VLOG(4) << "Replacing uses of " << instruction->ToString() |
| << " with " << forwarded_instruction->ToString(); |
| TF_RETURN_IF_ERROR( |
| instruction->ReplaceAllUsesWith(forwarded_instruction)); |
| computation_modified = true; |
| } |
| } |
| } |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| void MemorySpaceAssignment::EnsureInstructionAndOperandsInserted( |
| HloInstruction* new_instruction, HloInstructionSequence* new_sequence, |
| absl::flat_hash_set<HloInstruction*>* inserted_instructions) const { |
| if (inserted_instructions->contains(new_instruction)) { |
| return; |
| } |
| for (HloInstruction* operand : new_instruction->operands()) { |
| // CopyStart/CopyDone dependencies should always be already inserted; it is |
| // a red flag when they haven't already been inserted. |
| CHECK((operand->opcode() != HloOpcode::kCopyStart && |
| operand->opcode() != HloOpcode::kCopyDone) || |
| inserted_instructions->contains(operand)) |
| << "Inserted instruction " << new_instruction->ToString() |
| << " has un-inserted dependency: " << operand->ToString(); |
| EnsureInstructionAndOperandsInserted(operand, new_sequence, |
| inserted_instructions); |
| } |
| VLOG(4) << "inserting: " << new_instruction->ToShortString(); |
| new_sequence->push_back(new_instruction); |
| inserted_instructions->insert(new_instruction); |
| } |
| |
| void MemorySpaceAssignment::ScheduleAsynchronousCopies() { |
| VLOG(1) << "Scheduling asynchronous copies..."; |
| for (MemorySpace memory_space : |
| {MemorySpace::kDefault, MemorySpace::kAlternate}) { |
| std::vector<CopyAllocation*> copy_allocations; |
| for (auto& allocation : allocations_) { |
| if (allocation->is_copy_allocation()) { |
| auto copy_allocation = static_cast<CopyAllocation*>(allocation.get()); |
| if (copy_allocation->memory_space() == memory_space) { |
| copy_allocations.push_back(copy_allocation); |
| } |
| } |
| } |
| |
| absl::c_stable_sort( |
| copy_allocations, [](CopyAllocation* first, CopyAllocation* second) { |
| return std::forward_as_tuple(first->copy_done_schedule_before(), |
| first->copy_start_schedule_after()) < |
| std::forward_as_tuple(second->copy_done_schedule_before(), |
| second->copy_start_schedule_after()); |
| }); |
| |
| CopyAllocation* prev_copy_allocation = nullptr; |
| for (CopyAllocation* copy_allocation : copy_allocations) { |
| // If the copy start doesn't happen to be scheduled at the correct |
| // computation, delay it until the correct computation starts. |
| int64 copy_start_schedule_after = |
| copy_allocation->copy_start_schedule_after(); |
| // Accessing flattened_instructions_ here without checking if it is |
| // nullptr is safe because this method is called before SimplifyGraph. |
| while (copy_allocation->defining_position().instruction->parent() != |
| flattened_instructions_[copy_start_schedule_after]->parent()) { |
| VLOG(4) << "Delaying CopyStart (" << copy_start_schedule_after << " to " |
| << (copy_start_schedule_after + 1) << ") for " |
| << copy_allocation->copy_start()->ToString() |
| << " because it is not in the correct computation."; |
| copy_allocation->set_copy_start_schedule_after( |
| ++copy_start_schedule_after); |
| } |
| |
| schedule_after_[copy_allocation->copy_start_schedule_after()].push_back( |
| copy_allocation->copy_start()); |
| schedule_before_[copy_allocation->copy_done_schedule_before()].push_back( |
| copy_allocation->copy_done()); |
| prev_copy_allocation = copy_allocation; |
| } |
| } |
| } |
| |
| Status MemorySpaceAssignment::FixSchedule() { |
| VLOG(1) << "Fixing schedule..."; |
| CHECK(module_->has_schedule()); |
| HloSchedule& schedule = module_->schedule(); |
| for (const HloComputation* computation : |
| module_->MakeNonfusionComputations()) { |
| // Parallel computations aren't in the schedule and don't need to be |
| // modified. |
| if (!computations_in_schedule_.contains(computation)) { |
| VLOG(4) << "Not scheduling " << computation->name() |
| << " because it's not in the schedule."; |
| continue; |
| } |
| CHECK(schedule.is_computation_scheduled(computation)); |
| HloInstructionSequence new_sequence; |
| |
| absl::flat_hash_set<HloInstruction*> inserted_instructions; |
| |
| VLOG(4) << "Scheduling: " << computation->ToString(); |
| |
| for (int64 instruction_index = 0; |
| instruction_index < flattened_instructions_.size(); |
| ++instruction_index) { |
| auto insts_before_iter = schedule_before_.find(instruction_index); |
| if (insts_before_iter != schedule_before_.end()) { |
| for (HloInstruction* new_instruction : insts_before_iter->second) { |
| if (new_instruction->parent() == computation) { |
| VLOG(4) << "before " << instruction_index << ": " |
| << new_instruction->name(); |
| EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence, |
| &inserted_instructions); |
| } |
| } |
| } |
| HloInstruction* instruction = flattened_instructions_[instruction_index]; |
| // Insert only if it is not deleted (SimplifyGraph sets it to nullptr if |
| // it was deleted) and not previously inserted. Also bitcasts and tuples |
| // are treated specially and only inserted as a result of operand |
| // dependencies. |
| if (instruction != nullptr && |
| !inserted_instructions.contains(instruction) && |
| instruction->parent() == computation && |
| instruction->opcode() != HloOpcode::kBitcast && |
| instruction->opcode() != HloOpcode::kTuple) { |
| VLOG(4) << "inst " << instruction_index << ": " << instruction->name(); |
| EnsureInstructionAndOperandsInserted(instruction, &new_sequence, |
| &inserted_instructions); |
| } |
| auto insts_after_iter = schedule_after_.find(instruction_index); |
| if (insts_after_iter != schedule_after_.end()) { |
| for (HloInstruction* new_instruction : insts_after_iter->second) { |
| if (new_instruction->parent() == computation) { |
| VLOG(4) << "after " << instruction_index << ": " |
| << new_instruction->name(); |
| EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence, |
| &inserted_instructions); |
| } |
| } |
| } |
| } |
| // For rare cases where the original sequence is empty, ensure the root |
| // instruction and its dependencies are scheduled. |
| EnsureInstructionAndOperandsInserted(computation->root_instruction(), |
| &new_sequence, &inserted_instructions); |
| CHECK_EQ(new_sequence.size(), computation->instruction_count()) |
| << "New sequence for computation " << computation->name() << " has " |
| << new_sequence.size() << " instructions, expects " |
| << computation->instruction_count() << "."; |
| schedule.set_sequence(computation, new_sequence); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { |
| VLOG(1) << "Verifying..."; |
| TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, |
| HloAliasAnalysis::Run(module_)); |
| TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range, |
| HloLiveRange::Run(module_->schedule(), *alias_analysis, |
| module_->entry_computation())); |
| |
| BufferIntervalTree interval_tree; |
| absl::flat_hash_set<int64> seen_buffers; |
| // The key for events is: time, is_free, value_id. This is so that the events |
| // are sorted first by time, then within the same time, allocations are sorted |
| // earlier than frees, and finally the value id as a tie breaker. |
| std::map<std::tuple<int64, bool, int64>, |
| std::tuple<const HloValue*, Chunk, HeapSimulatorTrace::Event::Kind>> |
| events; |
| |
| auto add_allocation_and_verify = [&](int64 start_time, int64 end_time, |
| const Chunk& chunk, |
| const HloValue* value) { |
| events[std::make_tuple(start_time, /*is_free=*/false, value->id())] = |
| std::make_tuple(value, chunk, HeapSimulatorTrace::Event::ALLOC); |
| events[std::make_tuple(end_time, /*is_free=*/true, value->id())] = |
| std::make_tuple(value, chunk, HeapSimulatorTrace::Event::FREE); |
| |
| // Get the chunks overlapping in time and search if they overlap in space |
| // as well. |
| // TODO(berkin): For now checking against end_time - 1 (exclusive), but we |
| // really should check against end_time (inclusive) for cases where the |
| // operand can't share buffer with user (see |
| // HloDataflowAnalysis::CanShareOperandBufferWithUser). |
| for (const Chunk& overlapping_chunk : |
| interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) { |
| if (chunk.OverlapsWith(overlapping_chunk)) { |
| return InternalError( |
| ("Value %s (%d, %d) off: %d size: %d overlaps with another chunk" |
| " off: %d size: %d"), |
| value->ToShortString(), start_time, end_time, chunk.offset, |
| chunk.size, overlapping_chunk.offset, overlapping_chunk.size); |
| } |
| } |
| interval_tree.Add(start_time, end_time - 1, chunk); |
| return Status::OK(); |
| }; |
| |
| // Go through all instructions in the module to ensure CopyStart/CopyDone |
| // instructions copy between alternate memory and default memory. |
| for (const HloComputation* computation : |
| module_->MakeNonfusionComputations()) { |
| for (const HloInstruction* instruction : computation->instructions()) { |
| if (instruction->opcode() == HloOpcode::kCopyStart) { |
| int64 from_memory_space = |
| ShapeUtil::GetSubshape(instruction->shape(), {1}) |
| .layout() |
| .memory_space(); |
| int64 to_memory_space = |
| ShapeUtil::GetSubshape(instruction->shape(), {0}) |
| .layout() |
| .memory_space(); |
| CHECK_NE(from_memory_space, to_memory_space) |
| << "Asynchronous copy to the same memory space: " |
| << instruction->ToString(); |
| } |
| } |
| } |
| |
| for (const auto& position_and_chunk : preset_assignments_->chunks()) { |
| const HloPosition& position = position_and_chunk.first; |
| const Chunk& chunk = position_and_chunk.second; |
| const HloBuffer& buffer = |
| alias_analysis->GetUniqueBufferAt(position.instruction, position.index); |
| CHECK(!seen_buffers.contains(buffer.id())) |
| << "Multiple preset assignments for the same buffer: " |
| << buffer.ToString() << ", pos: " << position.ToString() |
| << ", off: " << chunk.offset << ", size: " << chunk.size; |
| seen_buffers.insert(buffer.id()); |
| |
| for (const HloValue* value : buffer.values()) { |
| const HloLiveRange::TimeBound& time_bound = |
| hlo_live_range->buffer_live_ranges().at(value); |
| const HloInstruction* last_use_instruction = nullptr; |
| int64 last_use_time = time_bound.start; |
| for (const HloUse& use : value->uses()) { |
| int64 use_time = |
| hlo_live_range->instruction_schedule().at(use.instruction); |
| if (use_time > last_use_time) { |
| last_use_time = use_time; |
| last_use_instruction = use.instruction; |
| } |
| } |
| |
| if (last_use_instruction && |
| last_use_instruction->opcode() == HloOpcode::kConditional) { |
| // Special case when verifying conditional: we internally split the use |
| // of alternate memory in conditionals, so fish them out from the |
| // conditionals. |
| VLOG(3) << " Splitting conditional buffer: " << buffer.ToString() |
| << " value: " << value->ToShortString() << ": (" |
| << time_bound.start << ", " << time_bound.end |
| << ") off: " << chunk.offset << ", size: " << chunk.size; |
| int64 earliest_computation_start_time = time_bound.end; |
| for (const HloComputation* called_computation : |
| last_use_instruction->called_computations()) { |
| earliest_computation_start_time = |
| std::min(earliest_computation_start_time, |
| hlo_live_range->computation_span_times() |
| .at(called_computation) |
| .start); |
| int64 parameter_time = -1; |
| int64 last_use_time = -1; |
| for (const HloPosition& position : value->positions()) { |
| if (position.instruction->opcode() == HloOpcode::kParameter && |
| position.instruction->parent() == called_computation) { |
| parameter_time = hlo_live_range->instruction_schedule().at( |
| position.instruction); |
| break; |
| } |
| } |
| for (const HloUse& use : value->uses()) { |
| if (use.instruction->parent() == called_computation) { |
| last_use_time = std::max( |
| last_use_time, |
| hlo_live_range->instruction_schedule().at(use.instruction)); |
| } |
| } |
| if (last_use_time != -1) { |
| CHECK_NE(parameter_time, -1); |
| VLOG(3) << " computation: " << called_computation->name() << ": (" |
| << parameter_time << ", " << last_use_time << ")"; |
| TF_RETURN_IF_ERROR(add_allocation_and_verify( |
| parameter_time, last_use_time, chunk, value)); |
| } |
| } |
| VLOG(3) << " from beginning until first computation: (" |
| << time_bound.start << ", " |
| << (earliest_computation_start_time - 1) << ")"; |
| TF_RETURN_IF_ERROR(add_allocation_and_verify( |
| time_bound.start, earliest_computation_start_time - 1, chunk, |
| value)); |
| } else { |
| VLOG(3) << " buffer: " << buffer.ToString() |
| << " value: " << value->ToShortString() << ": (" |
| << time_bound.start << ", " << time_bound.end |
| << ") off: " << chunk.offset << ", size: " << chunk.size; |
| TF_RETURN_IF_ERROR(add_allocation_and_verify( |
| time_bound.start, time_bound.end, chunk, value)); |
| } |
| } |
| } |
| |
| HeapSimulatorTrace* heap_trace = |
| &preset_assignments_ |
| ->assignment_information_for_space(options_.alternate_memory_space) |
| ->heap_simulator_trace; |
| int64 memory_usage = 0; |
| int64 max_memory_usage = 0; |
| for (const auto& event : events) { |
| int64 time; |
| bool is_free; |
| int64 buffer_id; |
| std::tie(time, is_free, buffer_id) = event.first; |
| const HloValue* value; |
| Chunk chunk; |
| HeapSimulatorTrace::Event::Kind kind; |
| std::tie(value, chunk, kind) = event.second; |
| HeapSimulatorTrace::Event* heap_trace_event = heap_trace->add_events(); |
| heap_trace_event->set_kind(kind); |
| heap_trace_event->set_buffer_id(buffer_id); |
| heap_trace_event->set_instruction_name(value->instruction()->name()); |
| heap_trace_event->set_computation_name( |
| value->instruction()->parent()->name()); |
| |
| if (kind == HeapSimulatorTrace::Event::ALLOC) { |
| memory_usage += chunk.size; |
| } else { |
| CHECK_EQ(kind, HeapSimulatorTrace::Event::FREE); |
| memory_usage -= chunk.size; |
| } |
| max_memory_usage = std::max(max_memory_usage, memory_usage); |
| VLOG(4) << "Memory usage: " << memory_usage << " at time: " << time; |
| } |
| VLOG(1) << "Max memory usage ignoring fragmentation: " << max_memory_usage; |
| |
| return Status::OK(); |
| } |
| |
| } // namespace xla |