| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/service/memory_space_assignment.h" |
| |
| namespace xla { |
| |
| namespace { |
| // Define a dummy chunk for chunks that will be allocated in the default memory |
| // space and for keeping track of number of asynchronous copies. |
| const HeapSimulator::Chunk kDummyChunk{-1, -1}; |
| } // namespace |
| |
| float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute( |
| const HloInstruction& instruction) const { |
| return std::max( |
| cost_analysis_.flop_count(instruction) / |
| cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey), |
| cost_analysis_.transcendental_count(instruction) / |
| cost_analysis_.per_second_rate(HloCostAnalysis::kTranscendentalsKey)); |
| } |
| |
| float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory( |
| const HloInstruction& instruction, |
| absl::optional<int64> operand_in_alternate_mem, |
| bool output_in_alternate_mem) const { |
| float bytes_accessed = cost_analysis_.bytes_accessed(instruction); |
| VLOG(4) << " bytes_accessed = " << bytes_accessed; |
| float elapsed_due_to_bytes = |
| bytes_accessed / |
| cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); |
| if (operand_in_alternate_mem) { |
| // Estimate the elapsed time due to the operand being in the alternate |
| // memory space. |
| float operand_bytes_accessed = cost_analysis_.operand_bytes_accessed( |
| instruction, *operand_in_alternate_mem); |
| float elapsed_due_to_operand_bytes = |
| operand_bytes_accessed / alternate_mem_bandwidth_bytes_per_second_; |
| bytes_accessed -= operand_bytes_accessed; |
| elapsed_due_to_bytes = |
| elapsed_due_to_operand_bytes + |
| bytes_accessed / |
| cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); |
| } |
| if (output_in_alternate_mem) { |
| // Estimate the elapsed time due to the output being in the alternate memory |
| // space. |
| float output_bytes_accessed = |
| cost_analysis_.output_bytes_accessed(instruction); |
| float elapsed_due_to_output_bytes = |
| output_bytes_accessed / alternate_mem_bandwidth_bytes_per_second_; |
| bytes_accessed -= output_bytes_accessed; |
| elapsed_due_to_bytes = |
| elapsed_due_to_output_bytes + |
| bytes_accessed / |
| cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); |
| } |
| return elapsed_due_to_bytes; |
| } |
| |
| float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsed( |
| const HloInstruction& instruction, |
| absl::optional<int64> operand_in_alternate_mem, |
| bool output_in_alternate_mem) const { |
| return std::max( |
| GetInstructionElapsedDueToCompute(instruction), |
| GetInstructionElapsedDueToMemory(instruction, operand_in_alternate_mem, |
| output_in_alternate_mem)); |
| } |
| |
| float MemorySpaceAssignmentCostAnalysis::GetAsyncCopyElapsed( |
| const Shape& shape) const { |
| int64 size_in_bytes = cost_analysis_.GetShapeSize(shape); |
| return static_cast<float>(size_in_bytes) / |
| async_copy_bandwidth_bytes_per_second_; |
| } |
| |
| bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy( |
| const Shape& shape, int64 start_time, int64 end_time) const { |
| return end_time - start_time <= max_overlap_count_; |
| } |
| |
| void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use, |
| int64 start_time, |
| int64 end_time) { |
| end_time_ = end_time; |
| current_prefetch_time_ = std::max(start_time, end_time_ - max_overlap_count_); |
| } |
| |
| int64 InstructionCountPrefetchIntervalPicker::Next() { |
| CHECK(!Done()) << "Prefetch interval picker's Next() is called even though " |
| "Done() is false"; |
| return current_prefetch_time_++; |
| } |
| |
| bool InstructionCountPrefetchIntervalPicker::Done() const { |
| return end_time_ - current_prefetch_time_ <= min_overlap_count_; |
| } |
| |
| void CostAnalysisPrefetchIntervalPicker::SetInstructionSchedule( |
| const absl::flat_hash_map<const HloInstruction*, int64>& |
| instruction_schedule) { |
| // First create a vector of elapsed times of HLO instructions. |
| std::vector<float> instructions_elapsed_time(instruction_schedule.size(), |
| 0.0); |
| |
| for (const auto& instruction_and_logical_time : instruction_schedule) { |
| float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds( |
| *instruction_and_logical_time.first); |
| int64 logical_time = instruction_and_logical_time.second; |
| if (logical_time >= instructions_elapsed_time.size()) { |
| instructions_elapsed_time.resize(logical_time + 1, 0.0); |
| } |
| instructions_elapsed_time[logical_time] = elapsed_time; |
| VLOG(4) << "Elapsed time in seconds [" << logical_time |
| << "] = " << elapsed_time; |
| } |
| // As an optimization, create a cumulative sum vector of elapsed time. |
| float cumsum = 0.0; |
| for (float elapsed_time : instructions_elapsed_time) { |
| cumsum += elapsed_time; |
| elapsed_time_cumsum_.push_back(cumsum); |
| } |
| } |
| |
| bool CostAnalysisPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy( |
| const Shape& shape, int64 start_time, int64 end_time) const { |
| // Even though this method returns if we allow the buffer in alternate memory |
| // _without_ asynchronous copies, calculate how long it would have taken to |
| // copy it and compare it to the elapsed time in the logical interval. |
| float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape); |
| float logical_interval_elapsed = |
| GetLogicalIntervalElapsed(start_time, end_time); |
| return max_async_copy_to_overlap_ratio_ * async_copy_elapsed > |
| logical_interval_elapsed; |
| } |
| |
| void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use, |
| int64 start_time, |
| int64 end_time) { |
| const Shape& shape = use.instruction->operand(use.operand_number)->shape(); |
| // Find the earliest time that satisfies max_async_copy_to_overlap_ratio_. |
| async_copy_elapsed_ = cost_analysis_.GetAsyncCopyElapsed(shape); |
| // Estimate the time we would save by having this op in alternate memory. |
| float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction); |
| float elapsed_time_in_alternate_mem = cost_analysis_.GetInstructionElapsed( |
| *use.instruction, use.operand_number); |
| inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem; |
| end_logical_time_ = end_time; |
| // Find the earliest time we're allowed to start prefetching. |
| for (current_logical_prefetch_time_ = start_time; |
| max_async_copy_to_overlap_ratio_ * async_copy_elapsed_ < |
| GetLogicalIntervalElapsed(current_logical_prefetch_time_, |
| end_logical_time_); |
| ++current_logical_prefetch_time_) { |
| } |
| } |
| |
| int64 CostAnalysisPrefetchIntervalPicker::Next() { |
| CHECK(!Done()) << "Prefetch interval picker's Next() is called even though " |
| "Done() is false"; |
| return current_logical_prefetch_time_++; |
| } |
| |
| bool CostAnalysisPrefetchIntervalPicker::Done() const { |
| float logical_interval_elapsed = GetLogicalIntervalElapsed( |
| current_logical_prefetch_time_, end_logical_time_); |
| return min_async_copy_to_overlap_ratio_ * async_copy_elapsed_ - |
| inst_elapsed_reduction_ > |
| logical_interval_elapsed; |
| } |
| |
| float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed( |
| int64 start_time, int64 end_time) const { |
| return elapsed_time_cumsum_[end_time - 1] - elapsed_time_cumsum_[start_time]; |
| } |
| |
| std::vector<const GlobalDecreasingSizeBestFitHeap::BufferInterval*> |
| AlternateMemoryBestFitHeap::GetSortedColocatedIntervals( |
| const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { |
| std::vector<const BufferInterval*> colocated_intervals; |
| std::vector<const BufferInterval*> worklist = {&interval}; |
| while (!worklist.empty()) { |
| const BufferInterval* item = worklist.back(); |
| worklist.pop_back(); |
| colocated_intervals.push_back(item); |
| for (const HloValue* buffer_colocated : item->colocations) { |
| worklist.push_back(&buffer_intervals_.at(buffer_colocated)); |
| } |
| } |
| |
| absl::c_sort(colocated_intervals, [&](const BufferInterval* x, |
| const BufferInterval* y) { |
| return std::make_pair(x->start, x->end) < std::make_pair(y->start, y->end); |
| }); |
| return colocated_intervals; |
| } |
| |
| HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { |
| std::vector<BufferInterval> sorted_buffer_intervals = |
| GetSortedBufferIntervals(); |
| |
| VLOG(1) << "Assigning buffers to alternate memory. Max heap size = " |
| << max_size_in_bytes_; |
| |
| AddInputAndOutputRequiredAssignments(); |
| prefetch_interval_picker_->SetInstructionSchedule( |
| hlo_live_range_.instruction_schedule()); |
| |
| for (auto& interval : sorted_buffer_intervals) { |
| if (!interval.need_allocation) { |
| continue; |
| } |
| |
| // Skip if we have already allocated for this buffer. |
| if (allocation_map_->contains(interval.buffer)) { |
| continue; |
| } |
| |
| // If the buffer is a tuple, don't use this algorithm for now. The buffers |
| // that are pointed to by the tuple will still use this algorithm. Because |
| // tuples are cheap to place in the alternate memory (they are just |
| // pointers) we don't need to use prefetch/evict logic. |
| if (interval.buffer->shape().IsTuple()) { |
| VLOG(4) << "Keeping value " << interval.buffer->ToShortString() |
| << " in default mem because it is a tuple."; |
| continue; |
| } |
| |
| auto colocated_intervals = GetSortedColocatedIntervals(interval); |
| bool keep_in_default_memory = false; |
| for (const BufferInterval* colocated_interval : colocated_intervals) { |
| const HloValue* value = colocated_interval->buffer; |
| // If any of the colocated values are phi buffers, we keep them in the |
| // default memory for now. |
| if (value->is_phi()) { |
| keep_in_default_memory = true; |
| VLOG(4) << "Keeping value " << value->ToShortString() |
| << " because it contains a phi node."; |
| break; |
| } |
| } |
| |
| // At this point, none of the colocated buffers contain any phi buffers. |
| for (const BufferInterval* colocated_interval : colocated_intervals) { |
| if (keep_in_default_memory) { |
| break; |
| } |
| const HloValue* value = colocated_interval->buffer; |
| const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); |
| MemorySpaceAssignment::AllocationSequence* allocation_sequence = |
| &(*allocation_map_)[value]; |
| int64 definition_time = |
| instruction_schedule.at(value->defining_instruction()); |
| // Sort the uses by the use time. |
| std::vector<HloUse> uses = value->uses(); |
| absl::c_sort(uses, [&](HloUse use1, HloUse use2) { |
| return instruction_schedule.at(use1.instruction) < |
| instruction_schedule.at(use2.instruction); |
| }); |
| // Iterate over the uses. |
| for (HloUse use : uses) { |
| int64 use_time = instruction_schedule.at(use.instruction); |
| int64 last_use_time = instruction_schedule.at(uses.back().instruction); |
| |
| // Bitcasts don't define buffers and don't directly consume buffers. |
| // Skip allocating buffers for bitcast uses. The uses that feed from |
| // bitcasts will be handled specially. |
| if (use.instruction->opcode() != HloOpcode::kBitcast) { |
| if (!FindAllocation(definition_time, use_time, last_use_time, |
| value->defining_position(), use, value, |
| colocated_interval->size, allocation_sequence)) { |
| // If the allocation finding failed (e.g., due to running out of |
| // asynchronous copies), then fall back to allocating the buffer |
| // entirely in the default memory. |
| pending_chunks_.clear(); |
| pending_async_copies_.clear(); |
| allocation_sequence->clear(); |
| keep_in_default_memory = true; |
| break; |
| } |
| |
| // If there are multiple uses, they can try using the memory |
| // allocation already at the alternate memory. |
| definition_time = use_time; |
| } |
| } |
| } |
| |
| CommitPendingChunks(); |
| } |
| |
| if (VLOG_IS_ON(3)) { |
| for (const auto& alloc_pair : *allocation_map_) { |
| VLOG(3) << "Allocation for " << alloc_pair.first->ToShortString(); |
| for (const auto& alloc : alloc_pair.second) { |
| std::string addr_str = ": default"; |
| if (alloc->memory_space() == MemorySpace::kAlternate) { |
| addr_str = absl::StrCat(": alt ", alloc->chunk().offset); |
| } |
| |
| VLOG(3) << " " << alloc->start_time() << "-" << alloc->end_time() |
| << addr_str << ", " << alloc->uses().size() << " uses"; |
| } |
| } |
| } |
| |
| return result_; |
| } |
| |
| void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() { |
| // Go through the parameters and outputs and pin them to default memory by |
| // adding a required assignment. |
| // TODO(berkin): If these values are already marked alternate memory, use |
| // those instead. |
| const HloDataflowAnalysis& dataflow_analysis = |
| alias_analysis_.dataflow_analysis(); |
| const HloModule& module = dataflow_analysis.module(); |
| const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); |
| HloComputation* entry_computation = module.entry_computation(); |
| for (HloInstruction* parameter_instruction : |
| entry_computation->parameter_instructions()) { |
| int64 parameter_instruction_time = |
| instruction_schedule.at(parameter_instruction); |
| ShapeUtil::ForEachSubshape( |
| parameter_instruction->shape(), |
| [&](const Shape& /*subshape*/, const ShapeIndex& index) { |
| for (const HloValue* value : |
| dataflow_analysis.GetValueSet(parameter_instruction, index) |
| .values()) { |
| VLOG(3) << "Adding required assignment for parameter value = " |
| << value->ToShortString() |
| << " time = " << parameter_instruction_time; |
| required_assignments_[value].push_back( |
| {/*memory_space=*/MemorySpace::kDefault, |
| /*time=*/parameter_instruction_time}); |
| } |
| }); |
| } |
| HloInstruction* root_instruction = entry_computation->root_instruction(); |
| int64 root_instruction_time = instruction_schedule.at(root_instruction); |
| ShapeUtil::ForEachSubshape( |
| root_instruction->shape(), |
| [&](const Shape& /*subshape*/, const ShapeIndex& index) { |
| for (const HloValue* value : |
| dataflow_analysis.GetValueSet(root_instruction, index).values()) { |
| VLOG(3) << "Adding required assignment for output value = " |
| << value->ToShortString() |
| << " time = " << root_instruction_time; |
| required_assignments_[value].push_back( |
| {/*memory_space=*/MemorySpace::kDefault, |
| /*time=*/root_instruction_time}); |
| } |
| }); |
| } |
| |
| void AlternateMemoryBestFitHeap::CommitPendingChunks() { |
| for (auto interval_and_chunk : pending_chunks_) { |
| VLOG(3) << "Committing chunk: " << interval_and_chunk.first.start << "-" |
| << interval_and_chunk.first.end << " : [" |
| << interval_and_chunk.second.chunk.offset << ", " |
| << interval_and_chunk.second.chunk.size << "]"; |
| CommitChunk(interval_and_chunk.first, interval_and_chunk.second); |
| } |
| pending_chunks_.clear(); |
| // Also add the pending async copies to the interval tree. |
| if (max_outstanding_async_copies_ >= 0) { |
| for (auto interval : pending_async_copies_) { |
| async_copy_interval_tree_.Add(interval.first, interval.second, |
| kDummyChunk); |
| } |
| } |
| pending_async_copies_.clear(); |
| } |
| |
| void AlternateMemoryBestFitHeap::AddToPendingChunks( |
| const BufferInterval& buffer_interval, |
| const ChunkCandidate& chunk_candidate) { |
| pending_chunks_.emplace_back(buffer_interval, chunk_candidate); |
| } |
| |
| bool AlternateMemoryBestFitHeap::FindAllocation( |
| int64 start_time, int64 end_time, int64 last_use_time, |
| HloPosition defining_position, HloUse use, const HloValue* buffer, |
| int64 size, MemorySpaceAssignment::AllocationSequence* allocations) { |
| HloInstruction* operand = |
| use.instruction->mutable_operand(use.operand_number); |
| // If the operand is a bitcast, we look at bitcast's operand until we find a |
| // non-bitcast operand. |
| HloInstruction* non_bitcast_operand = operand; |
| while (non_bitcast_operand->opcode() == HloOpcode::kBitcast) { |
| non_bitcast_operand = non_bitcast_operand->mutable_operand(0); |
| } |
| // Create an alternate memory interval that starts at the earliest |
| // possible position, given by max_prefetch_interval. |
| BufferInterval alternate_mem_interval; |
| alternate_mem_interval.buffer = buffer; |
| alternate_mem_interval.size = size; |
| alternate_mem_interval.end = end_time; |
| |
| VLOG(2) << "Finding allocation for " << buffer->ToShortString() << " (" |
| << start_time << ", " << end_time << ") last use = " << last_use_time |
| << " use = " << use.ToString() << ". Size = " << size |
| << ", def pos = " << defining_position.ToString() |
| << ", operand = " << operand->ToShortString() |
| << (non_bitcast_operand != operand |
| ? ", non_bitcast_operand = " + |
| non_bitcast_operand->ToShortString() |
| : ""); |
| CHECK_LE(start_time, end_time); |
| |
| // There could be a requirement to pin this buffer to default memory either at |
| // the definition site (e.g., parameters) or at the use site (e.g., outputs). |
| // If there is a definition requirement, then we're allowed to prefetch, but |
| // if it's a use requirement, we cannot prefetch the buffer. If the use |
| // expects the buffer to be in default memory, we cannot prefetch it because |
| // if we did, it would be in alternate memory instead. |
| bool definition_requires_buffer_in_default_mem = false; |
| bool use_requires_buffer_in_default_mem = false; |
| auto required_assignment_it = required_assignments_.find(buffer); |
| if (required_assignment_it != required_assignments_.end()) { |
| for (const RequiredMemoryAssignment& required_assignment : |
| required_assignment_it->second) { |
| VLOG(3) << "Required assignment at time = " << required_assignment.time; |
| // TODO(berkin): Handle memory requirements for alternate memory space. |
| if (required_assignment.memory_space == MemorySpace::kDefault) { |
| if (required_assignment.time == start_time) { |
| definition_requires_buffer_in_default_mem = true; |
| VLOG(3) << "Definition requires buffer in default memory."; |
| } |
| if (required_assignment.time == end_time) { |
| use_requires_buffer_in_default_mem = true; |
| VLOG(3) << "Use requires buffer in default memory."; |
| } |
| } |
| } |
| } |
| |
| // TODO(berkin): This is curently overly restrictive and will fail using |
| // alternate memory for any buffer that might leak into a different |
| // computation (e.g., while body). Enable more usage of alternate memory |
| // across computations. |
| if (defining_position.instruction->parent() != use.instruction->parent() || |
| (!use.instruction->called_computations().empty() && |
| use.instruction->opcode() != HloOpcode::kFusion)) { |
| VLOG(3) << "Use is in a different computation or calls a computation."; |
| // Fail because we do not allow asynchronous copies while in the bodies of |
| // other computation. |
| return false; |
| } |
| |
| // First try keeping the allocation entirely in the alternate memory. |
| if (!definition_requires_buffer_in_default_mem && |
| !use_requires_buffer_in_default_mem && |
| TryAllocatingInAlternateMemoryNoCopy( |
| start_time, end_time, last_use_time, defining_position, use, |
| alternate_mem_interval, non_bitcast_operand, allocations)) { |
| return true; |
| } |
| |
| MemorySpaceAssignment::Allocation* prev_allocation = nullptr; |
| if (!allocations->empty()) { |
| prev_allocation = allocations->back().get(); |
| } |
| |
| // Since copies couldn't be removed, create an allocation in the default |
| // memory space. |
| if (prev_allocation != nullptr && |
| prev_allocation->memory_space() == MemorySpace::kAlternate && |
| prev_allocation->defining_position() == defining_position) { |
| // If there was an allocation for this HloValue that was in the alternate |
| // memory space, we also need to perform an eviction. |
| // TODO(berkin): For now evictions happen relative to the most recent |
| // allocation in the alternate memory. We can potentially start evictions |
| // earlier and end later. |
| VLOG(3) << "Evicting buffer at " << prev_allocation->chunk().offset << " (" |
| << prev_allocation->start_time() << ", " |
| << prev_allocation->end_time() << ")"; |
| |
| // See if this interval would violate the asynchronous copy limit. |
| if (!ViolatesMaximumOutstandingAsyncCopies(prev_allocation->start_time(), |
| prev_allocation->end_time())) { |
| AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, kDummyChunk, |
| prev_allocation->start_time(), prev_allocation->end_time(), |
| allocations); |
| |
| } else { |
| VLOG(3) << "This violates the maximum async copies."; |
| // If the original interval violated the limit, try sub-intervals within |
| // this interval. |
| bool eviction_scheduled = false; |
| for (int64 time = prev_allocation->start_time(); |
| time <= prev_allocation->end_time(); ++time) { |
| VLOG(3) << "Try evicting (" << time << ", " << time << ")"; |
| if (!ViolatesMaximumOutstandingAsyncCopies(time, time)) { |
| VLOG(3) << "Eviction successful."; |
| AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, kDummyChunk, |
| time, time, allocations); |
| eviction_scheduled = true; |
| break; |
| } |
| } |
| |
| if (!eviction_scheduled) { |
| // If the eviction couldn't be scheduled, then fail. This buffer will be |
| // kept in the default memory. |
| VLOG(3) << "Bailing: Could not evict " << use.ToString() |
| << " because we hit the limit of maximum asynchronous copies " |
| << "between " |
| << hlo_live_range_.flattened_instruction_sequence() |
| .instructions()[prev_allocation->start_time()] |
| << " and " |
| << hlo_live_range_.flattened_instruction_sequence() |
| .instructions()[prev_allocation->end_time()]; |
| return false; |
| } |
| } |
| } else if (prev_allocation != nullptr && |
| prev_allocation->memory_space() == MemorySpace::kDefault && |
| prev_allocation->defining_position() == defining_position) { |
| // If the previous allocation was in the default memory space and was |
| // defined by the same instruction, extend that. Otherwise, create a new |
| // allocation. |
| prev_allocation->Extend(end_time); |
| } else { |
| allocations->push_back(absl::make_unique<MemorySpaceAssignment::Allocation>( |
| non_bitcast_operand, defining_position, MemorySpace::kDefault, |
| kDummyChunk, start_time, end_time)); |
| } |
| |
| // If the use requires the buffer to be in default memory, don't try to |
| // prefetch. |
| if (use_requires_buffer_in_default_mem) { |
| VLOG(4) |
| << "Not trying to prefetch because use requires buffer in default mem."; |
| allocations->back()->AddUse(use); |
| return true; |
| } |
| |
| // Try partially placing the buffer in the alternate space. The time that is |
| // overlapped will be used to asynchronously copy the buffer from the |
| // default memory to the alternate memory. |
| // |
| // start end |
| // time time |
| // X---------------------X |
| // Alternate: +------+ |
| // Default: +---------------------+ |
| // ^ ^ |
| // Copy Copy |
| // Start Done |
| prefetch_interval_picker_->Begin(use, start_time, end_time); |
| while (!prefetch_interval_picker_->Done()) { |
| alternate_mem_interval.start = prefetch_interval_picker_->Next(); |
| VLOG(4) << "Trying alternate memory allocation (" |
| << alternate_mem_interval.start << ", " |
| << alternate_mem_interval.end << ")"; |
| // If this additional asynchronous copy would violate the limit, try a |
| // different interval. |
| if (ViolatesMaximumOutstandingAsyncCopies(alternate_mem_interval.start, |
| alternate_mem_interval.end)) { |
| VLOG(4) << "This would violate the outstanding async copy limit."; |
| continue; |
| } |
| ChunkCandidate chunk_candidate = FindChunkCandidate(alternate_mem_interval); |
| // Check if the new heap size fits within limits. |
| if (chunk_candidate.heap_size < max_size_in_bytes_) { |
| VLOG(3) << "Move the buffer to alternate memory at " |
| << alternate_mem_interval.start |
| << ". Offset = " << chunk_candidate.chunk.offset |
| << ", size = " << chunk_candidate.chunk.size |
| << ", heap_size = " << chunk_candidate.heap_size; |
| AddToPendingChunks(alternate_mem_interval, chunk_candidate); |
| |
| AddAsyncCopy(*allocations->back().get(), MemorySpace::kAlternate, |
| chunk_candidate.chunk, alternate_mem_interval.start, |
| end_time, allocations); |
| |
| allocations->back()->AddUse(use); |
| return true; |
| } |
| } |
| |
| // If a copy wasn't inserted, then add this use to the latest allocation. |
| allocations->back()->AddUse(use); |
| return true; |
| } |
| |
| void AlternateMemoryBestFitHeap::AddAsyncCopy( |
| const MemorySpaceAssignment::Allocation& prev_allocation, |
| MemorySpace memory_space, Chunk chunk, int64 start_time, int64 end_time, |
| MemorySpaceAssignment::AllocationSequence* allocations) { |
| VLOG(3) << "Copy to " |
| << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault |
| ? "default" |
| : "alternate") |
| << " memory between " << start_time << " and " << end_time; |
| |
| allocations->push_back( |
| absl::make_unique<MemorySpaceAssignment::CopyAllocation>( |
| prev_allocation, memory_space, chunk, start_time, end_time)); |
| |
| // Register the additional async copy with the interval tree to keep track of |
| // the limit at any given time. |
| pending_async_copies_.emplace_back(start_time, end_time); |
| } |
| |
| bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies( |
| int64 start_time, int64 end_time) const { |
| if (max_outstanding_async_copies_ < 0) { |
| return false; |
| } |
| |
| // Count both the asynchronous copies in the interval tree as well as the |
| // pending asynchronous copies belonging to this buffer. |
| int64 num_async_copies = |
| async_copy_interval_tree_.ChunksOverlappingInTime(start_time, end_time) |
| .size(); |
| |
| for (auto interval : pending_async_copies_) { |
| if (interval.second > start_time && interval.first < end_time) { |
| num_async_copies++; |
| } |
| } |
| // Add one because we are checking if adding an additional asynchronous copy |
| // would violate the limit. |
| return num_async_copies + 1 > max_outstanding_async_copies_; |
| } |
| |
| bool AlternateMemoryBestFitHeap::TryAllocatingInAlternateMemoryNoCopy( |
| int64 start_time, int64 end_time, int64 last_use_time, |
| HloPosition defining_position, HloUse use, |
| BufferInterval alternate_mem_interval, HloInstruction* non_bitcast_operand, |
| MemorySpaceAssignment::AllocationSequence* allocations) { |
| MemorySpaceAssignment::Allocation* prev_allocation = nullptr; |
| bool can_eliminate_copy = false; |
| if (allocations->empty()) { |
| // There hasn't been any allocations for this interval so far. We can |
| // eliminate copy if the value can be placed in the alternate memory. |
| can_eliminate_copy = |
| is_allowed_in_alternate_mem_(*alternate_mem_interval.buffer); |
| } else { |
| // If there has been a previous allocation, we can eliminate the copy if the |
| // previous allocation was also in the alternate memory. |
| prev_allocation = allocations->back().get(); |
| can_eliminate_copy = |
| (prev_allocation->memory_space() == MemorySpace::kAlternate); |
| } |
| |
| if (!can_eliminate_copy) { |
| return false; |
| } |
| |
| if (!prefetch_interval_picker_->CanAllocateInAlternateMemoryNoCopy( |
| non_bitcast_operand->shape(), start_time, end_time)) { |
| return false; |
| } |
| |
| alternate_mem_interval.start = start_time; |
| |
| // Prefer the offset that was previously used for the previous allocation. |
| int64 preferred_offset = -1; |
| if (prev_allocation != nullptr) { |
| preferred_offset = prev_allocation->chunk().offset; |
| // If there is a previous allocation, set the start time one after the end |
| // of the previous allocation's end. |
| alternate_mem_interval.start = prev_allocation->end_time() + 1; |
| } |
| |
| VLOG(4) << "We can eliminate copy to alternate memory. Preferred offset = " |
| << preferred_offset; |
| // In case there are additional uses after this use, we rely on the last use |
| // time to try to reserve a chunk in the heap simulator. This is to prevent |
| // the following scenario: |
| // |
| // +-------+ |
| // / \ |
| // Producer--->Use1 +-->Use2 |
| // +---------+---------+ |
| // New buffer: | | | |
| // +---------+---------+ |
| // |
| // +-----------+ |
| // Current heap: | offset: 0 | |
| // --------------------------+-----------+------ |
| // |
| // Because we allocate buffers greedily, Producer to Use1 segment first, and |
| // then Use1 to Use2 segment, it is possible to allocate the first segment at |
| // an offset that is available for the first segment (e.g. offset 0) but not |
| // for the entire live range. This can result in unnecessary copies. By using |
| // the last use time, we try to find an allocation that is available for the |
| // entire Producer to Use2 range. |
| alternate_mem_interval.end = last_use_time; |
| ChunkCandidate chunk_candidate = |
| FindChunkCandidate(alternate_mem_interval, preferred_offset); |
| alternate_mem_interval.end = end_time; |
| // Check if the new heap size fits within limits. Also ensure if a |
| // preferred offset was provided, that offset was used. |
| if (chunk_candidate.heap_size < max_size_in_bytes_ && |
| (preferred_offset == -1 || |
| preferred_offset == chunk_candidate.chunk.offset)) { |
| VLOG(3) << "Keep the buffer in alternate memory. Offset = " |
| << chunk_candidate.chunk.offset |
| << ", size = " << chunk_candidate.chunk.size |
| << ", heap_size = " << chunk_candidate.heap_size; |
| AddToPendingChunks(alternate_mem_interval, chunk_candidate); |
| |
| // If there was a previous allocation, the buffer location is the |
| // same as the previous. Otherwise, it is the operand. |
| if (prev_allocation != nullptr && |
| (prev_allocation->is_copy_allocation() || |
| prev_allocation->defining_position() == defining_position)) { |
| prev_allocation->Extend(end_time); |
| } else { |
| allocations->push_back( |
| absl::make_unique<MemorySpaceAssignment::Allocation>( |
| non_bitcast_operand, defining_position, MemorySpace::kAlternate, |
| chunk_candidate.chunk, start_time, end_time)); |
| } |
| allocations->back()->AddUse(use); |
| return true; |
| } |
| return false; |
| } |
| |
| /*static*/ int64 MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies( |
| const HloModule& module) { |
| int64 max_copies = 0; |
| int64 current_copies = 0; |
| for (HloInstruction* instruction : |
| module.schedule().sequence(module.entry_computation()).instructions()) { |
| if (instruction->opcode() == HloOpcode::kCopyStart) { |
| current_copies++; |
| } else if (instruction->opcode() == HloOpcode::kCopyDone) { |
| current_copies--; |
| } |
| max_copies = std::max(max_copies, current_copies); |
| } |
| return max_copies; |
| } |
| |
| /*static*/ MemorySpaceAssignment::BufferIntervalCompare |
| MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( |
| const MemorySpaceAssignmentCostAnalysis& cost_analysis) { |
| return [&](const BufferInterval& x, const BufferInterval& y) { |
| // Returns a heuristic value that captures how much putting this tensor to |
| // the alternate memory would help if the op is memory bound, or otherwise |
| // how far off is the op to memory boundedness. The larger this number, the |
| // higher priority it will be placed in the alternate memory. |
| auto get_alternate_mem_benefit = |
| [&](const HloInstruction& instruction, |
| float elapsed_time_due_to_alternate_mem) { |
| float elapsed_time_due_to_compute = |
| cost_analysis.GetInstructionElapsedDueToCompute(instruction); |
| float elapsed_time_due_to_memory = |
| cost_analysis.GetInstructionElapsedDueToMemory(instruction); |
| if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) { |
| // Memory bound, return how much alternate memory is better. |
| return elapsed_time_due_to_memory - |
| elapsed_time_due_to_alternate_mem; |
| } else { |
| // Compute bound, return how far off are we to memory boundedness. |
| return elapsed_time_due_to_memory - elapsed_time_due_to_compute; |
| } |
| }; |
| |
| auto get_memory_boundedness = [&](const BufferInterval& interval) { |
| const HloInstruction& defining_instruction = |
| *interval.buffer->defining_instruction(); |
| float alternate_mem_benefit = get_alternate_mem_benefit( |
| defining_instruction, cost_analysis.GetInstructionElapsedDueToMemory( |
| defining_instruction, |
| /*operand_in_alternate_mem=*/{}, |
| /*output_in_alternate_mem=*/true)); |
| for (const HloUse& use : interval.buffer->uses()) { |
| float use_alternate_mem_benefit = get_alternate_mem_benefit( |
| *use.instruction, cost_analysis.GetInstructionElapsedDueToMemory( |
| *use.instruction, use.operand_number)); |
| // If the benefit is positive (memory bound), add it to this buffer's |
| // benefit. If the benefit is negative (compute bound), calculate the |
| // maximum. |
| if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) { |
| alternate_mem_benefit += use_alternate_mem_benefit; |
| } else { |
| alternate_mem_benefit = |
| std::max(alternate_mem_benefit, use_alternate_mem_benefit); |
| } |
| } |
| return alternate_mem_benefit; |
| }; |
| |
| float x_memory_boundedness = get_memory_boundedness(x); |
| float y_memory_boundedness = get_memory_boundedness(y); |
| if (x_memory_boundedness != y_memory_boundedness) { |
| return x_memory_boundedness > y_memory_boundedness; |
| } |
| // Tie-break if the memory boundedness is the same. |
| return GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare()( |
| x, y); |
| }; |
| } |
| |
| /*static*/ StatusOr<std::unique_ptr<PresetAssignments>> |
| MemorySpaceAssignment::Run( |
| HloModule* module, int64 alternate_memory_space, int64 max_size_in_bytes, |
| absl::optional<BufferIntervalCompare> buffer_interval_compare, |
| PrefetchIntervalPicker* prefetch_interval_picker, |
| int64 alternate_memory_space_alignment_in_bytes, |
| BufferValue::SizeFunction size_fn, |
| AlternateMemoryBestFitHeap::IsAllowedInAlternateMemoryFunction |
| is_allowed_in_alternate_mem, |
| int64 max_outstanding_async_copies) { |
| CHECK(module->has_schedule()); |
| VLOG(4) << "Module before memory space assignment: "; |
| XLA_VLOG_LINES(4, module->ToString()); |
| VLOG(4) << "Schedule: " << module->schedule().ToString(); |
| TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module)); |
| |
| MemorySpaceAssignment memory_space_assignment(module, alternate_memory_space); |
| const HloComputation* entry_computation = module->entry_computation(); |
| TF_ASSIGN_OR_RETURN(memory_space_assignment.hlo_live_range_, |
| HloLiveRange::Run(module->schedule(), *alias_analysis, |
| entry_computation)); |
| auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>( |
| &memory_space_assignment.allocation_map_, max_size_in_bytes, |
| buffer_interval_compare, prefetch_interval_picker, *alias_analysis, |
| *memory_space_assignment.hlo_live_range_, |
| alternate_memory_space_alignment_in_bytes, is_allowed_in_alternate_mem, |
| max_outstanding_async_copies); |
| |
| TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module, |
| module->schedule(), |
| *alias_analysis.get(), size_fn) |
| .status()); |
| |
| TF_RETURN_IF_ERROR(memory_space_assignment.Process()); |
| memory_space_assignment.ScheduleAsynchronousCopies(); |
| TF_RETURN_IF_ERROR(memory_space_assignment.SimplifyGraph()); |
| TF_RETURN_IF_ERROR(memory_space_assignment.FixSchedule()); |
| |
| VLOG(4) << "Module after memory space assignment: "; |
| XLA_VLOG_LINES(4, module->ToString()); |
| TF_CHECK_OK(module->schedule().Verify()); |
| VLOG(1) << "Maximum number of outstanding async copies: " |
| << CountMaximumOutstandingAsyncCopies(*module); |
| |
| return std::move(memory_space_assignment.preset_assignments_); |
| } |
| |
| void MemorySpaceAssignment::Allocation::AddUse(HloUse use) { |
| HloInstruction* operand = |
| use.instruction->mutable_operand(use.operand_number); |
| // When the operand of a use is a bitcast, we place the bitcast in a separate |
| // data structure. |
| if (operand->opcode() == HloOpcode::kBitcast) { |
| bitcasts_.push_back(operand); |
| } else { |
| uses_.push_back(use); |
| } |
| } |
| |
| Status MemorySpaceAssignment::Allocation::Process( |
| MemorySpaceAssignment* memory_space_assignment) { |
| // For non-copy allocations, all we need to do is to update the output memory |
| // space if placed in the alternate memory. |
| if (memory_space_ == MemorySpace::kAlternate) { |
| memory_space_assignment->AddPositionInAlternateMemorySpace( |
| defining_position()); |
| } |
| return Status::OK(); |
| } |
| |
| StatusOr<HloInstruction*> MemorySpaceAssignment::Allocation::ReplaceTupleWith( |
| HloInstruction* new_instruction, HloInstruction* tuple, |
| ShapeIndex shape_index) { |
| const Shape& tuple_shape = tuple->shape(); |
| CHECK(tuple->shape().IsTuple()) |
| << "ReplaceTupleWith was called for a non-tuple. Tuple = " |
| << tuple->ToString() |
| << ", new_instruction = " << new_instruction->ToString() |
| << ", shape_index = " << shape_index.ToString(); |
| |
| HloComputation* computation = new_instruction->parent(); |
| std::vector<HloInstruction*> tuple_args(tuple_shape.tuple_shapes_size()); |
| for (int64 i = 0; i < tuple_shape.tuple_shapes_size(); ++i) { |
| const Shape& subshape = tuple_shape.tuple_shapes(i); |
| if (i == shape_index[0]) { |
| // If the subshape is still a tuple, recurse and pass a new shape index |
| // for the one level deeper. |
| if (subshape.IsTuple()) { |
| HloInstruction* get_tuple_element = computation->AddInstruction( |
| HloInstruction::CreateGetTupleElement(subshape, tuple, i)); |
| TF_ASSIGN_OR_RETURN(tuple_args[i], |
| ReplaceTupleWith(new_instruction, get_tuple_element, |
| ShapeIndex(shape_index.begin() + 1, |
| shape_index.end()))); |
| } else { |
| tuple_args[i] = new_instruction; |
| } |
| } else { |
| HloInstruction* get_tuple_element = computation->AddInstruction( |
| HloInstruction::CreateGetTupleElement(subshape, tuple, i)); |
| tuple_args[i] = get_tuple_element; |
| } |
| } |
| return computation->AddInstruction(HloInstruction::CreateTuple(tuple_args)); |
| } |
| |
| Status MemorySpaceAssignment::CopyAllocation::Process( |
| MemorySpaceAssignment* memory_space_assignment) { |
| // Copy allocations need to insert asynchronous copy nodes. |
| HloInstruction* producing_instruction = defining_position().instruction; |
| CHECK_NE(producing_instruction, nullptr); |
| |
| Shape shape = defining_position().shape(); |
| CHECK(shape.IsArray()) << "CopyAllocation shape is not an array. Shape = " |
| << shape.ToString() |
| << " position = " << defining_position().shape(); |
| HloComputation* computation = producing_instruction->parent(); |
| |
| // If the instruction we're copying from is a tuple, we (recursively) create |
| // kGetTupleElement instructions and copy that value. Asynchronous copies only |
| // support array types. |
| if (!producing_instruction->shape().IsArray()) { |
| producing_instruction = defining_position().instruction; |
| for (int64 index : defining_position().index) { |
| producing_instruction = |
| computation->AddInstruction(HloInstruction::CreateGetTupleElement( |
| producing_instruction->shape().tuple_shapes(index), |
| producing_instruction, index)); |
| } |
| } |
| copy_start_ = computation->AddInstruction(HloInstruction::CreateUnary( |
| ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}), |
| HloOpcode::kCopyStart, producing_instruction)); |
| copy_done_ = computation->AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_)); |
| // Update the allocation with the copy done instruction so that if there |
| // are further copies from it, it can find the correct instruction. |
| instruction_ = copy_done_; |
| |
| // Also update the defining position. |
| defining_position_ = HloPosition{copy_done_, {}}; |
| |
| // Replace all the uses with the new copy instruction. |
| for (HloUse use : uses_) { |
| // If the operand is a tuple, we need to descend to the actual instruction |
| // we want to replace. |
| HloInstruction* replacement_instruction; |
| if (use.instruction->operand(use.operand_number)->shape().IsTuple()) { |
| TF_ASSIGN_OR_RETURN( |
| replacement_instruction, |
| ReplaceTupleWith(copy_done_, |
| use.instruction->mutable_operand(use.operand_number), |
| use.operand_index)); |
| } else { |
| replacement_instruction = copy_done_; |
| } |
| TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith( |
| use.operand_number, replacement_instruction)); |
| } |
| |
| // Replace all the bitcasts with the new copy instruction. Note that if there |
| // is a chain of bitcasts, their operands will be replaced with copy done. |
| // For example: |
| // |
| // a = Foo() |
| // b = Bitcast(a) |
| // c = Bitcast(b) |
| // |
| // If a is moved to the alternate memory asynchronously, the graph will be |
| // changed into: |
| // |
| // a = Foo() |
| // cs = CopyStart(a) |
| // cd = CopyDone(cs) |
| // b = Bitcast(cd) |
| // c = Bitcast(cd) |
| // |
| // Because of the potential shape change in the operand (b -> cd), we use |
| // ReplaceOperandWithDifferentShape. |
| for (HloInstruction* bitcast : bitcasts_) { |
| TF_RETURN_IF_ERROR(bitcast->ReplaceOperandWithDifferentShape( |
| /*operand_num=*/0, copy_done_)); |
| } |
| |
| if (memory_space_ == MemorySpace::kAlternate) { |
| memory_space_assignment->AddPositionInAlternateMemorySpace({copy_done_}); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status MemorySpaceAssignment::Process() { |
| // Insert CopyStart/CopyDone pairs. |
| int64 alternate_memory_size = 0; |
| for (auto& buffer_and_sequence : allocation_map_) { |
| for (auto& allocation : buffer_and_sequence.second) { |
| TF_RETURN_IF_ERROR(allocation->Process(this)); |
| // Add the offset and size of the allocation in the alternate memory to |
| // the output map. Special case for bitcast: since bitcast doesn't define |
| // its own buffer, that shouldn't be exported as a preset chunk. |
| if (allocation->memory_space() == MemorySpace::kAlternate && |
| allocation->instruction()->opcode() != HloOpcode::kBitcast) { |
| preset_assignments_->add_chunk(allocation->defining_position(), |
| allocation->chunk()); |
| alternate_memory_size = |
| std::max(alternate_memory_size, allocation->chunk().chunk_end()); |
| } |
| } |
| } |
| |
| if (!preset_assignments_->chunks().empty()) { |
| preset_assignments_->add_size(alternate_memory_space_, |
| alternate_memory_size); |
| } |
| |
| if (VLOG_IS_ON(3)) { |
| VLOG(3) << "Exported alternate memory allocations:"; |
| for (auto& pair : preset_assignments_->chunks()) { |
| VLOG(3) << " [" << pair.second.offset << ", " << pair.second.size |
| << "] : " << pair.first.ToString(); |
| } |
| VLOG(3) << "Exported alternate memory sizes:"; |
| for (auto& pair : preset_assignments_->sizes()) { |
| VLOG(3) << " space: " << pair.first << ", size: " << pair.second; |
| } |
| } |
| |
| // Color the pending positions and all of their aliased buffers. |
| TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module_)); |
| for (HloPosition defining_position : pending_positions_in_alternate_mem_) { |
| for (auto& buffer : alias_analysis->ComputeBuffersAt( |
| defining_position.instruction, defining_position.index)) { |
| for (auto& value : buffer->values()) { |
| for (auto& position : value->positions()) { |
| VLOG(3) << "Coloring " << position.ToString(); |
| Shape* shape = ShapeUtil::GetMutableSubshape( |
| position.instruction->mutable_shape(), position.index); |
| CHECK(shape->IsArray()) << "Coloring a shape that is not an array: " |
| << position.ToString(); |
| shape->mutable_layout()->set_memory_space(alternate_memory_space_); |
| } |
| } |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| void PresetAssignments::RemoveAssignmentForInstruction( |
| const HloInstruction* instruction) { |
| for (auto& position_and_chunk : chunks_) { |
| const HloPosition& position = position_and_chunk.first; |
| if (position.instruction == instruction) { |
| VLOG(3) << "Removing instruction from preset assignments."; |
| // Swap the removed position and chunk with the back and pop back. |
| position_and_chunk = chunks_.back(); |
| chunks_.pop_back(); |
| break; |
| } |
| } |
| } |
| |
| Status MemorySpaceAssignment::SimplifyGraph() { |
| for (HloComputation* computation : module_->MakeNonfusionComputations()) { |
| // We perform limited DCE and forward the tuple operand in patterns like |
| // GetTupleElement(Tuple(a, b), 0). This is mostly because memory space |
| // assignment is ran late in compilation (after DCE and arithmetic |
| // simplification passes) and we don't want to generate redundant code. Run |
| // to fixed point. |
| bool computation_modified = true; |
| while (computation_modified) { |
| computation_modified = false; |
| VLOG(4) << "Running simplify graph loop over " << computation->name(); |
| for (HloInstruction* instruction : |
| computation->MakeInstructionPostOrder()) { |
| if (computation->IsSafelyRemovable(instruction) && |
| instruction->user_count() == 0 && !instruction->HasSideEffect() && |
| instruction != computation->root_instruction()) { |
| VLOG(4) << "Instruction removed: " << instruction->ToString(); |
| // Ensure the exported preset assignments don't contain a refence to |
| // the removed instruction. |
| preset_assignments_->RemoveAssignmentForInstruction(instruction); |
| TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); |
| computation_modified = true; |
| } else if (instruction->opcode() == HloOpcode::kGetTupleElement) { |
| HloInstruction* operand = instruction->mutable_operand(0); |
| if (operand->opcode() == HloOpcode::kTuple) { |
| HloInstruction* forwarded_instruction = |
| operand->mutable_operand(instruction->tuple_index()); |
| VLOG(4) << "Replacing uses of " << instruction->ToString() |
| << " with " << forwarded_instruction->ToString(); |
| TF_RETURN_IF_ERROR( |
| instruction->ReplaceAllUsesWith(forwarded_instruction)); |
| computation_modified = true; |
| } |
| } |
| } |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| void MemorySpaceAssignment::EnsureInstructionAndOperandsInserted( |
| HloInstruction* new_instruction, HloInstructionSequence* new_sequence, |
| absl::flat_hash_set<HloInstruction*>* inserted_instructions) const { |
| if (inserted_instructions->contains(new_instruction)) { |
| return; |
| } |
| for (HloInstruction* operand : new_instruction->operands()) { |
| EnsureInstructionAndOperandsInserted(operand, new_sequence, |
| inserted_instructions); |
| } |
| VLOG(4) << "inserting: " << new_instruction->ToShortString(); |
| new_sequence->push_back(new_instruction); |
| inserted_instructions->insert(new_instruction); |
| } |
| |
| void MemorySpaceAssignment::AddPositionInAlternateMemorySpace( |
| HloPosition position) { |
| pending_positions_in_alternate_mem_.push_back(position); |
| } |
| |
| void MemorySpaceAssignment::ScheduleAsynchronousCopies() { |
| // For asynchronous copies of both directions (default to alternate and vice |
| // versa), sort them by their completion time. Then, if in the sorted order we |
| // see that the start time is earlier than the start time of an asynchronous |
| // copy that ends earlier, we delay the start of this. As a result, given |
| // asynchronous copies that might look like: |
| // |
| // CS CD |
| // a +-----------+ |
| // b +-----------+ |
| // c +---------+ |
| // |
| // We'll first sort by completion time: |
| // |
| // c +---------+ |
| // a +-----------+ |
| // b +-----------+ |
| // |
| // Then, delay a because c starts later than a despite also ending earlier: |
| // |
| // c +---------+ |
| // a +---------+ |
| // b +-----------+ |
| for (MemorySpace memory_space : |
| {MemorySpace::kDefault, MemorySpace::kAlternate}) { |
| std::vector<CopyAllocation*> copy_allocations; |
| for (auto& buffer_and_sequence : allocation_map_) { |
| for (auto& allocation : buffer_and_sequence.second) { |
| if (allocation->is_copy_allocation()) { |
| auto copy_allocation = static_cast<CopyAllocation*>(allocation.get()); |
| if (copy_allocation->memory_space() == memory_space) { |
| copy_allocations.push_back(copy_allocation); |
| } |
| } |
| } |
| } |
| |
| absl::c_stable_sort( |
| copy_allocations, [](CopyAllocation* first, CopyAllocation* second) { |
| return std::forward_as_tuple(first->copy_done_schedule_before(), |
| first->copy_start_schedule_after()) < |
| std::forward_as_tuple(second->copy_done_schedule_before(), |
| second->copy_start_schedule_after()); |
| }); |
| |
| CopyAllocation* prev_copy_allocation = nullptr; |
| for (CopyAllocation* copy_allocation : copy_allocations) { |
| if (prev_copy_allocation && |
| prev_copy_allocation->copy_start_schedule_after() > |
| copy_allocation->copy_start_schedule_after()) { |
| VLOG(4) << "Delaying CopyStart (" |
| << copy_allocation->copy_start_schedule_after() << " to " |
| << prev_copy_allocation->copy_start_schedule_after() << ") for " |
| << copy_allocation->copy_start()->ToString() << " because of " |
| << prev_copy_allocation->copy_start()->ToString(); |
| copy_allocation->set_copy_start_schedule_after( |
| prev_copy_allocation->copy_start_schedule_after()); |
| } |
| |
| // If the copy start doesn't happen to be scheduled at the correct |
| // computation, delay it until the correct computation starts. |
| const auto& flattened_instructions = |
| hlo_live_range_->flattened_instruction_sequence().instructions(); |
| int64 copy_start_schedule_after = |
| copy_allocation->copy_start_schedule_after(); |
| while (copy_allocation->instruction()->parent() != |
| flattened_instructions[copy_start_schedule_after]->parent()) { |
| VLOG(4) << "Delaying CopyStart (" << copy_start_schedule_after << " to " |
| << (copy_start_schedule_after + 1) << ") for " |
| << copy_allocation->copy_start()->ToString() |
| << " because it is not in the correct computation."; |
| copy_allocation->set_copy_start_schedule_after( |
| ++copy_start_schedule_after); |
| } |
| |
| schedule_after_[copy_allocation->copy_start_schedule_after()].push_back( |
| copy_allocation->copy_start()); |
| schedule_before_[copy_allocation->copy_done_schedule_before()].push_back( |
| copy_allocation->copy_done()); |
| prev_copy_allocation = copy_allocation; |
| } |
| } |
| } |
| |
| Status MemorySpaceAssignment::FixSchedule() { |
| CHECK(module_->has_schedule()); |
| HloSchedule& schedule = module_->schedule(); |
| for (const HloComputation* computation : |
| module_->MakeNonfusionComputations()) { |
| CHECK(schedule.is_computation_scheduled(computation)); |
| const HloInstructionSequence& sequence = |
| hlo_live_range_->flattened_instruction_sequence(); |
| HloInstructionSequence new_sequence; |
| |
| absl::flat_hash_set<HloInstruction*> inserted_instructions; |
| |
| VLOG(4) << "Scheduling: " << computation->ToString(); |
| |
| for (int64 instruction_index = 0; |
| instruction_index < sequence.instructions().size(); |
| ++instruction_index) { |
| HloInstruction* instruction = sequence.instructions()[instruction_index]; |
| if (!computation->ContainsInstruction(instruction)) { |
| continue; |
| } |
| auto insts_before_iter = schedule_before_.find(instruction_index); |
| if (insts_before_iter != schedule_before_.end()) { |
| for (HloInstruction* new_instruction : insts_before_iter->second) { |
| EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence, |
| &inserted_instructions); |
| } |
| } |
| // Insert only if not previously inserted. |
| if (!inserted_instructions.contains(instruction)) { |
| EnsureInstructionAndOperandsInserted(instruction, &new_sequence, |
| &inserted_instructions); |
| } |
| auto insts_after_iter = schedule_after_.find(instruction_index); |
| if (insts_after_iter != schedule_after_.end()) { |
| for (HloInstruction* new_instruction : insts_after_iter->second) { |
| EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence, |
| &inserted_instructions); |
| } |
| } |
| } |
| // For rare cases where the original sequence is empty, ensure the root |
| // instruction and its dependencies are scheduled. |
| EnsureInstructionAndOperandsInserted(computation->root_instruction(), |
| &new_sequence, &inserted_instructions); |
| CHECK_EQ(new_sequence.size(), computation->instruction_count()) |
| << "New sequence for computation " << computation->name() << " has " |
| << new_sequence.size() << " instructions, expects " |
| << computation->instruction_count() << "."; |
| schedule.set_sequence(computation, new_sequence); |
| } |
| |
| return Status::OK(); |
| } |
| |
| } // namespace xla |