blob: 627fea6783737890d26d4af74b151cc36e9a8509 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/collective_nccl_gatherer.h"
#ifdef GOOGLE_CUDA
#include "tensorflow/core/common_runtime/collective_util.h"
#include "tensorflow/core/nccl/nccl_manager.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/profiler/lib/traceme.h"
namespace tensorflow {
void NcclGatherer::Run(StatusCallback done) {
auto* compute_stream = col_ctx_->op_ctx->op_device_context()->stream();
auto* gpu_info = col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
const int num_global_devices = col_params_->group.group_size;
const int num_local_devices = col_params_->instance.num_devices_per_task.at(
col_params_->instance.task_names[col_params_->default_rank]);
string nccl_collective_key =
NcclCollectiveKey(col_ctx_->exec_key, col_ctx_->step_id);
auto participant = absl::make_unique<NcclManager::Participant>(
compute_stream->parent(), compute_stream, gpu_info->event_mgr,
gpu_info->gpu_id, col_ctx_->input, col_ctx_->output,
col_params_->default_rank, std::move(done));
VLOG(1) << "NcclGatherer calling NcclManager::AddToAllGather num_tasks "
<< col_params_->group.num_tasks << " current task "
<< col_params_->instance.task_names[col_params_->default_rank]
<< " num local devices " << num_local_devices
<< " num global devices " << num_global_devices << " rank "
<< col_params_->default_rank << " device " << col_ctx_->device_name
<< " instance " << col_params_->instance.instance_key;
NcclManager::instance()->AddToAllGather(
std::move(participant),
{std::move(nccl_collective_key), num_local_devices, num_global_devices,
col_params_->group.runtime_details.communicator_key,
/*source_rank=*/-1});
{
// `WaitForDependencies` may block if the collective instances on which this
// op depends have not yet launched. When this function returns, this op is
// ready to go.
profiler::TraceMe activity("WaitForDependencies",
profiler::TraceMeLevel::kInfo);
col_ctx_->col_exec->WaitForDependencies(*col_params_);
NcclManager::instance()->SignalMultiNodeReady(nccl_collective_key);
}
{
// When all devices at this worker have called `SignalMultiNodeReady`, the
// `NcclManager` will enqueue the NCCL kernel on the NCCL stream. Thus the
// implementation of `Launched` keeps track of the number of devices that
// have launched.
profiler::TraceMe activity("Schedule", profiler::TraceMeLevel::kInfo);
col_ctx_->col_exec->Launched(*col_params_);
}
}
REGISTER_COLLECTIVE(NcclGather, NcclGatherer);
} // namespace tensorflow
#endif // GOOGLE_CUDA