blob: 5a1848546e56f558bbf80cefec8025229899c89b [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/core/platform/random.h"
namespace xla {
namespace gpu {
bool StreamAssignment::HasStreamAssigned(const HloInstruction& hlo) const {
return hlo_to_stream_number_.contains(&hlo);
}
int StreamAssignment::StreamNumberForHlo(const HloInstruction& hlo) const {
return FindOrDie(hlo_to_stream_number_, &hlo);
}
void StreamAssignment::AssignStreamToHlo(const HloInstruction* hlo,
int stream_num) {
CHECK_GE(stream_num, 0);
if (stream_num >= stream_count_) {
stream_count_ = stream_num + 1;
}
InsertOrDie(&hlo_to_stream_number_, hlo, stream_num);
VLOG(2) << "Assign stream #" << stream_num << " to " << hlo->ToString();
}
namespace {
// Returns whether the two HLOs can run concurrently, i.e., neither is a
// transitive consumer of the other.
bool CanRunConcurrently(const HloInstruction& a, const HloInstruction& b,
const HloReachabilityMap& reachability) {
return !reachability.IsConnected(&a, &b);
}
constexpr int kInvalidStreamNum = -1;
// Returns true iff `stream_num` is an invalid stream number.
inline bool IsStreamNumValid(int stream_num) {
return stream_num != kInvalidStreamNum;
}
// Returns which existing stream to assign to `hlo`, or -1 if a stream is not
// needed. `stream_assignment` is the existing stream assignment for all
// instructions topologically before `hlo`. `seen_gemms` contains all GEMMs that
// are topologically before `hlo`.
int ComputeStreamToAssign(
const HloInstruction& hlo, const StreamAssignment& stream_assignment,
const HloReachabilityMap& reachability,
const std::vector<const HloInstruction*>& seen_gemms) {
if (hlo.opcode() == HloOpcode::kParameter ||
hlo.opcode() == HloOpcode::kConstant) {
// kParameter and kConstant do not need a thunk.
return kInvalidStreamNum;
}
const auto& debug_options = hlo.GetModule()->config().debug_options();
if (!debug_options.xla_gpu_disable_multi_streaming()) {
LOG(ERROR) << "Multi streaming is not supported";
}
return 0;
}
} // namespace
std::unique_ptr<StreamAssignment> AssignStreams(const HloModule& module) {
auto stream_assignment = absl::make_unique<StreamAssignment>();
const HloComputation& computation = *module.entry_computation();
std::unique_ptr<HloReachabilityMap> reachability =
HloReachabilityMap::Build(&computation);
std::vector<const HloInstruction*> seen_gemms;
// The execution of different RNG Hlo instructions in the same module updates
// a common global variable. To avoid a race condition, we simply assign all
// RNG kernels to the same stream to make them run sequentially.
//
// TODO(b/111791052): If we remove such a common variable, we will need to
// clean up the code here.
int stream_num_for_rng = kInvalidStreamNum;
for (const auto* hlo : computation.MakeInstructionPostOrder()) {
// If we ever enable fusion of RNG instructions, we will need to extend this
// code to look inside a fused instruction.
int stream_num = (hlo->opcode() == HloOpcode::kRng &&
IsStreamNumValid(stream_num_for_rng))
? stream_num_for_rng
: ComputeStreamToAssign(*hlo, *stream_assignment,
*reachability, seen_gemms);
if (IsStreamNumValid(stream_num)) {
stream_assignment->AssignStreamToHlo(hlo, stream_num);
if (hlo->opcode() == HloOpcode::kRng &&
!IsStreamNumValid(stream_num_for_rng)) {
stream_num_for_rng = stream_num;
}
}
if (IsCublasGemm(*hlo) || IsMatrixMultiplication(*hlo)) {
seen_gemms.push_back(hlo);
}
}
return stream_assignment;
}
} // namespace gpu
} // namespace xla