blob: 692b100752c32acd37a4f997cccf70d5df1fa394 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
See the License for the specific language governing permissions and
limitations under the License.
#include "tensorflow/compiler/xla/service/memory_space_assignment.h"
namespace xla {
namespace {
// Define a dummy chunk for chunks that will be allocated in the default memory
// space and for keeping track of number of asynchronous copies.
const HeapSimulator::Chunk kDummyChunk{-1, -1};
} // namespace
std::vector<const GlobalDecreasingSizeBestFitHeap::BufferInterval*>
const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const {
std::vector<const BufferInterval*> colocated_intervals;
std::vector<const BufferInterval*> worklist = {&interval};
while (!worklist.empty()) {
const BufferInterval* item = worklist.back();
for (const HloValue* buffer_colocated : item->colocations) {
absl::c_sort(colocated_intervals, [&](const BufferInterval* x,
const BufferInterval* y) {
return std::make_pair(x->start, x->end) < std::make_pair(y->start, y->end);
return colocated_intervals;
HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
std::vector<BufferInterval> sorted_buffer_intervals =
VLOG(1) << "Assigning buffers to alternate memory. Max heap size = "
<< max_size_in_bytes_
<< ", min prefetch interval = " << min_prefetch_interval_
<< ", max prefetch interval = " << max_prefetch_interval_;
for (auto& interval : sorted_buffer_intervals) {
if (!interval.need_allocation) {
// Skip if we have already allocated for this buffer.
const HloBuffer& buffer =
if (allocation_map_->contains(&buffer)) {
// If the buffer is a tuple, don't use this algorithm for now. The buffers
// that are pointed to by the tuple will still use this algorithm.
// TODO(berkin): Because tuples are cheap to place in the alternate memory
// (they are just pointers) we don't need to use prefetch/evict logic.
if (buffer.values()[0]->shape().IsTuple()) {
VLOG(4) << "Keeping buffer " << buffer.ToString()
<< " in default mem because it is a tuple.";
auto colocated_intervals = GetSortedColocatedIntervals(interval);
bool keep_in_default_memory = false;
for (const BufferInterval* colocated_interval : colocated_intervals) {
const HloValue* value = colocated_interval->buffer;
// If any of the colocated values are phi buffers, we keep them in the
// default memory for now.
if (value->is_phi()) {
keep_in_default_memory = true;
VLOG(4) << "Keeping value " << value->ToShortString()
<< " because it contains a phi node.";
MemorySpaceAssignment::AllocationSequence* allocation_sequence =
// At this point, none of the colocated buffers contain any phi buffers.
for (const BufferInterval* colocated_interval : colocated_intervals) {
if (keep_in_default_memory) {
const HloValue* value = colocated_interval->buffer;
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
int64 definition_time =>defining_instruction());
// Sort the uses by the use time.
std::vector<HloUse> uses = value->uses();
absl::c_sort(uses, [&](HloUse use1, HloUse use2) {
return <;
// Iterate over the uses.
for (HloUse use : uses) {
int64 use_time =;
int64 last_use_time =;
// Bitcasts don't define buffers and don't directly consume buffers.
// Skip allocating buffers for bitcast uses. The uses that feed from
// bitcasts will be handled specially.
if (use.instruction->opcode() != HloOpcode::kBitcast) {
if (!FindAllocation(definition_time, use_time, last_use_time,
value->defining_position(), use, value,
colocated_interval->size, allocation_sequence)) {
// If the allocation finding failed (e.g., due to running out of
// asynchronous copies), then fall back to allocating the buffer
// entirely in the default memory.
keep_in_default_memory = true;
// If there are multiple uses, they can try using the memory
// allocation already at the alternate memory.
definition_time = use_time;
if (VLOG_IS_ON(3)) {
for (const auto& alloc_pair : *allocation_map_) {
VLOG(3) << "Allocation for " << alloc_pair.first->ToString();
for (const auto& alloc : alloc_pair.second) {
std::string addr_str = ": default";
if (alloc->memory_space() == MemorySpace::kAlternate) {
addr_str = absl::StrCat(": alt ", alloc->chunk().offset);
VLOG(3) << " " << alloc->start_time() << "-" << alloc->end_time()
<< addr_str << ", " << alloc->uses().size() << " uses";
return result_;
void AlternateMemoryBestFitHeap::CommitPendingChunks() {
for (auto interval_and_chunk : pending_chunks_) {
VLOG(3) << "Committing chunk: " << interval_and_chunk.first.start << "-"
<< interval_and_chunk.first.end << " : ["
<< interval_and_chunk.second.chunk.offset << ", "
<< interval_and_chunk.second.chunk.size << "]";
CommitChunk(interval_and_chunk.first, interval_and_chunk.second);
// Also add the pending async copies to the interval tree.
if (max_outstanding_async_copies_ >= 0) {
for (auto interval : pending_async_copies_) {
async_copy_interval_tree_.Add(interval.first, interval.second,
void AlternateMemoryBestFitHeap::AddToPendingChunks(
const BufferInterval& buffer_interval,
const ChunkCandidate& chunk_candidate) {
pending_chunks_.emplace_back(buffer_interval, chunk_candidate);
bool AlternateMemoryBestFitHeap::FindAllocation(
int64 start_time, int64 end_time, int64 last_use_time,
HloPosition defining_position, HloUse use, const HloValue* buffer,
int64 size, MemorySpaceAssignment::AllocationSequence* allocations) {
HloInstruction* operand =
// If the operand is a bitcast, we look at bitcast's operand until we find a
// non-bitcast operand.
HloInstruction* non_bitcast_operand = operand;
while (non_bitcast_operand->opcode() == HloOpcode::kBitcast) {
non_bitcast_operand = non_bitcast_operand->mutable_operand(0);
// Create an alternate memory interval that starts at the earliest
// possible position, given by max_prefetch_interval.
BufferInterval alternate_mem_interval;
alternate_mem_interval.buffer = buffer;
alternate_mem_interval.size = size;
alternate_mem_interval.start =
std::max(start_time, end_time - max_prefetch_interval_);
alternate_mem_interval.end = end_time;
VLOG(2) << "Finding allocation for " << buffer->ToShortString() << " ("
<< start_time << ", " << end_time << ") last use = " << last_use_time
<< " use = " << use.ToString() << ". Size = " << size
<< ", def pos = " << defining_position.ToString()
<< ", operand = " << operand->ToShortString()
<< (non_bitcast_operand != operand
? ", non_bitcast_operand = " +
: "");
CHECK_LE(start_time, end_time);
// First try keeping the allocation entirely in the alternate memory.
if (TryAllocatingInAlternateMemoryNoCopy(
start_time, end_time, last_use_time, defining_position, use,
alternate_mem_interval, non_bitcast_operand, allocations)) {
return true;
if (defining_position.instruction->parent() != use.instruction->parent() ||
!use.instruction->called_computations().empty()) {
VLOG(3) << "Use is in a different computation or calls a computation.";
// Fail because we do not allow asynchronous copies while in the bodies of
// other computation.
return false;
MemorySpaceAssignment::Allocation* prev_allocation = nullptr;
if (!allocations->empty()) {
prev_allocation = allocations->back().get();
// Since copies couldn't be removed, create an allocation in the default
// memory space.
if (prev_allocation != nullptr &&
prev_allocation->memory_space() == MemorySpace::kAlternate &&
prev_allocation->defining_position() == defining_position) {
// If there was an allocation for this HloValue that was in the alternate
// memory space, we also need to perform an eviction.
// TODO(berkin): For now evictions happen relative to the most recent
// allocation in the alternate memory. We can potentially start evictions
// earlier and end later.
VLOG(3) << "Evicting buffer at " << prev_allocation->chunk().offset << " ("
<< prev_allocation->start_time() << ", "
<< prev_allocation->end_time() << ")";
// See if this interval would violate the asynchronous copy limit.
if (!ViolatesMaximumOutstandingAsyncCopies(prev_allocation->start_time(),
prev_allocation->end_time())) {
AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, kDummyChunk,
prev_allocation->start_time(), prev_allocation->end_time(),
} else {
VLOG(3) << "This violates the maximum async copies.";
// If the original interval violated the limit, try sub-intervals within
// this interval.
bool eviction_scheduled = false;
for (int64 time = prev_allocation->start_time();
time <= prev_allocation->end_time(); ++time) {
VLOG(3) << "Try evicting (" << time << ", " << time << ")";
if (!ViolatesMaximumOutstandingAsyncCopies(time, time)) {
VLOG(3) << "Eviction successful.";
AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, kDummyChunk,
time, time, allocations);
eviction_scheduled = true;
if (!eviction_scheduled) {
// If the eviction couldn't be scheduled, then fail. This buffer will be
// kept in the default memory.
VLOG(3) << "Bailing: Could not evict " << use.ToString()
<< " because we hit the limit of maximum asynchronous copies "
<< "between "
<< hlo_live_range_.flattened_instruction_sequence()
<< " and "
<< hlo_live_range_.flattened_instruction_sequence()
return false;
} else if (prev_allocation != nullptr &&
prev_allocation->memory_space() == MemorySpace::kDefault &&
prev_allocation->defining_position() == defining_position) {
// If the previous allocation was in the default memory space and was
// defined by the same instruction, extend that. Otherwise, create a new
// allocation.
} else {
non_bitcast_operand, defining_position, MemorySpace::kDefault,
kDummyChunk, start_time, end_time));
// Try partially placing the buffer in the alternate space. The time that is
// overlapped will be used to asynchronously copy the buffer from the
// default memory to the alternate memory.
// start end
// time time
// X---------------------X
// Alternate: +------+
// Default: +---------------------+
// ^ ^
// Copy Copy
// Start Done
for (alternate_mem_interval.start =
std::max(start_time, end_time - max_prefetch_interval_);
alternate_mem_interval.end - alternate_mem_interval.start >
++alternate_mem_interval.start) {
VLOG(4) << "Trying alternate memory allocation ("
<< alternate_mem_interval.start << ", "
<< alternate_mem_interval.end << ")";
// If this additional asynchronous copy would violate the limit, try a
// different interval.
if (ViolatesMaximumOutstandingAsyncCopies(alternate_mem_interval.start,
alternate_mem_interval.end)) {
VLOG(4) << "This would violate the outstanding async copy limit.";
ChunkCandidate chunk_candidate = FindChunkCandidate(alternate_mem_interval);
// Check if the new heap size fits within limits.
if (chunk_candidate.heap_size < max_size_in_bytes_) {
VLOG(3) << "Move the buffer to alternate memory at "
<< alternate_mem_interval.start
<< ". Offset = " << chunk_candidate.chunk.offset
<< ", size = " << chunk_candidate.chunk.size
<< ", heap_size = " << chunk_candidate.heap_size;
AddToPendingChunks(alternate_mem_interval, chunk_candidate);
AddAsyncCopy(*allocations->back().get(), MemorySpace::kAlternate,
chunk_candidate.chunk, alternate_mem_interval.start,
end_time, allocations);
return true;
// If a copy wasn't inserted, then add this use to the latest allocation.
return true;
void AlternateMemoryBestFitHeap::AddAsyncCopy(
const MemorySpaceAssignment::Allocation& prev_allocation,
MemorySpace memory_space, Chunk chunk, int64 start_time, int64 end_time,
MemorySpaceAssignment::AllocationSequence* allocations) {
VLOG(3) << "Copy to "
<< (memory_space == MemorySpaceAssignment::MemorySpace::kDefault
? "default"
: "alternate")
<< " memory between " << start_time << " and " << end_time;
prev_allocation, memory_space, chunk, start_time, end_time));
// Register the additional async copy with the interval tree to keep track of
// the limit at any given time.
pending_async_copies_.emplace_back(start_time, end_time);
bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
int64 start_time, int64 end_time) const {
if (max_outstanding_async_copies_ < 0) {
return false;
// Count both the asynchronous copies in the interval tree as well as the
// pending asynchronous copies belonging to this buffer.
int64 num_async_copies =
async_copy_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
for (auto interval : pending_async_copies_) {
if (interval.second > start_time && interval.first < end_time) {
// Add one because we are checking if adding an additional asynchronous copy
// would violate the limit.
return num_async_copies + 1 > max_outstanding_async_copies_;
bool AlternateMemoryBestFitHeap::TryAllocatingInAlternateMemoryNoCopy(
int64 start_time, int64 end_time, int64 last_use_time,
HloPosition defining_position, HloUse use,
BufferInterval alternate_mem_interval, HloInstruction* non_bitcast_operand,
MemorySpaceAssignment::AllocationSequence* allocations) {
MemorySpaceAssignment::Allocation* prev_allocation = nullptr;
bool can_eliminate_copy = false;
if (allocations->empty()) {
// There hasn't been any allocations for this interval so far. We can
// eliminate copy if the value can be placed in the alternate memory.
can_eliminate_copy =
} else {
// If there has been a previous allocation, we can eliminate the copy if the
// previous allocation was also in the alternate memory.
prev_allocation = allocations->back().get();
can_eliminate_copy =
(prev_allocation->memory_space() == MemorySpace::kAlternate);
if (!can_eliminate_copy) {
return false;
if (alternate_mem_interval.start != start_time) {
return false;
// Prefer the offset that was previously used for the previous allocation.
int64 preferred_offset = -1;
if (prev_allocation != nullptr) {
preferred_offset = prev_allocation->chunk().offset;
// If there is a previous allocation, set the start time one after the end
// of the previous allocation's end.
alternate_mem_interval.start = prev_allocation->end_time() + 1;
VLOG(4) << "We can eliminate copy to alternate memory. Preferred offset = "
<< preferred_offset;
// In case there are additional uses after this use, we rely on the last use
// time to try to reserve a chunk in the heap simulator. This is to prevent
// the following scenario:
// +-------+
// / \
// Producer--->Use1 +-->Use2
// +---------+---------+
// New buffer: | | |
// +---------+---------+
// +-----------+
// Current heap: | offset: 0 |
// --------------------------+-----------+------
// Because we allocate buffers greedily, Producer to Use1 segment first, and
// then Use1 to Use2 segment, it is possible to allocate the first segment at
// an offset that is available for the first segment (e.g. offset 0) but not
// for the entire live range. This can result in unnecessary copies. By using
// the last use time, we try to find an allocation that is available for the
// entire Producer to Use2 range.
alternate_mem_interval.end = last_use_time;
ChunkCandidate chunk_candidate =
FindChunkCandidate(alternate_mem_interval, preferred_offset);
alternate_mem_interval.end = end_time;
// Check if the new heap size fits within limits. Also ensure if a
// preferred offset was provided, that offset was used.
if (chunk_candidate.heap_size < max_size_in_bytes_ &&
(preferred_offset == -1 ||
preferred_offset == chunk_candidate.chunk.offset)) {
VLOG(3) << "Keep the buffer in alternate memory. Offset = "
<< chunk_candidate.chunk.offset
<< ", size = " << chunk_candidate.chunk.size
<< ", heap_size = " << chunk_candidate.heap_size;
AddToPendingChunks(alternate_mem_interval, chunk_candidate);
// If there was a previous allocation, the buffer location is the
// same as the previous. Otherwise, it is the operand.
if (prev_allocation != nullptr &&
(prev_allocation->is_copy_allocation() ||
prev_allocation->defining_position() == defining_position)) {
} else {
non_bitcast_operand, defining_position, MemorySpace::kAlternate,
chunk_candidate.chunk, start_time, end_time));
return true;
return false;
/*static*/ int64 MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(
const HloModule& module) {
int64 max_copies = 0;
int64 current_copies = 0;
for (HloInstruction* instruction :
module.schedule().sequence(module.entry_computation()).instructions()) {
if (instruction->opcode() == HloOpcode::kCopyStart) {
} else if (instruction->opcode() == HloOpcode::kCopyDone) {
max_copies = std::max(max_copies, current_copies);
return max_copies;
/*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
HloModule* module, int64 alternate_memory_space, int64 max_size_in_bytes,
int64 min_prefetch_interval, int64 max_prefetch_interval,
int64 alternate_memory_space_alignment_in_bytes,
BufferValue::SizeFunction size_fn,
int64 max_outstanding_async_copies) {
VLOG(4) << "Module before memory space assignment: ";
XLA_VLOG_LINES(4, module->ToString());
VLOG(4) << "Schedule: " << module->schedule().ToString();
TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module));
MemorySpaceAssignment memory_space_assignment(module, alternate_memory_space);
const HloComputation* entry_computation = module->entry_computation();
HloLiveRange::Run(module->schedule(), *alias_analysis,
// TODO(berkin): Explore heap algorithms other than kSpatial.
auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
&memory_space_assignment.allocation_map_, max_size_in_bytes,
min_prefetch_interval, max_prefetch_interval, *alias_analysis,
is_allowed_in_alternate_mem, max_outstanding_async_copies);
TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module,
*alias_analysis.get(), size_fn)
VLOG(4) << "Module after memory space assignment: ";
XLA_VLOG_LINES(4, module->ToString());
VLOG(1) << "Maximum number of outstanding async copies: "
<< CountMaximumOutstandingAsyncCopies(*module);
return std::move(memory_space_assignment.preset_assignments_);
void MemorySpaceAssignment::Allocation::AddUse(HloUse use) {
HloInstruction* operand =
// When the operand of a use is a bitcast, we place the bitcast in a separate
// data structure.
if (operand->opcode() == HloOpcode::kBitcast) {
} else {
Status MemorySpaceAssignment::Allocation::Process(
MemorySpaceAssignment* memory_space_assignment) {
// For non-copy allocations, all we need to do is to update the output memory
// space if placed in the alternate memory.
if (memory_space_ == MemorySpace::kAlternate) {
return Status::OK();
Status MemorySpaceAssignment::CopyAllocation::Process(
MemorySpaceAssignment* memory_space_assignment) {
// Copy allocations need to insert asynchronous copy nodes.
HloInstruction* producing_instruction = instruction();
CHECK_NE(producing_instruction, nullptr);
Shape shape = producing_instruction->shape();
HloComputation* computation = producing_instruction->parent();
copy_start_ = computation->AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}),
HloOpcode::kCopyStart, producing_instruction));
copy_done_ = computation->AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_));
// Update the allocation with the copy done instruction so that if there
// are further copies from it, it can find the correct instruction.
instruction_ = copy_done_;
// Also update the defining position. Note that the output of CopyDone is
// actually defined in the item {0} of CopyStart.
defining_position_ = HloPosition{copy_start_, {0}};
// Replace all the uses with the new copy instruction.
for (HloUse use : uses_) {
use.instruction->ReplaceOperandWith(use.operand_number, copy_done_));
// Replace all the bitcasts with the new copy instruction. Note that if there
// is a chain of bitcasts, their operands will be replaced with copy done.
// For example:
// a = Foo()
// b = Bitcast(a)
// c = Bitcast(b)
// If a is moved to the alternate memory asynchronously, the graph will be
// changed into:
// a = Foo()
// cs = CopyStart(a)
// cd = CopyDone(cs)
// b = Bitcast(cd)
// c = Bitcast(cd)
// Because of the potential shape change in the operand (b -> cd), we use
// ReplaceOperandWithDifferentShape.
for (HloInstruction* bitcast : bitcasts_) {
/*operand_num=*/0, copy_done_));
if (memory_space_ == MemorySpace::kAlternate) {
return Status::OK();
Status MemorySpaceAssignment::Process() {
// Insert CopyStart/CopyDone pairs.
int64 alternate_memory_size = 0;
for (auto& buffer_and_sequence : allocation_map_) {
for (auto& allocation : buffer_and_sequence.second) {
// Add the offset and size of the allocation in the alternate memory to
// the output map. Special case for bitcast: since bitcast doesn't define
// its own buffer, that shouldn't be exported as a preset chunk.
if (allocation->memory_space() == MemorySpace::kAlternate &&
allocation->instruction()->opcode() != HloOpcode::kBitcast) {
alternate_memory_size =
std::max(alternate_memory_size, allocation->chunk().chunk_end());
if (!preset_assignments_->chunks().empty()) {
if (VLOG_IS_ON(3)) {
VLOG(3) << "Exported alternate memory allocations:";
for (auto& pair : preset_assignments_->chunks()) {
VLOG(3) << " [" << pair.second.offset << ", " << pair.second.size
<< "] : " << pair.first.ToString();
VLOG(3) << "Exported alternate memory sizes:";
for (auto& pair : preset_assignments_->sizes()) {
VLOG(3) << " space: " << pair.first << ", size: " << pair.second;
// Color the pending positions and all of their aliased buffers.
TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module_));
for (HloPosition defining_position : pending_positions_in_alternate_mem_) {
for (auto& buffer : alias_analysis->ComputeBuffersAt(
defining_position.instruction, defining_position.index)) {
for (auto& value : buffer->values()) {
for (auto& position : value->positions()) {
VLOG(3) << "Coloring " << position.ToString();
Shape* shape = ShapeUtil::GetMutableSubshape(
position.instruction->mutable_shape(), position.index);
CHECK(shape->IsArray()) << "Coloring a shape that is not an array: "
<< position.ToString();
return Status::OK();
void MemorySpaceAssignment::EnsureInstructionAndOperandsInserted(
HloInstruction* new_instruction, HloInstructionSequence* new_sequence,
absl::flat_hash_set<HloInstruction*>* inserted_instructions) const {
if (inserted_instructions->contains(new_instruction)) {
for (HloInstruction* operand : new_instruction->operands()) {
EnsureInstructionAndOperandsInserted(operand, new_sequence,
VLOG(4) << "inserting: " << new_instruction->ToShortString();
void MemorySpaceAssignment::AddPositionInAlternateMemorySpace(
HloPosition position) {
void MemorySpaceAssignment::ScheduleAsynchronousCopies() {
// For asynchronous copies of both directions (default to alternate and vice
// versa), sort them by their completion time. Then, if in the sorted order we
// see that the start time is earlier than the start time of an asynchronous
// copy that ends earlier, we delay the start of this. As a result, given
// asynchronous copies that might look like:
// CS CD
// a +-----------+
// b +-----------+
// c +---------+
// We'll first sort by completion time:
// c +---------+
// a +-----------+
// b +-----------+
// Then, delay a because c starts later than a despite also ending earlier:
// c +---------+
// a +---------+
// b +-----------+
for (MemorySpace memory_space :
{MemorySpace::kDefault, MemorySpace::kAlternate}) {
std::vector<CopyAllocation*> copy_allocations;
for (auto& buffer_and_sequence : allocation_map_) {
for (auto& allocation : buffer_and_sequence.second) {
if (allocation->is_copy_allocation()) {
auto copy_allocation = static_cast<CopyAllocation*>(allocation.get());
if (copy_allocation->memory_space() == memory_space) {
copy_allocations, [](CopyAllocation* first, CopyAllocation* second) {
return std::forward_as_tuple(first->copy_done_schedule_before(),
first->copy_start_schedule_after()) <
CopyAllocation* prev_copy_allocation = nullptr;
for (CopyAllocation* copy_allocation : copy_allocations) {
if (prev_copy_allocation &&
prev_copy_allocation->copy_start_schedule_after() >
copy_allocation->copy_start_schedule_after()) {
VLOG(4) << "Delaying CopyStart ("
<< copy_allocation->copy_start_schedule_after() << " to "
<< prev_copy_allocation->copy_start_schedule_after() << ") for "
<< copy_allocation->copy_start()->ToString() << " because of "
<< prev_copy_allocation->copy_start()->ToString();
// If the copy start doesn't happen to be scheduled at the correct
// computation, delay it until the correct computation starts.
const auto& flattened_instructions =
int64 copy_start_schedule_after =
while (copy_allocation->instruction()->parent() !=
flattened_instructions[copy_start_schedule_after]->parent()) {
VLOG(4) << "Delaying CopyStart (" << copy_start_schedule_after << " to "
<< (copy_start_schedule_after + 1) << ") for "
<< copy_allocation->copy_start()->ToString()
<< " because it is not in the correct computation.";
prev_copy_allocation = copy_allocation;
Status MemorySpaceAssignment::FixSchedule() {
HloSchedule& schedule = module_->schedule();
for (const HloComputation* computation :
module_->MakeNonfusionComputations()) {
const HloInstructionSequence& sequence =
HloInstructionSequence new_sequence;
absl::flat_hash_set<HloInstruction*> inserted_instructions;
// Schedule the computations only if needed (if there are unscheduled
// instructions in the computation).
if (computation->instruction_count() ==
schedule.sequence(computation).size()) {
VLOG(4) << "Skip scheduling " << computation->name()
<< " because it is already scheduled.";
VLOG(4) << "Scheduling: " << computation->ToString();
for (int64 instruction_index = 0;
instruction_index < sequence.instructions().size();
++instruction_index) {
HloInstruction* instruction = sequence.instructions()[instruction_index];
if (instruction->parent() != computation) {
auto insts_before_iter = schedule_before_.find(instruction_index);
if (insts_before_iter != schedule_before_.end()) {
for (HloInstruction* new_instruction : insts_before_iter->second) {
EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
// Insert only if not previously inserted.
if (!inserted_instructions.contains(instruction)) {
EnsureInstructionAndOperandsInserted(instruction, &new_sequence,
auto insts_after_iter = schedule_after_.find(instruction_index);
if (insts_after_iter != schedule_after_.end()) {
for (HloInstruction* new_instruction : insts_after_iter->second) {
EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
CHECK_EQ(new_sequence.size(), computation->instruction_count())
<< "New sequence for computation " << computation->name() << " has "
<< new_sequence.size() << " instructions, expects "
<< computation->instruction_count() << ".";
schedule.set_sequence(computation, new_sequence);
return Status::OK();
} // namespace xla