| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/service/memory_space_assignment.h" |
| |
| namespace xla { |
| |
| namespace { |
| // Define a dummy chunk for chunks that will be allocated in the default memory |
| // space and for keeping track of number of asynchronous copies. |
| const HeapSimulator::Chunk kDummyChunk{-1, -1}; |
| } // namespace |
| |
| float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute( |
| const HloInstruction& instruction) const { |
| return std::max( |
| cost_analysis_.flop_count(instruction) / |
| cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey), |
| cost_analysis_.transcendental_count(instruction) / |
| cost_analysis_.per_second_rate(HloCostAnalysis::kTranscendentalsKey)); |
| } |
| |
| float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory( |
| const HloInstruction& instruction, |
| absl::optional<int64> operand_in_alternate_mem, |
| bool output_in_alternate_mem) const { |
| float bytes_accessed = cost_analysis_.bytes_accessed(instruction); |
| float elapsed_due_to_bytes = |
| bytes_accessed / |
| cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); |
| if (operand_in_alternate_mem) { |
| // Estimate the elapsed time due to the operand being in the alternate |
| // memory space. |
| float operand_bytes_accessed = cost_analysis_.operand_bytes_accessed( |
| instruction, *operand_in_alternate_mem); |
| float elapsed_due_to_operand_bytes = |
| operand_bytes_accessed / alternate_mem_bandwidth_bytes_per_second_; |
| bytes_accessed -= operand_bytes_accessed; |
| elapsed_due_to_bytes = |
| elapsed_due_to_operand_bytes + |
| bytes_accessed / |
| cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); |
| } |
| if (output_in_alternate_mem) { |
| // Estimate the elapsed time due to the output being in the alternate memory |
| // space. |
| float output_bytes_accessed = |
| cost_analysis_.output_bytes_accessed(instruction); |
| float elapsed_due_to_output_bytes = |
| output_bytes_accessed / alternate_mem_bandwidth_bytes_per_second_; |
| bytes_accessed -= output_bytes_accessed; |
| elapsed_due_to_bytes = |
| elapsed_due_to_output_bytes + |
| bytes_accessed / |
| cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey); |
| } |
| return elapsed_due_to_bytes; |
| } |
| |
| float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsed( |
| const HloInstruction& instruction, |
| absl::optional<int64> operand_in_alternate_mem, |
| bool output_in_alternate_mem) const { |
| return std::max( |
| GetInstructionElapsedDueToCompute(instruction), |
| GetInstructionElapsedDueToMemory(instruction, operand_in_alternate_mem, |
| output_in_alternate_mem)); |
| } |
| |
| float MemorySpaceAssignmentCostAnalysis::GetAsyncCopyElapsed( |
| const Shape& shape) const { |
| int64 size_in_bytes = cost_analysis_.GetShapeSize(shape); |
| return static_cast<float>(size_in_bytes) / |
| async_copy_bandwidth_bytes_per_second_; |
| } |
| |
| bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy( |
| const Shape& shape, int64 start_time, int64 end_time) const { |
| return end_time - start_time <= max_overlap_count_; |
| } |
| |
| int64 InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime( |
| const Shape& shape, int64 start_time, int64 latest_end_time) const { |
| return std::min(start_time + min_overlap_count_, latest_end_time); |
| } |
| |
| void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use, |
| int64 start_time, |
| int64 end_time) { |
| end_time_ = end_time; |
| current_prefetch_time_ = std::max(start_time, end_time_ - max_overlap_count_); |
| } |
| |
| int64 InstructionCountPrefetchIntervalPicker::Next() { |
| CHECK(!Done()) << "Prefetch interval picker's Next() is called even though " |
| "Done() is false"; |
| return current_prefetch_time_++; |
| } |
| |
| bool InstructionCountPrefetchIntervalPicker::Done() const { |
| return end_time_ - current_prefetch_time_ <= min_overlap_count_; |
| } |
| |
| std::string InstructionCountPrefetchIntervalPicker::ToDebugString() const { |
| return absl::StrCat("Overlapped HLOs = ", end_time_ - current_prefetch_time_); |
| } |
| |
| std::string InstructionCountPrefetchIntervalPicker::ToNoCopyDebugString( |
| const Shape& shape, int64 start_time, int64 end_time) const { |
| return absl::StrCat("Overlapped HLOs = ", end_time - start_time); |
| } |
| |
| void CostAnalysisPrefetchIntervalPicker::SetInstructionSchedule( |
| const absl::flat_hash_map<const HloInstruction*, int64>& |
| instruction_schedule) { |
| // First create a vector of elapsed times of HLO instructions. |
| std::vector<float> instructions_elapsed_time(instruction_schedule.size(), |
| 0.0); |
| |
| for (const auto& instruction_and_logical_time : instruction_schedule) { |
| float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds( |
| *instruction_and_logical_time.first); |
| int64 logical_time = instruction_and_logical_time.second; |
| if (logical_time >= instructions_elapsed_time.size()) { |
| instructions_elapsed_time.resize(logical_time + 1, 0.0); |
| } |
| instructions_elapsed_time[logical_time] = elapsed_time; |
| } |
| // As an optimization, create a cumulative sum vector of elapsed time. |
| float cumsum = 0.0; |
| for (float elapsed_time : instructions_elapsed_time) { |
| cumsum += elapsed_time; |
| elapsed_time_cumsum_.push_back(cumsum); |
| } |
| } |
| |
| bool CostAnalysisPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy( |
| const Shape& shape, int64 start_time, int64 end_time) const { |
| // Even though this method returns if we allow the buffer in alternate memory |
| // _without_ asynchronous copies, calculate how long it would have taken to |
| // copy it and compare it to the elapsed time in the logical interval. |
| float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape); |
| float logical_interval_elapsed = |
| GetLogicalIntervalElapsed(start_time, end_time); |
| return max_async_copy_to_overlap_ratio_ * async_copy_elapsed > |
| logical_interval_elapsed; |
| } |
| |
| int64 CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime( |
| const Shape& shape, int64 start_time, int64 latest_end_time) const { |
| float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape); |
| int64 end_time; |
| for (end_time = start_time + 1; end_time <= latest_end_time; ++end_time) { |
| float logical_interval_elapsed = |
| GetLogicalIntervalElapsed(start_time, end_time); |
| if (logical_interval_elapsed >= |
| min_async_copy_to_overlap_ratio_ * async_copy_elapsed) { |
| break; |
| } |
| } |
| return end_time; |
| } |
| |
| void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use, |
| int64 start_time, |
| int64 end_time) { |
| const Shape& shape = use.instruction->operand(use.operand_number)->shape(); |
| // Find the earliest time that satisfies max_async_copy_to_overlap_ratio_. |
| async_copy_elapsed_ = cost_analysis_.GetAsyncCopyElapsed(shape); |
| // Estimate the time we would save by having this op in alternate memory. |
| float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction); |
| float elapsed_time_in_alternate_mem = cost_analysis_.GetInstructionElapsed( |
| *use.instruction, use.operand_number); |
| inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem; |
| end_logical_time_ = end_time; |
| // Find the earliest time we're allowed to start prefetching. |
| for (current_logical_prefetch_time_ = start_time; |
| current_logical_prefetch_time_ <= end_logical_time_ && |
| max_async_copy_to_overlap_ratio_ * async_copy_elapsed_ < |
| GetLogicalIntervalElapsed(current_logical_prefetch_time_, |
| end_logical_time_); |
| ++current_logical_prefetch_time_) { |
| } |
| } |
| |
| int64 CostAnalysisPrefetchIntervalPicker::Next() { |
| CHECK(!Done()) << "Prefetch interval picker's Next() is called even though " |
| "Done() is false"; |
| return current_logical_prefetch_time_++; |
| } |
| |
| bool CostAnalysisPrefetchIntervalPicker::Done() const { |
| // The end time is inclusive, so we're done if the prefetch time is greater |
| // than that. |
| if (current_logical_prefetch_time_ > end_logical_time_) { |
| return true; |
| } |
| float logical_interval_elapsed = GetLogicalIntervalElapsed( |
| current_logical_prefetch_time_, end_logical_time_); |
| return async_copy_elapsed_ * min_async_copy_to_overlap_ratio_ > |
| logical_interval_elapsed + inst_elapsed_reduction_; |
| } |
| |
| float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed( |
| int64 start_time, int64 end_time) const { |
| return elapsed_time_cumsum_[end_time - 1] - elapsed_time_cumsum_[start_time]; |
| } |
| |
| std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const { |
| float logical_interval_elapsed = GetLogicalIntervalElapsed( |
| current_logical_prefetch_time_, end_logical_time_); |
| return absl::StrCat( |
| "Async copy elapsed (s) = ", async_copy_elapsed_, |
| ", inst elapsed reduction (s) = ", inst_elapsed_reduction_, |
| ", logical interval elapsed (s) = ", logical_interval_elapsed); |
| } |
| |
| std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString( |
| const Shape& shape, int64 start_time, int64 end_time) const { |
| float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape); |
| float logical_interval_elapsed = |
| GetLogicalIntervalElapsed(start_time, end_time); |
| return absl::StrCat( |
| "Async copy elapsed (s) = ", async_copy_elapsed, |
| ", logical interval elapsed (s) = ", logical_interval_elapsed); |
| } |
| |
| std::vector<const GlobalDecreasingSizeBestFitHeap::BufferInterval*> |
| AlternateMemoryBestFitHeap::GetSortedColocatedIntervals( |
| const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { |
| std::vector<const BufferInterval*> colocated_intervals; |
| std::vector<const BufferInterval*> worklist = {&interval}; |
| while (!worklist.empty()) { |
| const BufferInterval* item = worklist.back(); |
| worklist.pop_back(); |
| colocated_intervals.push_back(item); |
| for (const HloValue* buffer_colocated : item->colocations) { |
| worklist.push_back(&buffer_intervals_.at(buffer_colocated)); |
| } |
| } |
| |
| absl::c_sort(colocated_intervals, [&](const BufferInterval* x, |
| const BufferInterval* y) { |
| return std::make_pair(x->start, x->end) < std::make_pair(y->start, y->end); |
| }); |
| return colocated_intervals; |
| } |
| |
| bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory( |
| const BufferInterval& interval) const { |
| // If the buffer is a tuple, don't use this algorithm for now. The buffers |
| // that are pointed to by the tuple will still use this algorithm. Because |
| // tuples are cheap to place in the alternate memory (they are just pointers) |
| // we don't need to use prefetch/evict logic. |
| if (interval.buffer->shape().IsTuple()) { |
| VLOG(4) << "Keeping value " << interval.buffer->ToShortString() |
| << " in default mem because it is a tuple."; |
| return false; |
| } |
| |
| // The semantics of TupleSelect are weird: TupleSelect doesn't define a |
| // buffer, but just forwards the buffers in the either left or right side. |
| // This means the the two different inputs to TupleSelect must not alias, yet |
| // they should be allocated in the same memory space, and both buffers must be |
| // kept alive for the entire live range of TupleSelect. Instead, just don't |
| // allocate TupleSelect in the alternate memory space. |
| // TODO(berkin): Not allocating add-dependencies either since they need to be |
| // treated specially. We should revisit this later. |
| for (const HloPosition& position : interval.buffer->positions()) { |
| if (position.instruction->opcode() == HloOpcode::kTupleSelect || |
| position.instruction->opcode() == HloOpcode::kAddDependency) { |
| VLOG(4) << "Keeping value " << interval.buffer->ToShortString() |
| << " in default mem because it has a tuple-select or " |
| << "add-dependency position."; |
| return false; |
| } |
| } |
| |
| // Send and Recv HLOs return a request identifier. These should not be |
| // allocated in the alternate memory. |
| const HloPosition& defining_position = interval.buffer->defining_position(); |
| if ((defining_position.instruction->opcode() == HloOpcode::kSend || |
| defining_position.instruction->opcode() == HloOpcode::kRecv) && |
| defining_position.index == ShapeIndex({1})) { |
| VLOG(4) |
| << "Keeping value " << interval.buffer->ToShortString() |
| << " in default mem because it is a request identifier for send/recv."; |
| return false; |
| } |
| |
| return true; |
| } |
| |
| HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { |
| std::vector<BufferInterval> sorted_buffer_intervals = |
| GetSortedBufferIntervals(); |
| |
| VLOG(1) << "Assigning buffers to alternate memory. Max heap size = " |
| << options_.max_size_in_bytes; |
| |
| AddInputAndOutputRequiredAssignments(); |
| options_.prefetch_interval_picker->SetInstructionSchedule( |
| hlo_live_range_.instruction_schedule()); |
| |
| for (auto& interval : sorted_buffer_intervals) { |
| if (!interval.need_allocation) { |
| continue; |
| } |
| |
| // Skip if we have already allocated for this buffer. |
| if (allocation_map_->contains(interval.buffer)) { |
| continue; |
| } |
| |
| if (!IsIntervalAllowedInAlternateMemory(interval)) { |
| continue; |
| } |
| |
| auto colocated_intervals = GetSortedColocatedIntervals(interval); |
| |
| if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) { |
| VLOG(4) << "Interval " << interval.buffer->ToShortString() |
| << " is reserved in the alternate memory. Total reserved bytes = " |
| << reserved_in_bytes_; |
| for (const BufferInterval* colocated_interval : colocated_intervals) { |
| const HloValue* value = colocated_interval->buffer; |
| // Color all of the aliased reserved buffers here because reserved |
| // alternate memory allocations will not have an entry in preset |
| // allocations that is normally used for coloring. |
| for (auto& position : value->positions()) { |
| VLOG(3) << "Coloring " << position.ToString(); |
| Shape* shape = ShapeUtil::GetMutableSubshape( |
| position.instruction->mutable_shape(), position.index); |
| CHECK(shape->IsArray()) << "Coloring a shape that is not an array: " |
| << position.ToString(); |
| shape->mutable_layout()->set_memory_space( |
| options_.alternate_memory_space); |
| } |
| } |
| // Increment the reserved part of alternate memory so that it is not |
| // available for other buffers. Since all colocated intervals should have |
| // the same size, just use the first one. |
| reserved_in_bytes_ += options_.size_fn(*colocated_intervals[0]->buffer); |
| continue; |
| } |
| |
| if (colocated_intervals.size() > 1 && |
| !options_.allocate_across_sequential_calls) { |
| VLOG(4) << "Not allocating " << interval.buffer->ToShortString() |
| << " because it aliases with another interval and " |
| << " allocate_across_sequential_calls is false."; |
| continue; |
| } |
| |
| const HloComputation* defining_computation = |
| colocated_intervals[0]->buffer->defining_instruction()->parent(); |
| MemorySpaceAssignment::Allocation* aliased_allocation = nullptr; |
| for (const BufferInterval* colocated_interval : colocated_intervals) { |
| const HloValue* value = colocated_interval->buffer; |
| const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); |
| MemorySpaceAssignment::AllocationSequence* allocation_sequence = |
| &(*allocation_map_)[value]; |
| int64 definition_time = |
| instruction_schedule.at(value->defining_instruction()); |
| // Sort the uses by the use time. |
| std::vector<HloUse> uses = value->uses(); |
| absl::c_sort(uses, [&](HloUse use1, HloUse use2) { |
| return instruction_schedule.at(use1.instruction) < |
| instruction_schedule.at(use2.instruction); |
| }); |
| |
| // If there was an aliased allocation for this buffer, propagate that for |
| // this HloValue. |
| if (aliased_allocation != nullptr) { |
| VLOG(3) << "Adding an aliased allocation: (" |
| << aliased_allocation->start_time() << ", " |
| << aliased_allocation->end_time() |
| << ") pos: " << aliased_allocation->defining_position() |
| << " mem space: " |
| << (aliased_allocation->memory_space() == MemorySpace::kDefault |
| ? "default" |
| : "alt"); |
| allocation_sequence->push_back( |
| absl::make_unique<MemorySpaceAssignment::Allocation>( |
| value->defining_instruction(), value->defining_position(), |
| aliased_allocation->memory_space(), aliased_allocation->chunk(), |
| definition_time, definition_time)); |
| } |
| |
| // Iterate over the uses. |
| for (HloUse use : uses) { |
| int64 use_time = instruction_schedule.at(use.instruction); |
| int64 last_use_time = instruction_schedule.at(uses.back().instruction); |
| int64 latest_prefetch_time = use_time; |
| |
| if (use.instruction->parent() != defining_computation) { |
| VLOG(3) << "skip use " << use.ToString() |
| << " because it's in a different computation."; |
| continue; |
| } |
| |
| // Sequential calls include kWhile, kCall, and kConditional opcodes. |
| bool is_sequential_call = |
| (GetInstructionCallContext(use.instruction->opcode()) == |
| CallContext::kSequential); |
| if (is_sequential_call) { |
| for (const HloComputation* called_computation : |
| use.instruction->called_computations()) { |
| const HloLiveRange::TimeBound& computation_span = |
| hlo_live_range_.computation_span_times().at(called_computation); |
| latest_prefetch_time = |
| std::min(computation_span.start, latest_prefetch_time); |
| } |
| } |
| |
| // Bitcasts don't define buffers and don't directly consume buffers. |
| // Skip allocating buffers for bitcast uses. The uses that feed from |
| // bitcasts will be handled specially. |
| if (use.instruction->opcode() != HloOpcode::kBitcast) { |
| if (!FindAllocation(definition_time, use_time, last_use_time, |
| latest_prefetch_time, value->defining_position(), |
| use, value, colocated_interval->size, |
| allocation_sequence)) { |
| // If the allocation finding failed (e.g., due to running out of |
| // asynchronous copies), then fall back to allocating the buffer |
| // entirely in the default memory. |
| pending_chunks_.clear(); |
| pending_async_copies_.clear(); |
| allocation_sequence->clear(); |
| break; |
| } |
| |
| // If there are multiple uses, they can try using the memory |
| // allocation already at the alternate memory. |
| definition_time = use_time; |
| } |
| |
| // If the use has been a sequential call (e.g. a while loop), the other |
| // colocated intervals must alias with this allocation. |
| if (is_sequential_call) { |
| aliased_allocation = |
| GetLiveAllocationAt(*allocation_sequence, use_time); |
| } |
| } |
| } |
| |
| CommitPendingChunks(); |
| } |
| |
| if (VLOG_IS_ON(3)) { |
| for (const auto& alloc_pair : *allocation_map_) { |
| VLOG(3) << "Allocation for " << alloc_pair.first->ToShortString(); |
| for (const auto& alloc : alloc_pair.second) { |
| std::string addr_str = ": default"; |
| if (alloc->memory_space() == MemorySpace::kAlternate) { |
| addr_str = absl::StrCat(": alt ", alloc->chunk().offset); |
| } |
| |
| VLOG(3) << " " << alloc->start_time() << "-" << alloc->end_time() |
| << addr_str << ", " << alloc->uses().size() << " uses"; |
| } |
| } |
| } |
| |
| return result_; |
| } |
| |
| bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) { |
| return (a.start_time < b.start_time && a.end_time <= b.end_time) || |
| (a.start_time <= b.start_time && a.end_time < b.end_time); |
| } |
| |
| void AsynchronousCopyOrdering::AddCopy(const AsynchronousCopy& copy) { |
| auto it_and_inserted = ranges_.insert(copy); |
| CHECK(it_and_inserted.second || |
| it_and_inserted.first->start_time == copy.start_time); |
| } |
| |
| bool AsynchronousCopyOrdering::ViolatesOrdering(int64 start_time, |
| int64 end_time) const { |
| // We allow identical start and end times. It is enough to check for just the |
| // start time in case we find a match in ranges_ because the found value will |
| // either be identical to {start_time, end_time} (and this doesn't violate) or |
| // its start_time will be smaller and end_time will be larger (this violates). |
| auto copy_it = ranges_.find( |
| {start_time, end_time, MemorySpaceAssignment::MemorySpace::kAlternate}); |
| return copy_it != ranges_.end() && copy_it->start_time != start_time; |
| } |
| |
| /*static*/ MemorySpaceAssignment::Allocation* |
| AlternateMemoryBestFitHeap::GetLiveAllocationAt( |
| const MemorySpaceAssignment::AllocationSequence& allocations, int64 time) { |
| for (auto allocation_it = allocations.rbegin(); |
| allocation_it != allocations.rend(); ++allocation_it) { |
| if ((*allocation_it)->start_time() <= time && |
| (*allocation_it)->end_time() >= time) { |
| return allocation_it->get(); |
| } |
| } |
| return nullptr; |
| } |
| |
| void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() { |
| // Go through the parameters and outputs and pin them to the corresponding |
| // memory by adding a required assignment. |
| const HloModule& module = alias_analysis_.dataflow_analysis().module(); |
| const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); |
| HloComputation* entry_computation = module.entry_computation(); |
| for (HloInstruction* parameter_instruction : |
| entry_computation->parameter_instructions()) { |
| int64 parameter_instruction_time = |
| instruction_schedule.at(parameter_instruction); |
| ShapeUtil::ForEachSubshape( |
| parameter_instruction->shape(), |
| [&](const Shape& subshape, const ShapeIndex& index) { |
| MemorySpace memory_space = MemorySpace::kDefault; |
| if (subshape.has_layout() && subshape.layout().memory_space() == |
| options_.alternate_memory_space) { |
| memory_space = MemorySpace::kAlternate; |
| } |
| for (const HloBuffer* buffer : |
| alias_analysis_.ComputeBuffersAt(parameter_instruction, index)) { |
| for (const HloValue* value : buffer->values()) { |
| VLOG(3) << "Adding required assignment for parameter value = " |
| << value->ToShortString() |
| << " time = " << parameter_instruction_time << " space = " |
| << (memory_space == MemorySpace::kDefault ? "def" |
| : "alt"); |
| required_assignments_[value].push_back( |
| {memory_space, /*time=*/parameter_instruction_time}); |
| } |
| } |
| }); |
| } |
| HloInstruction* root_instruction = entry_computation->root_instruction(); |
| int64 root_instruction_time = instruction_schedule.at(root_instruction); |
| ShapeUtil::ForEachSubshape( |
| root_instruction->shape(), |
| [&](const Shape& subshape, const ShapeIndex& index) { |
| MemorySpace memory_space = MemorySpace::kDefault; |
| if (subshape.has_layout() && subshape.layout().memory_space() == |
| options_.alternate_memory_space) { |
| memory_space = MemorySpace::kAlternate; |
| } |
| for (const HloBuffer* buffer : |
| alias_analysis_.ComputeBuffersAt(root_instruction, index)) { |
| for (const HloValue* value : buffer->values()) { |
| VLOG(3) << "Adding required assignment for output value = " |
| << value->ToShortString() |
| << " time = " << root_instruction_time << " space = " |
| << (memory_space == MemorySpace::kDefault ? "def" : "alt"); |
| required_assignments_[value].push_back( |
| {memory_space, /*time=*/root_instruction_time}); |
| } |
| } |
| }); |
| } |
| |
| bool AlternateMemoryBestFitHeap::AreIntervalsReservedInAlternateMemory( |
| absl::Span<const BufferInterval* const> colocated_intervals) const { |
| auto is_position_in_alternate_memory = [&](const HloPosition& position) { |
| const Shape& shape = position.shape(); |
| return shape.has_layout() && |
| shape.layout().memory_space() == options_.alternate_memory_space; |
| }; |
| |
| const HloModule& module = alias_analysis_.dataflow_analysis().module(); |
| const HloComputation* entry_computation = module.entry_computation(); |
| const HloInstruction* root_instruction = |
| entry_computation->root_instruction(); |
| for (const BufferInterval* colocated_interval : colocated_intervals) { |
| const HloValue* value = colocated_interval->buffer; |
| if (value->defining_instruction()->opcode() == HloOpcode::kParameter && |
| value->defining_instruction()->parent() == entry_computation && |
| is_position_in_alternate_memory(value->defining_position())) { |
| return true; |
| } |
| |
| for (const HloPosition& position : value->positions()) { |
| if (position.instruction == root_instruction && |
| is_position_in_alternate_memory(position)) { |
| return true; |
| } |
| } |
| } |
| return false; |
| } |
| |
| void AlternateMemoryBestFitHeap::CommitPendingChunks() { |
| for (auto interval_and_chunk : pending_chunks_) { |
| VLOG(3) << "Committing chunk: " << interval_and_chunk.first.start << "-" |
| << interval_and_chunk.first.end << " : [" |
| << interval_and_chunk.second.chunk.offset << ", " |
| << interval_and_chunk.second.chunk.size << "]"; |
| CommitChunk(interval_and_chunk.first, interval_and_chunk.second); |
| } |
| pending_chunks_.clear(); |
| // Also add the pending async copies to the interval tree. |
| for (const auto& interval : pending_async_copies_) { |
| if (options_.max_outstanding_async_copies >= 0) { |
| async_copy_interval_tree_.Add(interval.start_time, interval.end_time, |
| kDummyChunk); |
| } |
| if (interval.destination == MemorySpace::kAlternate) { |
| async_copy_ordering_.AddCopy(interval); |
| } |
| } |
| pending_async_copies_.clear(); |
| } |
| |
| void AlternateMemoryBestFitHeap::AddToPendingChunks( |
| const BufferInterval& buffer_interval, |
| const ChunkCandidate& chunk_candidate) { |
| pending_chunks_.emplace_back(buffer_interval, chunk_candidate); |
| } |
| |
| bool AlternateMemoryBestFitHeap::RequiredInDefaultMemory(const HloValue* buffer, |
| int64 time) const { |
| auto required_assignment_it = required_assignments_.find(buffer); |
| return required_assignment_it != required_assignments_.end() && |
| absl::c_any_of( |
| required_assignment_it->second, |
| [&](const RequiredMemoryAssignment& required_assignment) { |
| return required_assignment.memory_space == |
| MemorySpace::kDefault && |
| required_assignment.time == time; |
| }); |
| } |
| |
| bool AlternateMemoryBestFitHeap::FindAllocation( |
| int64 start_time, int64 end_time, int64 last_use_time, |
| int64 latest_prefetch_time, HloPosition defining_position, HloUse use, |
| const HloValue* buffer, int64 size, |
| MemorySpaceAssignment::AllocationSequence* allocations) { |
| HloInstruction* operand = |
| use.instruction->mutable_operand(use.operand_number); |
| // If the operand is a bitcast, we look at bitcast's operand until we find a |
| // non-bitcast operand. |
| HloInstruction* non_bitcast_operand = operand; |
| while (non_bitcast_operand->opcode() == HloOpcode::kBitcast) { |
| non_bitcast_operand = non_bitcast_operand->mutable_operand(0); |
| } |
| // Create an alternate memory interval that starts at the earliest |
| // possible position, given by max_prefetch_interval. |
| BufferInterval alternate_mem_interval; |
| alternate_mem_interval.buffer = buffer; |
| alternate_mem_interval.size = size; |
| alternate_mem_interval.end = end_time; |
| |
| // start_time == end_time is a special case where the value is consumed |
| // multiple times by the same instruction. We can just find the previous |
| // allocation and use that allocation. |
| if (start_time == end_time) { |
| MemorySpaceAssignment::Allocation* allocation = |
| GetLiveAllocationAt(*allocations, end_time); |
| CHECK_NE(allocation, nullptr); |
| allocation->AddUse(use); |
| return true; |
| } |
| |
| VLOG(2) << "Finding allocation for " << buffer->ToShortString() << " (" |
| << start_time << ", " << end_time |
| << ") latest prefetch = " << latest_prefetch_time |
| << " last use = " << last_use_time << " use = " << use.ToString() |
| << ". Size = " << size |
| << ", def pos = " << defining_position.ToString() |
| << ", operand = " << operand->ToShortString() |
| << (non_bitcast_operand != operand |
| ? ", non_bitcast_operand = " + |
| non_bitcast_operand->ToShortString() |
| : ""); |
| CHECK_LE(start_time, end_time); |
| |
| // There could be a requirement to pin this buffer to default memory either |
| // because it is a parameter or an output. If the buffer is a parameter, then |
| // we're allowed to prefetch. If the use expects the ouput to be in default |
| // memory, we cannot prefetch it because if we did, it would be in alternate |
| // memory instead. |
| bool in_default_mem_at_start = RequiredInDefaultMemory(buffer, start_time); |
| bool in_default_mem_at_end = RequiredInDefaultMemory(buffer, end_time); |
| |
| // First try keeping the allocation entirely in the alternate memory. |
| if (!in_default_mem_at_start && !in_default_mem_at_end && |
| TryAllocatingInAlternateMemoryNoCopy( |
| start_time, end_time, last_use_time, defining_position, use, |
| alternate_mem_interval, non_bitcast_operand, allocations)) { |
| return true; |
| } |
| |
| auto prev_allocation_it = allocations->rbegin(); |
| // Find a previous allocation that is in the default memory space (not |
| // necessarily the very last allocation). |
| auto prev_allocation_in_default_mem_it = std::find_if( |
| allocations->rbegin(), allocations->rend(), [&](const auto& allocation) { |
| return allocation->memory_space() == MemorySpace::kDefault && |
| allocation->defining_position() == defining_position; |
| }); |
| |
| if (prev_allocation_in_default_mem_it == allocations->rend() && |
| prev_allocation_it != allocations->rend() && |
| (*prev_allocation_it)->memory_space() == MemorySpace::kAlternate && |
| (*prev_allocation_it)->defining_position() == defining_position) { |
| // If there was an allocation for this HloValue that was in the alternate |
| // memory space, we also need to perform an eviction. |
| int64 eviction_start_time = (*prev_allocation_it)->start_time(); |
| int64 eviction_end_time = (*prev_allocation_it)->end_time(); |
| CHECK(eviction_start_time <= eviction_end_time); |
| |
| int64 preferred_eviction_end_time = std::max( |
| options_.prefetch_interval_picker->PreferredEvictionEndTime( |
| non_bitcast_operand->shape(), eviction_start_time, end_time), |
| eviction_end_time); |
| |
| BufferInterval eviction_mem_interval; |
| eviction_mem_interval.buffer = buffer; |
| eviction_mem_interval.size = size; |
| // Try to reserve a buffer from the end of the previous allocation to the |
| // preferred eviction end time. |
| eviction_mem_interval.start = eviction_end_time + 1; |
| eviction_mem_interval.end = preferred_eviction_end_time; |
| int64 preferred_offset = (*prev_allocation_it)->chunk().offset; |
| VLOG(4) << "Eviction (" << eviction_start_time << ", " << eviction_end_time |
| << ") preferred end time = " << eviction_mem_interval.end; |
| |
| for (; eviction_mem_interval.end > eviction_end_time; |
| --eviction_mem_interval.end) { |
| ChunkCandidate chunk_candidate = |
| FindChunkCandidate(eviction_mem_interval, preferred_offset); |
| if (chunk_candidate.chunk.offset == preferred_offset) { |
| AddToPendingChunks(eviction_mem_interval, chunk_candidate); |
| break; |
| } |
| } |
| eviction_end_time = eviction_mem_interval.end; |
| |
| VLOG(3) << "Evicting buffer at " << (*prev_allocation_it)->chunk().offset |
| << " (" << eviction_start_time << ", " << eviction_end_time << ")"; |
| |
| bool eviction_interval_too_short = |
| (eviction_start_time == eviction_end_time); |
| bool eviction_violates_outstanding_copies = |
| ViolatesMaximumOutstandingAsyncCopies(eviction_start_time, |
| eviction_end_time); |
| |
| // See if this interval would violate the asynchronous copy limit. |
| if (!eviction_interval_too_short && !eviction_violates_outstanding_copies) { |
| (*prev_allocation_it)->Extend(eviction_end_time); |
| AddAsyncCopy(**prev_allocation_it, MemorySpace::kDefault, kDummyChunk, |
| eviction_start_time, (*prev_allocation_it)->end_time(), |
| eviction_end_time, allocations); |
| } else { |
| if (eviction_violates_outstanding_copies) { |
| VLOG(3) << "This violates the maximum async copies."; |
| } else { |
| VLOG(3) << "Eviction interval is too short (" << eviction_start_time |
| << ", " << eviction_end_time << ")."; |
| } |
| // If the original interval violated the limit, try sub-intervals within |
| // this interval. |
| bool eviction_scheduled = false; |
| for (int64 time = eviction_start_time; time < eviction_end_time; ++time) { |
| VLOG(3) << "Try evicting (" << time << ", " << time + 1 << ")"; |
| if (!ViolatesMaximumOutstandingAsyncCopies(time, time + 1)) { |
| VLOG(3) << "Eviction successful."; |
| AddAsyncCopy(**prev_allocation_it, MemorySpace::kDefault, kDummyChunk, |
| time, time + 1, time + 1, allocations); |
| eviction_scheduled = true; |
| break; |
| } |
| } |
| |
| if (!eviction_scheduled) { |
| // If the eviction couldn't be scheduled, then fail. This buffer will be |
| // kept in the default memory. |
| VLOG(3) << "Bailing: Could not evict " << use.ToString() |
| << " because we hit the limit of maximum asynchronous copies " |
| << "between " |
| << hlo_live_range_.flattened_instruction_sequence() |
| .instructions()[eviction_start_time] |
| << " and " |
| << hlo_live_range_.flattened_instruction_sequence() |
| .instructions()[eviction_end_time]; |
| return false; |
| } |
| } |
| prev_allocation_in_default_mem_it = allocations->rbegin(); |
| } else if (prev_allocation_in_default_mem_it == allocations->rend()) { |
| allocations->push_back(absl::make_unique<MemorySpaceAssignment::Allocation>( |
| non_bitcast_operand, defining_position, MemorySpace::kDefault, |
| kDummyChunk, start_time, end_time)); |
| prev_allocation_in_default_mem_it = allocations->rbegin(); |
| } |
| |
| CHECK(prev_allocation_in_default_mem_it != allocations->rend()); |
| CHECK((*prev_allocation_in_default_mem_it)->memory_space() == |
| MemorySpace::kDefault); |
| |
| // If the buffer must be in default memory at the end_time, don't prefetch. |
| if (in_default_mem_at_end) { |
| VLOG(4) |
| << "Not trying to prefetch because use requires buffer in default mem."; |
| (*prev_allocation_in_default_mem_it)->Extend(end_time); |
| (*prev_allocation_in_default_mem_it)->AddUse(use); |
| return true; |
| } |
| |
| // Try partially placing the buffer in the alternate space. The time that is |
| // overlapped will be used to asynchronously copy the buffer from the |
| // default memory to the alternate memory. |
| // |
| // start end |
| // time time |
| // X---------------------X |
| // Alternate: +------+ |
| // Default: +---------------------+ |
| // ^ ^ |
| // Copy Copy |
| // Start Done |
| options_.prefetch_interval_picker->Begin(use, start_time, |
| latest_prefetch_time); |
| VLOG(4) << "Trying prefetch picker = " |
| << options_.prefetch_interval_picker->ToDebugString(); |
| while (!options_.prefetch_interval_picker->Done()) { |
| alternate_mem_interval.start = options_.prefetch_interval_picker->Next(); |
| VLOG(4) << "Trying alternate memory allocation (" |
| << alternate_mem_interval.start << ", " |
| << alternate_mem_interval.end << ")"; |
| // If this additional asynchronous copy would violate the limit, try a |
| // different interval. |
| if (ViolatesMaximumOutstandingAsyncCopies(alternate_mem_interval.start, |
| alternate_mem_interval.end)) { |
| VLOG(4) << "This would violate the outstanding async copy limit."; |
| continue; |
| } |
| if (async_copy_ordering_.ViolatesOrdering(alternate_mem_interval.start, |
| alternate_mem_interval.end)) { |
| VLOG(4) << "This would violate asynchronous copy ordering."; |
| continue; |
| } |
| |
| ChunkCandidate chunk_candidate = FindChunkCandidate(alternate_mem_interval); |
| // Check if the new heap size fits within limits. |
| if (chunk_candidate.heap_size < available_heap_size()) { |
| VLOG(3) << "Move the buffer to alternate memory at " |
| << alternate_mem_interval.start |
| << ". Offset = " << chunk_candidate.chunk.offset |
| << ", size = " << chunk_candidate.chunk.size |
| << ", heap_size = " << chunk_candidate.heap_size |
| << ", prefetch picker = " |
| << options_.prefetch_interval_picker->ToDebugString(); |
| AddToPendingChunks(alternate_mem_interval, chunk_candidate); |
| |
| AddAsyncCopy(**prev_allocation_in_default_mem_it, MemorySpace::kAlternate, |
| chunk_candidate.chunk, alternate_mem_interval.start, |
| end_time, latest_prefetch_time, allocations); |
| |
| allocations->back()->AddUse(use); |
| return true; |
| } |
| } |
| |
| // If a copy wasn't inserted, then add this use to the latest allocation in |
| // default memory. |
| (*prev_allocation_in_default_mem_it)->Extend(end_time); |
| (*prev_allocation_in_default_mem_it)->AddUse(use); |
| return true; |
| } |
| |
| void AlternateMemoryBestFitHeap::AddAsyncCopy( |
| const MemorySpaceAssignment::Allocation& prev_allocation, |
| MemorySpace memory_space, Chunk chunk, int64 start_time, int64 end_time, |
| int64 copy_done_schedule_before_time, |
| MemorySpaceAssignment::AllocationSequence* allocations) { |
| VLOG(3) << "Copy to " |
| << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault |
| ? "default" |
| : "alternate") |
| << " memory between " << start_time << " and " |
| << copy_done_schedule_before_time << " keeping until " << end_time; |
| |
| allocations->push_back( |
| absl::make_unique<MemorySpaceAssignment::CopyAllocation>( |
| prev_allocation, memory_space, chunk, start_time, end_time, |
| copy_done_schedule_before_time)); |
| |
| // Register the additional async copy with the interval tree to keep track of |
| // the limit at any given time. |
| pending_async_copies_.push_back({start_time, end_time, memory_space}); |
| } |
| |
| bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies( |
| int64 start_time, int64 end_time) const { |
| if (options_.max_outstanding_async_copies < 0) { |
| return false; |
| } |
| |
| // Count both the asynchronous copies in the interval tree as well as the |
| // pending asynchronous copies belonging to this buffer. |
| int64 num_async_copies = |
| async_copy_interval_tree_.ChunksOverlappingInTime(start_time, end_time) |
| .size(); |
| |
| for (const auto& interval : pending_async_copies_) { |
| if (interval.start_time > start_time && interval.end_time < end_time) { |
| num_async_copies++; |
| } |
| } |
| // Add one because we are checking if adding an additional asynchronous copy |
| // would violate the limit. |
| return num_async_copies + 1 > options_.max_outstanding_async_copies; |
| } |
| |
| bool AlternateMemoryBestFitHeap::TryAllocatingInAlternateMemoryNoCopy( |
| int64 start_time, int64 end_time, int64 last_use_time, |
| HloPosition defining_position, HloUse use, |
| BufferInterval alternate_mem_interval, HloInstruction* non_bitcast_operand, |
| MemorySpaceAssignment::AllocationSequence* allocations) { |
| MemorySpaceAssignment::Allocation* prev_allocation = nullptr; |
| bool can_eliminate_copy = false; |
| if (allocations->empty()) { |
| // There hasn't been any allocations for this interval so far. We can |
| // eliminate copy if the value can be placed in the alternate memory. |
| can_eliminate_copy = |
| options_.is_allowed_in_alternate_mem_fn(*alternate_mem_interval.buffer); |
| } else { |
| // If there has been a previous allocation, we can eliminate the copy if the |
| // previous allocation was also in the alternate memory. |
| prev_allocation = allocations->back().get(); |
| can_eliminate_copy = |
| (prev_allocation->memory_space() == MemorySpace::kAlternate); |
| } |
| |
| if (!can_eliminate_copy) { |
| return false; |
| } |
| |
| if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( |
| non_bitcast_operand->shape(), start_time + 1, end_time)) { |
| return false; |
| } |
| |
| alternate_mem_interval.start = start_time; |
| |
| // Prefer the offset that was previously used for the previous allocation. |
| int64 preferred_offset = -1; |
| if (prev_allocation != nullptr) { |
| preferred_offset = prev_allocation->chunk().offset; |
| // If there is a previous allocation, set the start time one after the end |
| // of the previous allocation's end. |
| alternate_mem_interval.start = prev_allocation->end_time() + 1; |
| } |
| |
| VLOG(4) << "We can eliminate copy to alternate memory. Preferred offset = " |
| << preferred_offset; |
| // In case there are additional uses after this use, we rely on the last use |
| // time to try to reserve a chunk in the heap simulator. This is to prevent |
| // the following scenario: |
| // |
| // +-------+ |
| // / \ |
| // Producer--->Use1 +-->Use2 |
| // +---------+---------+ |
| // New buffer: | | | |
| // +---------+---------+ |
| // |
| // +-----------+ |
| // Current heap: | offset: 0 | |
| // --------------------------+-----------+------ |
| // |
| // Because we allocate buffers greedily, Producer to Use1 segment first, and |
| // then Use1 to Use2 segment, it is possible to allocate the first segment at |
| // an offset that is available for the first segment (e.g. offset 0) but not |
| // for the entire live range. This can result in unnecessary copies. By using |
| // the last use time, we try to find an allocation that is available for the |
| // entire Producer to Use2 range. |
| alternate_mem_interval.end = last_use_time; |
| ChunkCandidate chunk_candidate = |
| FindChunkCandidate(alternate_mem_interval, preferred_offset); |
| alternate_mem_interval.end = end_time; |
| // Check if the new heap size fits within limits. Also ensure if a |
| // preferred offset was provided, that offset was used. |
| if (chunk_candidate.heap_size <= available_heap_size() && |
| (preferred_offset == -1 || |
| preferred_offset == chunk_candidate.chunk.offset)) { |
| VLOG(3) << "Keep the buffer in alternate memory. Offset = " |
| << chunk_candidate.chunk.offset |
| << ", size = " << chunk_candidate.chunk.size |
| << ", heap_size = " << chunk_candidate.heap_size |
| << ", prefetch picker = " |
| << options_.prefetch_interval_picker->ToNoCopyDebugString( |
| non_bitcast_operand->shape(), start_time, end_time); |
| AddToPendingChunks(alternate_mem_interval, chunk_candidate); |
| |
| // If there was a previous allocation, the buffer location is the |
| // same as the previous. Otherwise, it is the operand. |
| if (prev_allocation != nullptr && |
| (prev_allocation->is_copy_allocation() || |
| prev_allocation->defining_position() == defining_position)) { |
| prev_allocation->Extend(end_time); |
| } else { |
| allocations->push_back( |
| absl::make_unique<MemorySpaceAssignment::Allocation>( |
| non_bitcast_operand, defining_position, MemorySpace::kAlternate, |
| chunk_candidate.chunk, start_time, end_time)); |
| } |
| allocations->back()->AddUse(use); |
| return true; |
| } |
| return false; |
| } |
| |
| /*static*/ int64 MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies( |
| const HloModule& module) { |
| int64 max_copies = 0; |
| int64 current_copies = 0; |
| for (HloInstruction* instruction : |
| module.schedule().sequence(module.entry_computation()).instructions()) { |
| if (instruction->opcode() == HloOpcode::kCopyStart) { |
| current_copies++; |
| } else if (instruction->opcode() == HloOpcode::kCopyDone) { |
| current_copies--; |
| } |
| max_copies = std::max(max_copies, current_copies); |
| } |
| return max_copies; |
| } |
| |
| /*static*/ MemorySpaceAssignment::BufferIntervalCompare |
| MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( |
| const MemorySpaceAssignmentCostAnalysis& cost_analysis) { |
| return [&](const BufferInterval& x, const BufferInterval& y) { |
| // Returns a heuristic value that captures how much putting this tensor to |
| // the alternate memory would help if the op is memory bound, or otherwise |
| // how far off is the op to memory boundedness. The larger this number, the |
| // higher priority it will be placed in the alternate memory. |
| auto get_alternate_mem_benefit = |
| [&](const HloInstruction& instruction, |
| float elapsed_time_due_to_alternate_mem) { |
| float elapsed_time_due_to_compute = |
| cost_analysis.GetInstructionElapsedDueToCompute(instruction); |
| float elapsed_time_due_to_memory = |
| cost_analysis.GetInstructionElapsedDueToMemory(instruction); |
| if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) { |
| // Memory bound, return how much alternate memory is better. |
| return elapsed_time_due_to_memory - |
| elapsed_time_due_to_alternate_mem; |
| } else { |
| // Compute bound, return how far off are we to memory boundedness. |
| return elapsed_time_due_to_memory - elapsed_time_due_to_compute; |
| } |
| }; |
| |
| auto get_memory_boundedness = [&](const BufferInterval& interval) { |
| const HloInstruction& defining_instruction = |
| *interval.buffer->defining_instruction(); |
| float alternate_mem_benefit = get_alternate_mem_benefit( |
| defining_instruction, cost_analysis.GetInstructionElapsedDueToMemory( |
| defining_instruction, |
| /*operand_in_alternate_mem=*/{}, |
| /*output_in_alternate_mem=*/true)); |
| for (const HloUse& use : interval.buffer->uses()) { |
| float use_alternate_mem_benefit = get_alternate_mem_benefit( |
| *use.instruction, cost_analysis.GetInstructionElapsedDueToMemory( |
| *use.instruction, use.operand_number)); |
| // If the benefit is positive (memory bound), add it to this buffer's |
| // benefit. If the benefit is negative (compute bound), calculate the |
| // maximum. |
| if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) { |
| alternate_mem_benefit += use_alternate_mem_benefit; |
| } else { |
| alternate_mem_benefit = |
| std::max(alternate_mem_benefit, use_alternate_mem_benefit); |
| } |
| } |
| return alternate_mem_benefit; |
| }; |
| |
| float x_memory_boundedness = get_memory_boundedness(x); |
| float y_memory_boundedness = get_memory_boundedness(y); |
| if (x_memory_boundedness != y_memory_boundedness) { |
| return x_memory_boundedness > y_memory_boundedness; |
| } |
| // Tie-break if the memory boundedness is the same. |
| return GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare()( |
| x, y); |
| }; |
| } |
| |
| /*static*/ StatusOr<std::unique_ptr<PresetAssignments>> |
| MemorySpaceAssignment::Run(HloModule* module, const Options& options) { |
| CHECK(module->has_schedule()); |
| VLOG(4) << "Module before memory space assignment: "; |
| XLA_VLOG_LINES(4, module->ToString()); |
| VLOG(4) << "Schedule: " << module->schedule().ToString(); |
| TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module)); |
| |
| const HloComputation* entry_computation = module->entry_computation(); |
| TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range, |
| HloLiveRange::Run(module->schedule(), *alias_analysis, |
| entry_computation)); |
| MemorySpaceAssignment memory_space_assignment( |
| module, options.alternate_memory_space, *hlo_live_range); |
| auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>( |
| &memory_space_assignment.allocation_map_, options, *alias_analysis, |
| *hlo_live_range); |
| |
| HeapSimulator::Options heap_simulator_options; |
| heap_simulator_options.may_reuse_operand_buffers = false; |
| TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module, |
| module->schedule(), |
| *alias_analysis.get(), options.size_fn, |
| heap_simulator_options) |
| .status()); |
| |
| TF_RETURN_IF_ERROR(memory_space_assignment.Process()); |
| memory_space_assignment.ScheduleAsynchronousCopies(); |
| TF_RETURN_IF_ERROR(memory_space_assignment.SimplifyGraph()); |
| TF_RETURN_IF_ERROR(memory_space_assignment.FixSchedule()); |
| |
| VLOG(4) << "Module after memory space assignment: "; |
| XLA_VLOG_LINES(4, module->ToString()); |
| TF_CHECK_OK(module->schedule().Verify()); |
| VLOG(1) << "Maximum number of outstanding async copies: " |
| << CountMaximumOutstandingAsyncCopies(*module); |
| |
| if (options.verify || VLOG_IS_ON(1)) { |
| TF_RETURN_IF_ERROR(memory_space_assignment.Verify()); |
| } |
| |
| return std::move(memory_space_assignment.preset_assignments_); |
| } |
| |
| void MemorySpaceAssignment::Allocation::AddUse(HloUse use) { |
| HloInstruction* operand = |
| use.instruction->mutable_operand(use.operand_number); |
| // If the use is a tuple, look inside the tuple to find the actual use. |
| for (int64 index : use.operand_index) { |
| if (operand->opcode() != HloOpcode::kTuple) { |
| break; |
| } |
| operand = operand->mutable_operand(index); |
| } |
| |
| // Look beyond GetTupleElement(Tuple()) pattern for any bitcasts. |
| std::function<HloInstruction*(HloInstruction*)> get_simplified_operand; |
| get_simplified_operand = [&](HloInstruction* instruction) { |
| while (instruction->opcode() == HloOpcode::kGetTupleElement) { |
| HloInstruction* operand = |
| get_simplified_operand(instruction->mutable_operand(0)); |
| if (operand->opcode() == HloOpcode::kTuple) { |
| instruction = operand->mutable_operand(instruction->tuple_index()); |
| } else { |
| return instruction; |
| } |
| } |
| return instruction; |
| }; |
| operand = get_simplified_operand(operand); |
| |
| uses_.push_back(use); |
| } |
| |
| Status MemorySpaceAssignment::Allocation::Process( |
| MemorySpaceAssignment* memory_space_assignment) { |
| return Status::OK(); |
| } |
| |
| StatusOr<HloInstruction*> MemorySpaceAssignment::Allocation::ReplaceTupleWith( |
| HloInstruction* new_instruction, HloInstruction* tuple, |
| ShapeIndex shape_index) { |
| const Shape& tuple_shape = tuple->shape(); |
| CHECK(tuple->shape().IsTuple()) |
| << "ReplaceTupleWith was called for a non-tuple. Tuple = " |
| << tuple->ToString() |
| << ", new_instruction = " << new_instruction->ToString() |
| << ", shape_index = " << shape_index.ToString(); |
| |
| HloComputation* computation = new_instruction->parent(); |
| std::vector<HloInstruction*> tuple_args(tuple_shape.tuple_shapes_size()); |
| for (int64 i = 0; i < tuple_shape.tuple_shapes_size(); ++i) { |
| const Shape& subshape = tuple_shape.tuple_shapes(i); |
| if (i == shape_index[0]) { |
| // If the subshape is still a tuple, recurse and pass a new shape index |
| // for the one level deeper. |
| if (subshape.IsTuple()) { |
| HloInstruction* get_tuple_element = computation->AddInstruction( |
| HloInstruction::CreateGetTupleElement(subshape, tuple, i)); |
| TF_ASSIGN_OR_RETURN(tuple_args[i], |
| ReplaceTupleWith(new_instruction, get_tuple_element, |
| ShapeIndex(shape_index.begin() + 1, |
| shape_index.end()))); |
| } else { |
| if (subshape != new_instruction->shape()) { |
| VLOG(4) << "Old shape = " << subshape.ToString() |
| << ", new shape = " << new_instruction->shape().ToString() |
| << "; inserting a bitcast."; |
| new_instruction = computation->AddInstruction( |
| HloInstruction::CreateBitcast(subshape, new_instruction)); |
| } |
| tuple_args[i] = new_instruction; |
| } |
| } else { |
| HloInstruction* get_tuple_element = computation->AddInstruction( |
| HloInstruction::CreateGetTupleElement(subshape, tuple, i)); |
| tuple_args[i] = get_tuple_element; |
| } |
| } |
| return computation->AddInstruction(HloInstruction::CreateTuple(tuple_args)); |
| } |
| |
| Status MemorySpaceAssignment::CopyAllocation::Process( |
| MemorySpaceAssignment* memory_space_assignment) { |
| // Copy allocations need to insert asynchronous copy nodes. |
| HloInstruction* producing_instruction = defining_position().instruction; |
| CHECK_NE(producing_instruction, nullptr); |
| |
| Shape shape = defining_position().shape(); |
| CHECK(shape.IsArray()) << "CopyAllocation shape is not an array. Shape = " |
| << shape.ToString() |
| << " position = " << defining_position().shape(); |
| HloComputation* computation = producing_instruction->parent(); |
| |
| // If the instruction we're copying from is a tuple, we (recursively) create |
| // kGetTupleElement instructions and copy that value. Asynchronous copies only |
| // support array types. |
| if (!producing_instruction->shape().IsArray()) { |
| producing_instruction = defining_position().instruction; |
| for (int64 index : defining_position().index) { |
| producing_instruction = |
| computation->AddInstruction(HloInstruction::CreateGetTupleElement( |
| producing_instruction->shape().tuple_shapes(index), |
| producing_instruction, index)); |
| } |
| } |
| copy_start_ = computation->AddInstruction(HloInstruction::CreateUnary( |
| ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}), |
| HloOpcode::kCopyStart, producing_instruction)); |
| copy_done_ = computation->AddInstruction( |
| HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_)); |
| // Update the allocation with the copy done instruction so that if there |
| // are further copies from it, it can find the correct instruction. |
| instruction_ = copy_done_; |
| |
| // Also update the defining position. |
| defining_position_ = HloPosition{copy_done_, {}}; |
| |
| // Replace all the uses with the new copy instruction. |
| for (HloUse use : uses_) { |
| // If the operand is a tuple, we need to descend to the actual instruction |
| // we want to replace. |
| HloInstruction* replacement_instruction; |
| Shape operand_shape = use.instruction->operand(use.operand_number)->shape(); |
| if (operand_shape.IsTuple()) { |
| TF_ASSIGN_OR_RETURN( |
| replacement_instruction, |
| ReplaceTupleWith(copy_done_, |
| use.instruction->mutable_operand(use.operand_number), |
| use.operand_index)); |
| } else if (operand_shape != copy_done_->shape()) { |
| VLOG(4) << "Old shape = " << operand_shape.ToString() |
| << ", new shape = " << copy_done_->shape().ToString() |
| << "; inserting a bitcast."; |
| replacement_instruction = computation->AddInstruction( |
| HloInstruction::CreateBitcast(operand_shape, copy_done_)); |
| } else { |
| replacement_instruction = copy_done_; |
| } |
| TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith( |
| use.operand_number, replacement_instruction)); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status MemorySpaceAssignment::Process() { |
| // Insert CopyStart/CopyDone pairs. |
| int64 alternate_memory_size = 0; |
| for (auto& buffer_and_sequence : allocation_map_) { |
| for (auto& allocation : buffer_and_sequence.second) { |
| TF_RETURN_IF_ERROR(allocation->Process(this)); |
| // Add the offset and size of the allocation in the alternate memory to |
| // the output map. Special case for bitcast: since bitcast doesn't define |
| // its own buffer, that shouldn't be exported as a preset chunk. |
| if (allocation->memory_space() == MemorySpace::kAlternate && |
| allocation->instruction()->opcode() != HloOpcode::kBitcast) { |
| preset_assignments_->add_chunk(allocation->defining_position(), |
| allocation->chunk()); |
| alternate_memory_size = |
| std::max(alternate_memory_size, allocation->chunk().chunk_end()); |
| } |
| } |
| } |
| |
| if (!preset_assignments_->chunks().empty()) { |
| preset_assignments_->add_size(alternate_memory_space_, |
| alternate_memory_size); |
| } |
| |
| if (VLOG_IS_ON(3)) { |
| VLOG(3) << "Exported alternate memory allocations:"; |
| for (auto& pair : preset_assignments_->chunks()) { |
| VLOG(3) << " [" << pair.second.offset << ", " << pair.second.size |
| << "] : " << pair.first.ToString(); |
| } |
| VLOG(3) << "Exported alternate memory sizes:"; |
| for (auto& pair : preset_assignments_->sizes()) { |
| VLOG(3) << " space: " << pair.first << ", size: " << pair.second; |
| } |
| } |
| |
| // Color the pending positions and all of their aliased buffers. |
| TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module_)); |
| for (const auto& defining_position_and_chunk : |
| preset_assignments_->chunks()) { |
| const HloPosition& defining_position = defining_position_and_chunk.first; |
| for (auto& buffer : alias_analysis->ComputeBuffersAt( |
| defining_position.instruction, defining_position.index)) { |
| for (auto& value : buffer->values()) { |
| for (auto& position : value->positions()) { |
| VLOG(3) << "Coloring " << position.ToString(); |
| Shape* shape = ShapeUtil::GetMutableSubshape( |
| position.instruction->mutable_shape(), position.index); |
| CHECK(shape->IsArray()) << "Coloring a shape that is not an array: " |
| << position.ToString(); |
| shape->mutable_layout()->set_memory_space(alternate_memory_space_); |
| } |
| } |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| void PresetAssignments::RemoveAssignmentForInstruction( |
| const HloInstruction* instruction) { |
| for (auto& position_and_chunk : chunks_) { |
| const HloPosition& position = position_and_chunk.first; |
| if (position.instruction == instruction) { |
| VLOG(3) << "Removing instruction from preset assignments."; |
| // Swap the removed position and chunk with the back and pop back. |
| position_and_chunk = chunks_.back(); |
| chunks_.pop_back(); |
| break; |
| } |
| } |
| } |
| |
| Status MemorySpaceAssignment::SimplifyGraph() { |
| for (HloComputation* computation : module_->MakeNonfusionComputations()) { |
| // Parallel computations aren't in the schedule and don't need to be |
| // modified. |
| if (!computations_in_schedule_.contains(computation)) { |
| VLOG(4) << "Not simplifying " << computation->name() |
| << " because it's not in the schedule."; |
| continue; |
| } |
| // Drop control dependencies. Since the computation is already scheduled, we |
| // don't need control dependencies anymore, and having control |
| // predecessors/successors prevents us from removing instructions without |
| // users (HloComputation::IsSafelyRemovable returns false if there are |
| // control dependencies). |
| for (HloInstruction* instruction : |
| computation->MakeInstructionPostOrder()) { |
| TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); |
| } |
| // We perform limited DCE and forward the tuple operand in patterns like |
| // GetTupleElement(Tuple(a, b), 0). This is mostly because memory space |
| // assignment is ran late in compilation (after DCE and arithmetic |
| // simplification passes) and we don't want to generate redundant code. Run |
| // to fixed point. |
| bool computation_modified = true; |
| while (computation_modified) { |
| computation_modified = false; |
| VLOG(4) << "Running simplify graph loop over " << computation->name(); |
| for (HloInstruction* instruction : |
| computation->MakeInstructionPostOrder()) { |
| if (computation->IsSafelyRemovable(instruction) && |
| instruction->user_count() == 0 && !instruction->HasSideEffect() && |
| instruction != computation->root_instruction()) { |
| VLOG(4) << "Instruction removed: " << instruction->ToString(); |
| // Ensure the exported preset assignments don't contain a reference to |
| // the removed instruction. |
| preset_assignments_->RemoveAssignmentForInstruction(instruction); |
| // Instead of deleting the instruction from the schedule, replace it |
| // with a nullptr. This is needed because FixSchedule relies on the |
| // logical time that is the index into flattened_instructions_ for |
| // scheduling asynchronous copies. |
| auto instruction_it = |
| absl::c_find(flattened_instructions_, instruction); |
| if (instruction_it != flattened_instructions_.end()) { |
| *instruction_it = nullptr; |
| } |
| TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); |
| computation_modified = true; |
| } else if (instruction->opcode() == HloOpcode::kGetTupleElement) { |
| HloInstruction* operand = instruction->mutable_operand(0); |
| if (operand->opcode() == HloOpcode::kTuple) { |
| HloInstruction* forwarded_instruction = |
| operand->mutable_operand(instruction->tuple_index()); |
| VLOG(4) << "Replacing uses of " << instruction->ToString() |
| << " with " << forwarded_instruction->ToString(); |
| TF_RETURN_IF_ERROR( |
| instruction->ReplaceAllUsesWith(forwarded_instruction)); |
| computation_modified = true; |
| } |
| } |
| } |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| void MemorySpaceAssignment::EnsureInstructionAndOperandsInserted( |
| HloInstruction* new_instruction, HloInstructionSequence* new_sequence, |
| absl::flat_hash_set<HloInstruction*>* inserted_instructions) const { |
| if (inserted_instructions->contains(new_instruction)) { |
| return; |
| } |
| for (HloInstruction* operand : new_instruction->operands()) { |
| // CopyStart/CopyDone dependencies should always be already inserted; it is |
| // a red flag when they haven't already been inserted. |
| CHECK((operand->opcode() != HloOpcode::kCopyStart && |
| operand->opcode() != HloOpcode::kCopyDone) || |
| inserted_instructions->contains(operand)) |
| << "Inserted instruction " << new_instruction->ToString() |
| << " has un-inserted dependency: " << operand->ToString(); |
| EnsureInstructionAndOperandsInserted(operand, new_sequence, |
| inserted_instructions); |
| } |
| VLOG(4) << "inserting: " << new_instruction->ToShortString(); |
| new_sequence->push_back(new_instruction); |
| inserted_instructions->insert(new_instruction); |
| } |
| |
| void MemorySpaceAssignment::ScheduleAsynchronousCopies() { |
| for (MemorySpace memory_space : |
| {MemorySpace::kDefault, MemorySpace::kAlternate}) { |
| std::vector<CopyAllocation*> copy_allocations; |
| for (auto& buffer_and_sequence : allocation_map_) { |
| for (auto& allocation : buffer_and_sequence.second) { |
| if (allocation->is_copy_allocation()) { |
| auto copy_allocation = static_cast<CopyAllocation*>(allocation.get()); |
| if (copy_allocation->memory_space() == memory_space) { |
| copy_allocations.push_back(copy_allocation); |
| } |
| } |
| } |
| } |
| |
| absl::c_stable_sort( |
| copy_allocations, [](CopyAllocation* first, CopyAllocation* second) { |
| return std::forward_as_tuple(first->copy_done_schedule_before(), |
| first->copy_start_schedule_after()) < |
| std::forward_as_tuple(second->copy_done_schedule_before(), |
| second->copy_start_schedule_after()); |
| }); |
| |
| CopyAllocation* prev_copy_allocation = nullptr; |
| for (CopyAllocation* copy_allocation : copy_allocations) { |
| // If the copy start doesn't happen to be scheduled at the correct |
| // computation, delay it until the correct computation starts. |
| int64 copy_start_schedule_after = |
| copy_allocation->copy_start_schedule_after(); |
| // Accessing flattened_instructions_ here without checking if it is |
| // nullptr is safe because this method is called before SimplifyGraph. |
| while (copy_allocation->instruction()->parent() != |
| flattened_instructions_[copy_start_schedule_after]->parent()) { |
| VLOG(4) << "Delaying CopyStart (" << copy_start_schedule_after << " to " |
| << (copy_start_schedule_after + 1) << ") for " |
| << copy_allocation->copy_start()->ToString() |
| << " because it is not in the correct computation."; |
| copy_allocation->set_copy_start_schedule_after( |
| ++copy_start_schedule_after); |
| } |
| |
| schedule_after_[copy_allocation->copy_start_schedule_after()].push_back( |
| copy_allocation->copy_start()); |
| schedule_before_[copy_allocation->copy_done_schedule_before()].push_back( |
| copy_allocation->copy_done()); |
| prev_copy_allocation = copy_allocation; |
| } |
| } |
| } |
| |
| Status MemorySpaceAssignment::FixSchedule() { |
| CHECK(module_->has_schedule()); |
| HloSchedule& schedule = module_->schedule(); |
| for (const HloComputation* computation : |
| module_->MakeNonfusionComputations()) { |
| // Parallel computations aren't in the schedule and don't need to be |
| // modified. |
| if (!computations_in_schedule_.contains(computation)) { |
| VLOG(4) << "Not scheduling " << computation->name() |
| << " because it's not in the schedule."; |
| continue; |
| } |
| CHECK(schedule.is_computation_scheduled(computation)); |
| HloInstructionSequence new_sequence; |
| |
| absl::flat_hash_set<HloInstruction*> inserted_instructions; |
| |
| VLOG(4) << "Scheduling: " << computation->ToString(); |
| |
| for (int64 instruction_index = 0; |
| instruction_index < flattened_instructions_.size(); |
| ++instruction_index) { |
| auto insts_before_iter = schedule_before_.find(instruction_index); |
| if (insts_before_iter != schedule_before_.end()) { |
| for (HloInstruction* new_instruction : insts_before_iter->second) { |
| if (new_instruction->parent() == computation) { |
| VLOG(4) << "before " << instruction_index << ": " |
| << new_instruction->name(); |
| EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence, |
| &inserted_instructions); |
| } |
| } |
| } |
| HloInstruction* instruction = flattened_instructions_[instruction_index]; |
| // Insert only if it is not deleted (SimplifyGraph sets it to nullptr if |
| // it was deleted) and not previously inserted. Also bitcasts and tuples |
| // are treated specially and only inserted as a result of operand |
| // dependencies. |
| if (instruction != nullptr && |
| !inserted_instructions.contains(instruction) && |
| instruction->parent() == computation && |
| instruction->opcode() != HloOpcode::kBitcast && |
| instruction->opcode() != HloOpcode::kTuple) { |
| VLOG(4) << "inst " << instruction_index << ": " << instruction->name(); |
| EnsureInstructionAndOperandsInserted(instruction, &new_sequence, |
| &inserted_instructions); |
| } |
| auto insts_after_iter = schedule_after_.find(instruction_index); |
| if (insts_after_iter != schedule_after_.end()) { |
| for (HloInstruction* new_instruction : insts_after_iter->second) { |
| if (new_instruction->parent() == computation) { |
| VLOG(4) << "after " << instruction_index << ": " |
| << new_instruction->name(); |
| EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence, |
| &inserted_instructions); |
| } |
| } |
| } |
| } |
| // For rare cases where the original sequence is empty, ensure the root |
| // instruction and its dependencies are scheduled. |
| EnsureInstructionAndOperandsInserted(computation->root_instruction(), |
| &new_sequence, &inserted_instructions); |
| CHECK_EQ(new_sequence.size(), computation->instruction_count()) |
| << "New sequence for computation " << computation->name() << " has " |
| << new_sequence.size() << " instructions, expects " |
| << computation->instruction_count() << "."; |
| schedule.set_sequence(computation, new_sequence); |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status MemorySpaceAssignment::Verify() const { |
| VLOG(3) << "Verifying:"; |
| TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, |
| HloAliasAnalysis::Run(module_)); |
| TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range, |
| HloLiveRange::Run(module_->schedule(), *alias_analysis, |
| module_->entry_computation())); |
| |
| BufferIntervalTree interval_tree; |
| absl::flat_hash_set<int64> seen_buffers; |
| |
| for (const auto& position_and_chunk : preset_assignments_->chunks()) { |
| const HloPosition& position = position_and_chunk.first; |
| const Chunk& chunk = position_and_chunk.second; |
| const HloBuffer& buffer = |
| alias_analysis->GetUniqueBufferAt(position.instruction, position.index); |
| if (seen_buffers.contains(buffer.id())) { |
| continue; |
| } |
| seen_buffers.insert(buffer.id()); |
| |
| int64 start_time = INT64_MAX; |
| int64 end_time = -1; |
| for (const HloValue* value : buffer.values()) { |
| const HloLiveRange::TimeBound& time_bound = |
| hlo_live_range->buffer_live_ranges().at(value); |
| VLOG(3) << " value: " << value->ToShortString() << " (" |
| << time_bound.start << ", " << time_bound.end << ")"; |
| start_time = std::min(start_time, time_bound.start); |
| end_time = std::max(end_time, time_bound.end); |
| } |
| CHECK_GE(start_time, 0); |
| CHECK_GT(end_time, 0); |
| // Get the chunks overlapping in time and search if they overlap in space as |
| // well. |
| // TODO(berkin): For now checking against end_time - 1 (exclusive), but we |
| // really should check against end_time (inclusive) for cases where the |
| // operand can't share buffer with user (see |
| // HloDataflowAnalysis::CanShareOperandBufferWithUser). |
| for (const Chunk& overlapping_chunk : |
| interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) { |
| if (chunk.OverlapsWith(overlapping_chunk)) { |
| return InternalError( |
| ("Buffer %s (%d, %d) off: %d size: %d overlaps with another chunk" |
| " off: %d size: %d"), |
| buffer.ToString(), start_time, end_time, chunk.offset, chunk.size, |
| overlapping_chunk.offset, overlapping_chunk.size); |
| } |
| } |
| interval_tree.Add(start_time, end_time - 1, chunk); |
| VLOG(3) << " buffer: " << buffer.ToString() << ": (" << start_time << ", " |
| << end_time << ") off: " << position_and_chunk.second.offset |
| << ", size: " << position_and_chunk.second.size; |
| } |
| |
| return Status::OK(); |
| } |
| |
| } // namespace xla |