blob: c59464cad70019ba15357b7839933dbe7ebedfc6 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include <utility>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace {
static string CollectiveKey(OpKernelContext* ctx, int32_t group_key,
int32_t instance_key) {
return strings::StrCat(group_key, ":", instance_key, ":",
ctx->frame_iter().frame_id, ":",
ctx->frame_iter().iter_id);
}
static std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c,
const string& name,
NodeDef* sub_node) {
std::unique_ptr<OpKernel> k;
if (name.empty() || name == "Id") return k;
sub_node->set_name(name);
sub_node->set_op(name);
Status status;
k = CreateOpKernel(c->device_type(), c->device(),
c->device()->GetAllocator(AllocatorAttributes()),
*sub_node, c->graph_def_version(), &status);
if (!status.ok()) {
c->CtxFailureWithWarning(errors::Internal(
"Failed to build OpKernel for ", name, " : ", status.error_message()));
}
return k;
}
class CollectiveOpV1Kernel : public AsyncOpKernel {
public:
explicit CollectiveOpV1Kernel(OpKernelConstruction* c)
: AsyncOpKernel(c), name_(name()), col_params_(new CollectiveParams()) {}
~CollectiveOpV1Kernel() override { col_params_->Unref(); }
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
name_),
done);
const CancellationToken token =
c->cancellation_manager()->get_cancellation_token();
const bool already_cancelled =
!c->cancellation_manager()->RegisterCallback(token, [col_exec]() {
// We must call StartAbort() within the callback. StartAbort() relies
// on resources that may be deallocated if all execution of a graph is
// finished.
col_exec->StartAbort(errors::Cancelled("op cancelled"));
});
OP_REQUIRES_ASYNC(c, !already_cancelled,
errors::Cancelled("op cancelled ", name_), done);
auto deregister_and_done = [c, token, done = std::move(done)]() {
// Once done() is called, StartAbort() won't have any effect, so we
// don't need to block on the deregistration. Also StartAbort() may call
// done() and DeregisterCallback may deadlock.
c->cancellation_manager()->TryDeregisterCallback(token);
done();
};
ComputeAsyncImpl(c, col_exec, std::move(deregister_and_done));
}
// A string encoding instance, frame and iter to be handed off to
// the implementation for use in generating RecvBuf keys.
string GetCollectiveKey(OpKernelContext* c) {
return CollectiveKey(c, col_params_->group.group_key,
col_params_->instance.instance_key);
}
// Returns false if calling invocation of ComputeAsync should return
// immediately.
bool CanProceedWithCompute(OpKernelContext* c, CollectiveExecutor* col_exec,
const DoneCallback& done) {
if (col_params_->group.group_size > col_params_->group.members.size()) {
// This is the first invocation: Finish initializing col_params_.
// Schedule the `CompleteParamsAsync` call on a work queue that can handle
// blocking work because it's not guaranteed that this call cannot block.
c->collective_executor()->RunClosure([this, c, col_exec, done]() {
VLOG(1) << "CollectiveOpKernel CompleteParams for collective "
<< col_params_->name << " device " << c->device()->name()
<< " group " << col_params_->group.group_key << " instance "
<< col_params_->instance.instance_key;
col_exec->CompleteParamsAsync(
c->device()->attributes(), col_params_, c->cancellation_manager(),
[this, c, done](const Status& s) {
if (s.ok()) {
col_params_->instance.impl_details.dependencies = dependencies_;
ComputeAsync(c, done);
} else {
c->SetStatus(s);
done();
}
});
});
return false;
}
return true;
}
protected:
virtual void ComputeAsyncImpl(OpKernelContext* c,
CollectiveExecutor* col_exec,
DoneCallback done) = 0;
string name_;
CollectiveParams* col_params_;
std::vector<int32> dependencies_;
};
class CollectiveGatherOpKernel : public CollectiveOpV1Kernel {
public:
explicit CollectiveGatherOpKernel(OpKernelConstruction* c)
: CollectiveOpV1Kernel(c) {
col_params_->instance.type = GATHER_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
OP_REQUIRES(
c, col_params_->group.group_size > 0,
errors::InvalidArgument("group_size must be positive integer but got ",
col_params_->group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_->instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_->instance.impl_details.timeout_seconds));
const NodeDef& real_node = c->def();
col_params_->name = strings::StrCat(real_node.name(), ": Gather");
col_params_->group.device_type = c->device_type();
}
protected:
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
DoneCallback done) override {
auto output_shape = c->input(0).shape();
output_shape.set_dim(
0, output_shape.dim_size(0) * col_params_->group.group_size);
col_params_->instance.shape = output_shape;
// Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing
// GPU kernel.
if (c->mutable_output(0) == nullptr) {
// Allocate the output tensor.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(
c, c->allocate_output(0, col_params_->instance.shape, &output), done);
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
auto actual_done = [c, col_params = col_params_, done](const Status& s) {
VLOG(1) << "CollectiveGatherOpKernel ExecuteAsync done for collective "
<< c->op_kernel().name() << " device " << c->device()->name()
<< " group " << col_params->group.group_key << " instance "
<< col_params->instance.instance_key << " status " << s;
col_params->Unref();
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
};
VLOG(1) << "CollectiveGatherOpKernel ExecuteAsync start for collective "
<< col_params_->name << " device " << c->device()->name()
<< " group " << col_params_->group.group_key << " instance "
<< col_params_->instance.instance_key;
col_params_->Ref();
col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveGatherOpKernel);
};
REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_CPU),
CollectiveGatherOpKernel);
REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_GPU),
CollectiveGatherOpKernel);
class CollectiveReduceOpKernel : public CollectiveOpV1Kernel {
public:
explicit CollectiveReduceOpKernel(OpKernelConstruction* c)
: CollectiveOpV1Kernel(c) {
col_params_->instance.type = REDUCTION_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
OP_REQUIRES(
c, col_params_->group.group_size > 0,
errors::InvalidArgument("group_size must be positive integer but got ",
col_params_->group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
OP_REQUIRES_OK(
c, c->GetAttr("subdiv_offsets",
&col_params_->instance.impl_details.subdiv_offsets));
string merge_op_name;
OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
if (merge_op_name == "Max") {
merge_op_name = "Maximum";
} else if (merge_op_name == "Min") {
merge_op_name = "Minimum";
}
string final_op_name;
OP_REQUIRES_OK(c, c->GetAttr("final_op", &final_op_name));
OP_REQUIRES(c, final_op_name == "Id" || final_op_name == "Div",
errors::InvalidArgument(
"final_op must be one of {\"Id\", \"Div\"} but got ",
final_op_name));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
OP_REQUIRES_OK(c, c->GetAttr("wait_for", &dependencies_));
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_->instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_->instance.impl_details.timeout_seconds));
VLOG(2) << "CollectiveReduce instance "
<< col_params_->instance.instance_key << " merge_op "
<< merge_op_name << " final_op " << final_op_name
<< " communication_hint "
<< col_params_->instance.impl_details.communication_hint
<< " timeout "
<< col_params_->instance.impl_details.timeout_seconds;
const NodeDef& real_node = c->def();
col_params_->name = strings::StrCat(real_node.name(), ": Reduce(",
merge_op_name, ",", final_op_name, ")");
col_params_->group.device_type = c->device_type();
// Find the OpKernels by name, type and device type.
NodeDef sub_node;
// The merge_op takes two inputs
sub_node.add_input(real_node.input(0));
sub_node.add_input(real_node.input(0));
sub_node.set_device(real_node.device());
SetAttrValue(col_params_->instance.data_type,
&(*sub_node.mutable_attr())["T"]);
merge_op_ = BuildOpKernel(c, merge_op_name, &sub_node);
final_op_ = BuildOpKernel(c, final_op_name, &sub_node);
col_params_->merge_op = merge_op_.get();
col_params_->final_op = final_op_.get();
}
protected:
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
DoneCallback done) override {
// Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing
// GPU kernel.
if (c->mutable_output(0) == nullptr) {
// Allocate the output tensor, trying to reuse the input.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(c,
c->forward_input_or_allocate_output(
{0}, 0, c->input(0).shape(), &output),
done);
col_params_->instance.shape = c->input(0).shape();
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
auto actual_done = [c, col_params = col_params_, done](const Status& s) {
VLOG(1) << "CollectiveReduceOpKernel ExecuteAsync done for collective "
<< c->op_kernel().name() << " device " << c->device()->name()
<< " group " << col_params->group.group_key << " instance "
<< col_params->instance.instance_key << " status " << s;
col_params->Unref();
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
};
VLOG(1) << "CollectiveReduceOpKernel ExecuteAsync start for collective "
<< col_params_->name << " device " << c->device()->name()
<< " group " << col_params_->group.group_key << " instance "
<< col_params_->instance.instance_key;
col_params_->Ref();
col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
}
private:
std::unique_ptr<OpKernel> merge_op_;
std::unique_ptr<OpKernel> final_op_;
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveReduceOpKernel);
};
REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_CPU),
CollectiveReduceOpKernel);
REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_GPU),
CollectiveReduceOpKernel);
class CollectiveBcastSendOpKernel : public CollectiveOpV1Kernel {
public:
explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c)
: CollectiveOpV1Kernel(c) {
col_params_->instance.type = BROADCAST_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
OP_REQUIRES(
c, col_params_->group.group_size > 0,
errors::InvalidArgument("group_size must be positive integer but got ",
col_params_->group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_->instance.shape));
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_->instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_->instance.impl_details.timeout_seconds));
col_params_->is_source = true;
col_params_->instance.impl_details.subdiv_offsets = {0};
col_params_->name =
strings::StrCat(name(), ": Broadcast(", col_params_->is_source, ")");
col_params_->group.device_type = c->device_type();
}
protected:
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
DoneCallback done) override {
// Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing
// GPU kernel.
if (c->mutable_output(0) == nullptr) {
// Allocate the output tensor, trying to reuse the input.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(c,
c->forward_input_or_allocate_output(
{0}, 0, col_params_->instance.shape, &output),
done);
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
OP_REQUIRES_ASYNC(
c, col_params_->instance.shape.IsSameSize(c->input(0).shape()),
errors::Internal("Declared shape of op ", col_params_->name,
" does not match shape of input"),
done);
auto actual_done = [c, col_params = col_params_, done](const Status& s) {
VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync done for collective "
<< c->op_kernel().name() << " device " << c->device()->name()
<< " group " << col_params->group.group_key << " instance "
<< col_params->instance.instance_key << " status " << s;
col_params->Unref();
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
};
VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync start for collective "
<< col_params_->name << " device " << c->device()->name()
<< " group " << col_params_->group.group_key << " instance "
<< col_params_->instance.instance_key;
col_params_->Ref();
col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastSendOpKernel);
};
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_CPU),
CollectiveBcastSendOpKernel);
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_DEFAULT),
CollectiveBcastSendOpKernel);
class CollectiveBcastRecvOpKernel : public CollectiveOpV1Kernel {
public:
explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c)
: CollectiveOpV1Kernel(c) {
col_params_->instance.type = BROADCAST_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_->group.group_size));
OP_REQUIRES(
c, col_params_->group.group_size > 0,
errors::InvalidArgument("group_size must be positive integer but got ",
col_params_->group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_->group.group_key));
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_->instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_->instance.shape));
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_->instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_->instance.impl_details.timeout_seconds));
col_params_->is_source = false;
col_params_->instance.impl_details.subdiv_offsets = {0};
col_params_->name =
strings::StrCat(name(), ": Broadcast(", col_params_->is_source, ")");
col_params_->group.device_type = c->device_type();
}
protected:
void ComputeAsyncImpl(OpKernelContext* c, CollectiveExecutor* col_exec,
DoneCallback done) override {
// Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing
// GPU kernel.
if (c->mutable_output(0) == nullptr) {
// No input, so must allocate output.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(
c, c->allocate_output(0, col_params_->instance.shape, &output), done);
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
auto actual_done = [c, col_params = col_params_, done](const Status& s) {
VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync done for collective "
<< c->op_kernel().name() << " device " << c->device()->name()
<< " group " << col_params->group.group_key << " instance_key "
<< col_params->instance.instance_key << " status " << s;
col_params->Unref();
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
};
VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync start for collective "
<< col_params_->name << " device " << c->device()->name()
<< " group " << col_params_->group.group_key << " instance "
<< col_params_->instance.instance_key;
col_params_->Ref();
col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastRecvOpKernel);
};
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_CPU),
CollectiveBcastRecvOpKernel);
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_DEFAULT),
CollectiveBcastRecvOpKernel);
class CollectiveOpV2Kernel : public AsyncOpKernel {
public:
explicit CollectiveOpV2Kernel(OpKernelConstruction* c)
: AsyncOpKernel(c), name_(name()), device_type_(DEVICE_DEFAULT) {
OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
device_type_ = c->device_type();
}
protected:
// Fills common parts of CollectiveParams according to the Op, *excluding
// output_shape*. Kernels should further work on the CollectiveParams if they
// need to set additional fields.
Status FillCollectiveParams(CollectiveParams* col_params,
CollectiveType collective_type,
const Tensor& group_size, const Tensor& group_key,
const Tensor& instance_key) {
if (group_size.dims() > 0) {
return errors::Internal("Unexpected dimensions on input group_size, got ",
group_size.shape().DebugString());
}
if (group_key.dims() > 0) {
return errors::Internal("Unexpected dimensions on input group_key, got ",
group_key.shape().DebugString());
}
if (instance_key.dims() > 0) {
return errors::Internal(
"Unexpected dimensions on input instance_key, got ",
instance_key.shape().DebugString());
}
col_params->name = name_;
col_params->group.device_type = device_type_;
col_params->group.group_size = group_size.unaligned_flat<int32>()(0);
if (col_params->group.group_size <= 0) {
return errors::InvalidArgument(
"group_size must be positive integer but got ",
col_params->group.group_size);
}
col_params->group.group_key = group_key.unaligned_flat<int32>()(0);
col_params->instance.type = collective_type;
col_params->instance.instance_key = instance_key.unaligned_flat<int32>()(0);
col_params->instance.data_type = data_type_;
col_params->instance.impl_details.communication_hint = communication_hint_;
col_params->instance.impl_details.timeout_seconds = timeout_seconds_;
return Status::OK();
}
// Runs a collective. The output tensor must be allocated before calling this
// method. col_params must live until done is called.
void Run(OpKernelContext* c, CollectiveParams* col_params,
DoneCallback done) {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
name_),
done);
// Resolve the collective params.
// Schedule the `CompleteParamsAsync` call on a work queue that can handle
// blocking work because it's not guaranteed that this call cannot block.
c->collective_executor()->RunClosure([c, done = std::move(done), col_params,
col_exec]() {
VLOG(1) << "Collective CompleteParams for " << col_params->name
<< " device " << c->device()->name() << " group "
<< col_params->group.group_key << " instance "
<< col_params->instance.instance_key;
col_exec->CompleteParamsAsync(
c->device()->attributes(), col_params, c->cancellation_manager(),
[c, done = std::move(done), col_params, col_exec](const Status& s) {
if (s.ok()) {
auto actual_done = [c, col_params,
done = std::move(done)](const Status& s) {
VLOG(1) << "Collective ExecuteAsync done for "
<< col_params->name << " device " << c->device()->name()
<< " group " << col_params->group.group_key
<< " instance " << col_params->instance.instance_key
<< " status " << s;
if (!s.ok()) {
c->SetStatus(s);
}
done();
};
VLOG(1) << "Collective ExecuteAsync start for "
<< col_params->name << " device " << c->device()->name()
<< " group " << col_params->group.group_key
<< " instance " << col_params->instance.instance_key;
col_exec->ExecuteAsync(
c, col_params,
CollectiveKey(c, col_params->group.group_key,
col_params->instance.instance_key),
actual_done);
} else {
c->SetStatus(s);
done();
}
});
});
}
protected:
string name_;
DataType data_type_ = DT_INVALID;
string communication_hint_;
float timeout_seconds_ = 0;
DeviceType device_type_;
};
class CollectiveReduceV2OpKernel : public CollectiveOpV2Kernel {
public:
explicit CollectiveReduceV2OpKernel(OpKernelConstruction* c)
: CollectiveOpV2Kernel(c) {
string merge_op_name;
OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
if (merge_op_name == "Max") {
merge_op_name = "Maximum";
} else if (merge_op_name == "Min") {
merge_op_name = "Minimum";
}
string final_op_name;
OP_REQUIRES_OK(c, c->GetAttr("final_op", &final_op_name));
OP_REQUIRES_OK(
c, c->GetAttr("max_subdivs_per_device", &max_subdivs_per_device_));
// Prepare OpKernels for reduction and final operations.
// The merge_op takes two inputs
NodeDef sub_node;
sub_node.add_input(c->def().input(0));
sub_node.add_input(c->def().input(0));
sub_node.set_device(c->def().device());
SetAttrValue(data_type_, &(*sub_node.mutable_attr())["T"]);
merge_op_ = BuildOpKernel(c, merge_op_name, &sub_node);
final_op_ = BuildOpKernel(c, final_op_name, &sub_node);
name_ = strings::StrCat(c->def().name(), ": ReduceV2(", merge_op_name, ",",
final_op_name, ")");
VLOG(2) << "CollectiveReduceV2 " << this << " name " << name_
<< " communication_hint " << communication_hint_;
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
auto col_params = new CollectiveParams();
auto done_with_cleanup = [col_params, done = std::move(done)]() {
done();
col_params->Unref();
};
OP_REQUIRES_OK_ASYNC(c,
FillCollectiveParams(col_params, REDUCTION_COLLECTIVE,
/*group_size*/ c->input(1),
/*group_key*/ c->input(2),
/*instance_key*/ c->input(3)),
done);
col_params->instance.shape = c->input(0).shape();
col_params->merge_op = merge_op_.get();
col_params->final_op = final_op_.get();
VLOG(1) << "CollectiveReduceV2 group_size " << col_params->group.group_size
<< " group_key " << col_params->group.group_key << " instance_key "
<< col_params->instance.instance_key;
// Allocate the output tensor, trying to reuse the input.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(c,
c->forward_input_or_allocate_output(
{0}, 0, col_params->instance.shape, &output),
done_with_cleanup);
Run(c, col_params, std::move(done_with_cleanup));
}
private:
int max_subdivs_per_device_;
std::unique_ptr<OpKernel> merge_op_;
std::unique_ptr<OpKernel> final_op_;
};
REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2").Device(DEVICE_CPU),
CollectiveReduceV2OpKernel);
REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2")
.Device(DEVICE_DEFAULT)
.HostMemory("group_size")
.HostMemory("group_key")
.HostMemory("instance_key"),
CollectiveReduceV2OpKernel);
class CollectiveGatherV2OpKernel : public CollectiveOpV2Kernel {
public:
explicit CollectiveGatherV2OpKernel(OpKernelConstruction* c)
: CollectiveOpV2Kernel(c) {
name_ = strings::StrCat(c->def().name(), ": GatherV2");
VLOG(2) << "CollectiveGatherV2 " << this << " name " << name_
<< " communication_hint " << communication_hint_;
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
auto col_params = new CollectiveParams();
auto done_with_cleanup = [col_params, done = std::move(done)]() {
done();
col_params->Unref();
};
OP_REQUIRES_OK_ASYNC(c,
FillCollectiveParams(col_params, GATHER_COLLECTIVE,
/*group_size*/ c->input(1),
/*group_key*/ c->input(2),
/*instance_key*/
c->input(3)),
done_with_cleanup);
auto output_shape = c->input(0).shape();
output_shape.set_dim(
0, output_shape.dim_size(0) * col_params->group.group_size);
col_params->instance.shape = output_shape;
VLOG(1) << "CollectiveGatherV2 group_size " << col_params->group.group_size
<< " group_key " << col_params->group.group_key << " instance_key "
<< col_params->instance.instance_key;
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(
c, c->allocate_output(0, col_params->instance.shape, &output),
done_with_cleanup);
Run(c, col_params, std::move(done_with_cleanup));
}
};
REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2").Device(DEVICE_CPU),
CollectiveGatherV2OpKernel);
REGISTER_KERNEL_BUILDER(Name("CollectiveGatherV2")
.Device(DEVICE_DEFAULT)
.HostMemory("group_size")
.HostMemory("group_key")
.HostMemory("instance_key"),
CollectiveGatherV2OpKernel);
class CollectiveBcastSendV2OpKernel : public CollectiveOpV2Kernel {
public:
explicit CollectiveBcastSendV2OpKernel(OpKernelConstruction* c)
: CollectiveOpV2Kernel(c) {
const bool is_source = true;
name_ = strings::StrCat(name(), ": Broadcast(", is_source, ")");
}
protected:
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
auto col_params = new CollectiveParams();
auto done_with_cleanup = [col_params, done = std::move(done)]() {
done();
col_params->Unref();
};
OP_REQUIRES_OK_ASYNC(c,
FillCollectiveParams(col_params, BROADCAST_COLLECTIVE,
/*group_size*/ c->input(1),
/*group_key*/ c->input(2),
/*instance_key*/ c->input(3)),
done_with_cleanup);
col_params->is_source = true;
col_params->instance.shape = c->input(0).shape();
// Add a default value for subdiv offsets, which is the same as the default
// value in the V1 op's attribute.
col_params->instance.impl_details.subdiv_offsets.push_back(0);
VLOG(1) << "CollectiveBcastSendV2 group_size "
<< col_params->group.group_size << " group_key "
<< col_params->group.group_key << " instance_key "
<< col_params->instance.instance_key;
// Allocate the output tensor, trying to reuse the input.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(c,
c->forward_input_or_allocate_output(
{0}, 0, col_params->instance.shape, &output),
done_with_cleanup);
Run(c, col_params, std::move(done_with_cleanup));
}
};
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSendV2").Device(DEVICE_CPU),
CollectiveBcastSendV2OpKernel);
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSendV2")
.Device(DEVICE_DEFAULT)
.HostMemory("group_size")
.HostMemory("group_key")
.HostMemory("instance_key"),
CollectiveBcastSendV2OpKernel);
class CollectiveBcastRecvV2OpKernel : public CollectiveOpV2Kernel {
public:
explicit CollectiveBcastRecvV2OpKernel(OpKernelConstruction* c)
: CollectiveOpV2Kernel(c) {
const bool is_source = false;
name_ = strings::StrCat(name(), ": Broadcast(", is_source, ")");
}
protected:
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
auto col_params = new CollectiveParams();
auto done_with_cleanup = [col_params, done = std::move(done)]() {
done();
col_params->Unref();
};
OP_REQUIRES_OK_ASYNC(c,
FillCollectiveParams(col_params, BROADCAST_COLLECTIVE,
/*group_size*/ c->input(0),
/*group_key*/ c->input(1),
/*instance_key*/ c->input(2)),
done_with_cleanup);
col_params->is_source = false;
TensorShape output_shape;
OP_REQUIRES_OK_ASYNC(c, tensor::MakeShape(c->input(3), &output_shape),
done_with_cleanup);
col_params->instance.shape = output_shape;
// Add a default value for subdiv offsets, which is the same as the default
// value in the V1 op's attribute.
col_params->instance.impl_details.subdiv_offsets.push_back(0);
VLOG(1) << "CollectiveBcastRecvV2 group_size "
<< col_params->group.group_size << " group_key "
<< col_params->group.group_key << " instance_key "
<< col_params->instance.instance_key;
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(
c, c->allocate_output(0, col_params->instance.shape, &output),
done_with_cleanup);
Run(c, col_params, std::move(done_with_cleanup));
}
};
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecvV2").Device(DEVICE_CPU),
CollectiveBcastRecvV2OpKernel);
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecvV2")
.Device(DEVICE_DEFAULT)
.HostMemory("group_size")
.HostMemory("group_key")
.HostMemory("instance_key")
.HostMemory("shape"),
CollectiveBcastRecvV2OpKernel);
/*
* Resource for holding group for CollectiveOps.
* This resource is returned from CollectiveInitializeCommunicatorOpKernel
* It generates next instance key for the group for each collective operation.
*/
class CollectiveGroupResource : public ResourceBase {
public:
CollectiveGroupResource(int32 group_key, int32 rank, int32 group_size,
string communication_hint, float timeout_seconds)
: group_key_(group_key),
rank_(rank),
group_size_(group_size),
communication_hint_(communication_hint),
timeout_seconds_(timeout_seconds) {}
std::string DebugString() const override {
return absl::StrFormat(
"Collective Group with group_key = %d, group_size = %d, rank = %d",
group_key_, group_size_, rank_);
}
int get_next_instance_key() {
return instance_key_.fetch_add(1, std::memory_order_relaxed);
}
int32 group_key() const { return group_key_; }
int32 rank() const { return rank_; }
int32 group_size() const { return group_size_; }
string communication_hint() const { return communication_hint_; }
float timeout_seconds() const { return timeout_seconds_; }
private:
int32 group_key_, rank_, group_size_;
string communication_hint_;
std::atomic<int> instance_key_{0};
float timeout_seconds_ = 0;
};
class CollectiveInitializeCommunicatorOpKernel : public AsyncOpKernel {
public:
explicit CollectiveInitializeCommunicatorOpKernel(OpKernelConstruction* c)
: AsyncOpKernel(c), device_type_(DEVICE_DEFAULT) {
OP_REQUIRES_OK(c, c->GetAttr("communication_hint", &communication_hint_));
OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
device_type_ = c->device_type();
}
Status CheckInputs(Tensor group_size_t, Tensor group_key_t) {
if (group_size_t.dims() > 0) {
return errors::Internal(
"Unexpected dimensions on input group_size. "
"It shoulbe a scalar, got tensor with shape ",
group_size_t.shape().DebugString());
}
if (group_key_t.dims() > 0) {
return errors::Internal("Unexpected dimensions on input group_key, got ",
group_key_t.shape().DebugString());
}
auto group_size = group_size_t.unaligned_flat<int32>()(0);
if (group_size <= 0) {
return errors::InvalidArgument(
"group_size must be positive integer but got ", group_size);
}
return Status::OK();
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
auto group_key_t = c->input(0);
auto rank_t = c->input(1);
auto group_size_t = c->input(2);
OP_REQUIRES_OK_ASYNC(c, CheckInputs(group_size_t, group_key_t), done);
auto group_size = group_size_t.unaligned_flat<int32>()(0);
auto group_key = group_key_t.unaligned_flat<int32>()(0);
auto rank = rank_t.unaligned_flat<int32>()(0);
ResourceHandle resource_handle =
MakeResourceHandle<CollectiveGroupResource>(
c, "collective_op_group", absl::StrFormat("%d", group_key));
Tensor* output_handle = nullptr;
OP_REQUIRES_OK_ASYNC(
c, c->allocate_output(0, TensorShape({}), &output_handle), done);
output_handle->scalar<ResourceHandle>()() = resource_handle;
CollectiveGroupResource* resource = new CollectiveGroupResource(
group_key, rank, group_size, this->communication_hint_,
this->timeout_seconds_);
OP_REQUIRES_OK_ASYNC(
c,
CreateResource<CollectiveGroupResource>(c, resource_handle, resource),
done);
auto group_params = new CollGroupParams();
group_params->device_type = device_type_;
group_params->group_size = resource->group_size();
group_params->group_key = resource->group_key();
auto* col_exec = c->collective_executor();
c->collective_executor()->RunClosure([c, done = std::move(done),
group_params, col_exec]() {
VLOG(1) << "Collective Group initialization for "
<< " device " << c->device()->name() << " group "
<< group_params->group_key;
col_exec->CompleteGroupAsync(
c->device()->attributes(), group_params, c->cancellation_manager(),
[c, done = std::move(done), group_params](const Status& s) {
if (s.ok()) {
VLOG(1) << "Collective Group initialization done for device "
<< c->device()->name() << " group "
<< group_params->group_key << " status " << s;
} else {
c->SetStatus(s);
}
delete group_params;
done();
});
});
}
private:
string communication_hint_;
DeviceType device_type_;
float timeout_seconds_ = 0;
};
REGISTER_KERNEL_BUILDER(
Name("CollectiveInitializeCommunicator").Device(DEVICE_CPU),
CollectiveInitializeCommunicatorOpKernel);
REGISTER_KERNEL_BUILDER(Name("CollectiveInitializeCommunicator")
.Device(DEVICE_GPU)
.HostMemory("group_size")
.HostMemory("group_key")
.HostMemory("rank"),
CollectiveInitializeCommunicatorOpKernel);
class CollectiveOpV3Kernel : public AsyncOpKernel {
public:
explicit CollectiveOpV3Kernel(OpKernelConstruction* c)
: AsyncOpKernel(c), name_(name()), device_type_(DEVICE_DEFAULT) {
OP_REQUIRES_OK(c, c->GetAttr("T", &data_type_));
if (c->HasAttr("timeout_seconds")) {
OP_REQUIRES_OK(c, c->GetAttr("timeout_seconds", &timeout_seconds_));
} else {
timeout_seconds_ = -1;
}
device_type_ = c->device_type();
}
protected:
// Fills common parts of CollectiveParams according to the Op, *excluding
// output_shape*. Kernels should further work on the CollectiveParams if they
// need to set additional fields.
Status FillCollectiveParams(CollectiveParams* col_params,
const Tensor& group_assignment,
CollectiveType collective_type,
CollectiveGroupResource* resource) {
int64 group_id;
int64 group_size;
if (group_assignment.NumElements() == 0) {
// No group assignments, perform collective as a single group.
group_id = 0;
group_size = resource->group_size();
} else {
return errors::Unimplemented("Group assignments are not supported yet.");
}
// Construct instance key with format:
// <11 bits for group><21 bits for atomic incremented instance key>
int32 instance_key = group_id << 21 | resource->get_next_instance_key();
col_params->name = name_;
col_params->group.device_type = device_type_;
col_params->group.group_size = group_size;
col_params->group.group_key = resource->group_key();
col_params->instance.type = collective_type;
col_params->instance.instance_key = instance_key;
col_params->instance.data_type = data_type_;
col_params->instance.impl_details.communication_hint =
resource->communication_hint();
col_params->instance.impl_details.timeout_seconds =
timeout_seconds_ > 0 ? resource->timeout_seconds() : timeout_seconds_;
col_params->run_group_initialization = false;
return Status::OK();
}
// Runs a collective. The output tensor must be allocated before calling this
// method. col_params must live until done is called.
void Run(OpKernelContext* c, CollectiveParams* col_params,
DoneCallback done) {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
name_),
done);
// Resolve the collective params.
// Schedule the `CompleteParamsAsync` call on a work queue that can handle
// blocking work because it's not guaranteed that this call cannot block.
col_exec->RunClosure([c, done = std::move(done), col_params, col_exec]() {
VLOG(1) << "Collective CompleteParams for " << col_params->name
<< " device " << c->device()->name() << " group "
<< col_params->group.group_key << " instance "
<< col_params->instance.instance_key;
col_exec->CompleteParamsAsync(
c->device()->attributes(), col_params, c->cancellation_manager(),
[c, done = std::move(done), col_params, col_exec](const Status& s) {
if (s.ok()) {
auto actual_done = [c, col_params,
done = std::move(done)](const Status& s) {
VLOG(1) << "Collective ExecuteAsync done for "
<< col_params->name << " device " << c->device()->name()
<< " group " << col_params->group.group_key
<< " instance " << col_params->instance.instance_key
<< " status " << s;
if (!s.ok()) {
c->SetStatus(s);
}
done();
};
VLOG(1) << "Collective ExecuteAsync start for "
<< col_params->name << " device " << c->device()->name()
<< " group " << col_params->group.group_key
<< " instance " << col_params->instance.instance_key;
col_exec->ExecuteAsync(
c, col_params,
CollectiveKey(c, col_params->group.group_key,
col_params->instance.instance_key),
actual_done);
} else {
c->SetStatus(s);
done();
}
});
});
}
protected:
string name_;
DataType data_type_ = DT_INVALID;
DeviceType device_type_;
float timeout_seconds_ = 0;
};
class CollectiveReduceV3OpKernel : public CollectiveOpV3Kernel {
public:
explicit CollectiveReduceV3OpKernel(OpKernelConstruction* c)
: CollectiveOpV3Kernel(c) {
string reduction;
OP_REQUIRES_OK(c, c->GetAttr("reduction", &reduction));
if (reduction == "Max") {
reduction = "Maximum";
} else if (reduction == "Min") {
reduction = "Minimum";
}
// Prepare OpKernels for reduction and final operations.
// The merge_op takes two inputs
NodeDef sub_node;
sub_node.add_input(c->def().input(0));
sub_node.add_input(c->def().input(0));
sub_node.set_device(c->def().device());
SetAttrValue(data_type_, &(*sub_node.mutable_attr())["T"]);
merge_op_ = BuildOpKernel(c, reduction, &sub_node);
final_op_ = BuildOpKernel(c, "Id", &sub_node);
name_ = strings::StrCat(c->def().name(), ": ReduceV3(", reduction, ")");
VLOG(2) << "CollectiveReduceV3 " << this << " name " << name_;
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
auto col_params = new CollectiveParams();
auto done_with_cleanup = [col_params, done = std::move(done)]() {
done();
col_params->Unref();
};
core::RefCountPtr<CollectiveGroupResource> resource;
OP_REQUIRES_OK_ASYNC(c, LookupResource(c, HandleFromInput(c, 1), &resource),
done);
Tensor group_assignment = c->input(2);
OP_REQUIRES_OK_ASYNC(
c,
FillCollectiveParams(col_params, group_assignment, REDUCTION_COLLECTIVE,
resource.get()),
done);
col_params->instance.shape = c->input(0).shape();
col_params->merge_op = merge_op_.get();
col_params->final_op = final_op_.get();
VLOG(1) << "CollectiveReduceV3 group_size " << col_params->group.group_size
<< " group_key " << col_params->group.group_key << " instance_key "
<< col_params->instance.instance_key;
// Allocate the output tensor, trying to reuse the input.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(c,
c->forward_input_or_allocate_output(
{0}, 0, col_params->instance.shape, &output),
done_with_cleanup);
Run(c, col_params, std::move(done_with_cleanup));
}
private:
std::unique_ptr<OpKernel> merge_op_;
std::unique_ptr<OpKernel> final_op_;
};
REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV3").Device(DEVICE_CPU),
CollectiveReduceV3OpKernel);
REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV3").Device(DEVICE_GPU),
CollectiveReduceV3OpKernel);
class CollectiveAllToAllV3OpKernel : public CollectiveOpV3Kernel {
public:
explicit CollectiveAllToAllV3OpKernel(OpKernelConstruction* c)
: CollectiveOpV3Kernel(c) {
name_ = strings::StrCat(c->def().name(), ": AllToAllV3");
VLOG(2) << "CollectiveAllToAllV3 " << this << " name " << name_;
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
auto col_params = new CollectiveParams();
auto done_with_cleanup = [col_params, done = std::move(done)]() {
done();
col_params->Unref();
};
core::RefCountPtr<CollectiveGroupResource> resource;
OP_REQUIRES_OK_ASYNC(c, LookupResource(c, HandleFromInput(c, 1), &resource),
done);
Tensor group_assignment = c->input(2);
OP_REQUIRES_OK_ASYNC(
c,
FillCollectiveParams(col_params, group_assignment,
ALL_TO_ALL_COLLECTIVE, resource.get()),
done);
col_params->instance.shape = c->input(0).shape();
VLOG(1) << "CollectiveAllToAll group_size " << col_params->group.group_size
<< " group_key " << col_params->group.group_key << " instance_key "
<< col_params->instance.instance_key;
// Allocate the output tensor, trying to reuse the input.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(c,
c->forward_input_or_allocate_output(
{0}, 0, col_params->instance.shape, &output),
done_with_cleanup);
Run(c, col_params, std::move(done_with_cleanup));
}
};
REGISTER_KERNEL_BUILDER(Name("CollectiveAllToAllV3").Device(DEVICE_CPU),
CollectiveAllToAllV3OpKernel);
REGISTER_KERNEL_BUILDER(Name("CollectiveAllToAllV3").Device(DEVICE_GPU),
CollectiveAllToAllV3OpKernel);
} // namespace
} // namespace tensorflow