blob: 0230852d0824669f761abb7fdcb743f3d2590ca1 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace {
static string CollectiveKey(OpKernelContext* ctx, int32 group_key,
int32 instance_key) {
return strings::StrCat(group_key, ":", instance_key, ":",
ctx->frame_iter().frame_id, ":",
ctx->frame_iter().iter_id);
}
static std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c,
const string& name,
NodeDef* sub_node) {
std::unique_ptr<OpKernel> k;
if (name.empty() || name == "Id") return k;
sub_node->set_name(name);
sub_node->set_op(name);
Status status;
k = CreateOpKernel(c->device_type(), c->device(),
c->device()->GetAllocator(AllocatorAttributes()),
*sub_node, c->graph_def_version(), &status);
if (!status.ok()) {
c->CtxFailureWithWarning(errors::Internal(
"Failed to build OpKernel for ", name, " : ", status.error_message()));
}
return k;
}
class CollectiveOpKernel : public AsyncOpKernel {
public:
explicit CollectiveOpKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {}
// A string encoding instance, frame and iter to be handed off to
// the implementation for use in generating RecvBuf keys.
string GetCollectiveKey(OpKernelContext* c) {
return CollectiveKey(c, col_params_.group.group_key,
col_params_.instance.instance_key);
}
// Returns false if calling invocation of ComputeAsync should return
// immediately.
bool CanProceedWithCompute(OpKernelContext* c, CollectiveExecutor* col_exec,
const DoneCallback& done) {
if (col_params_.group.group_size >
col_params_.instance.device_names.size()) {
// This is the first invocation: Finish initializing col_params_.
// Schedule the `CompleteParamsAsync` call on a work queue that can handle
// blocking work because it's not guaranteed that this call cannot block.
c->collective_executor()->RunClosure([this, c, done, col_exec]() {
VLOG(1) << "CollectiveOpKernel CompleteParams for collective "
<< col_params_.name << " device " << c->device()->name()
<< " group " << col_params_.group.group_key << " instance "
<< col_params_.instance.instance_key;
col_exec->CompleteParamsAsync(
c->device()->name(), &col_params_, c->cancellation_manager(),
[this, c, done](const Status& s) {
if (s.ok()) {
col_params_.instance.impl_details.dependencies = dependencies_;
ComputeAsync(c, done);
} else {
c->SetStatus(s);
done();
}
});
});
return false;
}
return true;
}
CollectiveParams col_params_;
std::vector<int32> dependencies_;
};
class CollectiveGatherOpKernel : public CollectiveOpKernel {
public:
explicit CollectiveGatherOpKernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) {
col_params_.instance.type = GATHER_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES(
c, col_params_.group.group_size > 0,
errors::InvalidArgument("group_size must be positive integer but got ",
col_params_.group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_.instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_.instance.impl_details.timeout_seconds));
const NodeDef& real_node = c->def();
col_params_.name = strings::StrCat(real_node.name(), ": Gather");
col_params_.group.device_type = c->device_type();
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
auto output_shape = c->input(0).shape();
output_shape.set_dim(
0, output_shape.dim_size(0) * col_params_.group.group_size);
col_params_.instance.shape = output_shape;
// Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing
// GPU kernel.
if (c->mutable_output(0) == nullptr) {
// Allocate the output tensor.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(
c, c->allocate_output(0, col_params_.instance.shape, &output), done);
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
auto actual_done = [c, group_key = col_params_.group.group_key,
instance_key = col_params_.instance.instance_key,
done](const Status& s) {
VLOG(1) << "CollectiveGatherOpKernel ExecuteAsync done for collective "
<< c->op_kernel().name() << " device " << c->device()->name()
<< " group " << group_key << " instance " << instance_key
<< " status " << s;
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
};
VLOG(1) << "CollectiveGatherOpKernel ExecuteAsync start for collective "
<< col_params_.name << " device " << c->device()->name()
<< " group " << col_params_.group.group_key << " instance "
<< col_params_.instance.instance_key;
col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveGatherOpKernel);
};
REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_CPU),
CollectiveGatherOpKernel);
REGISTER_KERNEL_BUILDER(Name("CollectiveGather").Device(DEVICE_GPU),
CollectiveGatherOpKernel);
class CollectiveReduceOpKernel : public CollectiveOpKernel {
public:
explicit CollectiveReduceOpKernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) {
col_params_.instance.type = REDUCTION_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES(
c, col_params_.group.group_size > 0,
errors::InvalidArgument("group_size must be positive integer but got ",
col_params_.group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
OP_REQUIRES_OK(
c, c->GetAttr("subdiv_offsets",
&col_params_.instance.impl_details.subdiv_offsets));
string merge_op_name;
OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
if (merge_op_name == "Max") {
merge_op_name = "Maximum";
} else if (merge_op_name == "Min") {
merge_op_name = "Minimum";
}
string final_op_name;
OP_REQUIRES_OK(c, c->GetAttr("final_op", &final_op_name));
OP_REQUIRES(c, final_op_name == "Id" || final_op_name == "Div",
errors::InvalidArgument(
"final_op must be one of {\"Id\", \"Div\"} but got ",
final_op_name));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
OP_REQUIRES_OK(c, c->GetAttr("wait_for", &dependencies_));
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_.instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_.instance.impl_details.timeout_seconds));
VLOG(2) << "CollectiveReduce instance " << col_params_.instance.instance_key
<< " merge_op " << merge_op_name << " final_op " << final_op_name
<< " communication_hint "
<< col_params_.instance.impl_details.communication_hint
<< " timeout " << col_params_.instance.impl_details.timeout_seconds;
const NodeDef& real_node = c->def();
col_params_.name = strings::StrCat(real_node.name(), ": Reduce(",
merge_op_name, ",", final_op_name, ")");
col_params_.group.device_type = c->device_type();
// Find the OpKernels by name, type and device type.
NodeDef sub_node;
// The merge_op takes two inputs
sub_node.add_input(real_node.input(0));
sub_node.add_input(real_node.input(0));
sub_node.set_device(real_node.device());
SetAttrValue(col_params_.instance.data_type,
&(*sub_node.mutable_attr())["T"]);
col_params_.merge_op = BuildOpKernel(c, merge_op_name, &sub_node);
col_params_.final_op = BuildOpKernel(c, final_op_name, &sub_node);
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
// Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing
// GPU kernel.
if (c->mutable_output(0) == nullptr) {
// Allocate the output tensor, trying to reuse the input.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(c,
c->forward_input_or_allocate_output(
{0}, 0, c->input(0).shape(), &output),
done);
col_params_.instance.shape = c->input(0).shape();
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
auto actual_done = [c, group_key = col_params_.group.group_key,
instance_key = col_params_.instance.instance_key,
done](const Status& s) {
VLOG(1) << "CollectiveReduceOpKernel ExecuteAsync done for collective "
<< c->op_kernel().name() << " device " << c->device()->name()
<< " group " << group_key << " instance " << instance_key
<< " status " << s;
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
};
VLOG(1) << "CollectiveReduceOpKernel ExecuteAsync start for collective "
<< col_params_.name << " device " << c->device()->name()
<< " group " << col_params_.group.group_key << " instance "
<< col_params_.instance.instance_key;
col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveReduceOpKernel);
};
REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_CPU),
CollectiveReduceOpKernel);
REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_GPU),
CollectiveReduceOpKernel);
class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
public:
explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) {
col_params_.instance.type = BROADCAST_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES(
c, col_params_.group.group_size > 0,
errors::InvalidArgument("group_size must be positive integer but got ",
col_params_.group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape));
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_.instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_.instance.impl_details.timeout_seconds));
col_params_.is_source = true;
col_params_.instance.impl_details.subdiv_offsets = {0};
col_params_.name =
strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")");
col_params_.group.device_type = c->device_type();
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
// Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing
// GPU kernel.
if (c->mutable_output(0) == nullptr) {
// Allocate the output tensor, trying to reuse the input.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(c,
c->forward_input_or_allocate_output(
{0}, 0, col_params_.instance.shape, &output),
done);
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
OP_REQUIRES_ASYNC(
c, col_params_.instance.shape.IsSameSize(c->input(0).shape()),
errors::Internal("Declared shape of op ", col_params_.name,
" does not match shape of input"),
done);
auto actual_done = [c, group_key = col_params_.group.group_key,
instance_key = col_params_.instance.instance_key,
done](const Status& s) {
VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync done for collective "
<< c->op_kernel().name() << " device " << c->device()->name()
<< " group " << group_key << " instance " << instance_key
<< " status " << s;
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
};
VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync start for collective "
<< col_params_.name << " device " << c->device()->name()
<< " group " << col_params_.group.group_key << " instance "
<< col_params_.instance.instance_key;
col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastSendOpKernel);
};
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_CPU),
CollectiveBcastSendOpKernel);
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_GPU),
CollectiveBcastSendOpKernel);
class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
public:
explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c)
: CollectiveOpKernel(c) {
col_params_.instance.type = BROADCAST_COLLECTIVE;
OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
OP_REQUIRES(
c, col_params_.group.group_size > 0,
errors::InvalidArgument("group_size must be positive integer but got ",
col_params_.group.group_size));
OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape));
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_.instance.impl_details.communication_hint));
OP_REQUIRES_OK(
c, c->GetAttr("timeout_seconds",
&col_params_.instance.impl_details.timeout_seconds));
col_params_.is_source = false;
col_params_.instance.impl_details.subdiv_offsets = {0};
col_params_.name =
strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")");
col_params_.group.device_type = c->device_type();
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
// Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing
// GPU kernel.
if (c->mutable_output(0) == nullptr) {
// No input, so must allocate output.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(
c, c->allocate_output(0, col_params_.instance.shape, &output), done);
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
auto actual_done = [c, group_key = col_params_.group.group_key,
instance_key = col_params_.instance.instance_key,
done](const Status& s) {
VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync done for collective "
<< c->op_kernel().name() << " device " << c->device()->name()
<< " group " << group_key << " instance_key " << instance_key
<< " status " << s;
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
};
VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync start for collective "
<< col_params_.name << " device " << c->device()->name()
<< " group " << col_params_.group.group_key << " instance "
<< col_params_.instance.instance_key;
col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastRecvOpKernel);
};
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_CPU),
CollectiveBcastRecvOpKernel);
REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_GPU),
CollectiveBcastRecvOpKernel);
class CollectiveReduceV2OpKernel : public AsyncOpKernel {
public:
explicit CollectiveReduceV2OpKernel(OpKernelConstruction* c)
: AsyncOpKernel(c) {
col_params_ = std::make_shared<CollectiveParams>();
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_->instance.data_type));
string merge_op_name;
OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
if (merge_op_name == "Max") {
merge_op_name = "Maximum";
} else if (merge_op_name == "Min") {
merge_op_name = "Minimum";
}
string final_op_name;
OP_REQUIRES_OK(c, c->GetAttr("final_op", &final_op_name));
OP_REQUIRES_OK(
c, c->GetAttr("communication_hint",
&col_params_->instance.impl_details.communication_hint));
// Prepare OpKernels for reduction and final operations.
// The merge_op takes two inputs
NodeDef sub_node;
sub_node.add_input(c->def().input(0));
sub_node.add_input(c->def().input(0));
sub_node.set_device(c->def().device());
SetAttrValue(col_params_->instance.data_type,
&(*sub_node.mutable_attr())["T"]);
col_params_->merge_op = BuildOpKernel(c, merge_op_name, &sub_node);
col_params_->final_op = BuildOpKernel(c, final_op_name, &sub_node);
col_params_->name = strings::StrCat(c->def().name(), ": ReduceV2(",
merge_op_name, ",", final_op_name, ")");
col_params_->group.device_type = c->device_type();
// Add a default value for subdiv offsets, which is the same as the default
// value in the V1 op's attribute.
col_params_->instance.impl_details.subdiv_offsets.push_back(0);
VLOG(2) << "CollectiveReduceV2 " << this << " name " << col_params_->name
<< " communication_hint "
<< col_params_->instance.impl_details.communication_hint;
}
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
CollectiveExecutor* col_exec = c->collective_executor();
OP_REQUIRES_ASYNC(
c, col_exec,
errors::Internal(
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_->name),
done);
const Tensor& input = c->input(0);
const Tensor& group_size = c->input(1);
const Tensor& group_key = c->input(2);
const Tensor& instance_key = c->input(3);
OP_REQUIRES_ASYNC(
c, group_size.dims() == 0,
errors::Internal("Unexpected dimensions on input group_size"), done);
OP_REQUIRES_ASYNC(
c, group_key.dims() == 0,
errors::Internal("Unexpected dimensions on input group_key"), done);
OP_REQUIRES_ASYNC(
c, instance_key.dims() == 0,
errors::Internal("Unexpected dimensions on input instance_key"), done);
auto col_params = std::make_shared<CollectiveParams>();
col_params->name = col_params_->name;
col_params->group.device_type = col_params_->group.device_type;
col_params->group.group_size = group_size.unaligned_flat<int32>()(0);
col_params->group.group_key = group_key.unaligned_flat<int32>()(0);
col_params->instance.type = REDUCTION_COLLECTIVE;
col_params->instance.instance_key = instance_key.unaligned_flat<int32>()(0);
col_params->instance.data_type = col_params_->instance.data_type;
col_params->instance.impl_details.communication_hint =
col_params_->instance.impl_details.communication_hint;
col_params->instance.impl_details.timeout_seconds = 0;
col_params->instance.impl_details.subdiv_offsets =
col_params_->instance.impl_details.subdiv_offsets;
col_params->merge_op = std::move(col_params_->merge_op);
col_params->final_op = std::move(col_params_->final_op);
VLOG(1) << "CollectiveReduceV2 group_size " << col_params->group.group_size
<< " group_key " << col_params->group.group_key << " instance_key "
<< col_params->instance.instance_key;
// Allocate the output tensor, trying to reuse the input.
Tensor* output = nullptr;
OP_REQUIRES_OK_ASYNC(
c, c->forward_input_or_allocate_output({0}, 0, input.shape(), &output),
done);
col_params->instance.shape = input.shape();
// Store the updated params in this OpKernel.
col_params_ = col_params;
// Resolve the collective params.
// Schedule the `CompleteParamsAsync` call on a work queue that can handle
// blocking work because it's not guaranteed that this call cannot block.
c->collective_executor()->RunClosure([c, done = std::move(done), col_params,
col_exec]() {
VLOG(1) << "CollectiveReduceV2 CompleteParams for collective "
<< col_params->name << " device " << c->device()->name()
<< " group " << col_params->group.group_key << " instance "
<< col_params->instance.instance_key;
col_exec->CompleteParamsAsync(
c->device()->name(), col_params.get(), c->cancellation_manager(),
[c, done = std::move(done), col_params, col_exec](const Status& s) {
if (s.ok()) {
auto actual_done = [c, group_key = col_params->group.group_key,
instance_key =
col_params->instance.instance_key,
done = std::move(done)](const Status& s) {
VLOG(1) << "CollectiveReduceV2 ExecuteAsync done for "
"collective "
<< c->op_kernel().name() << " device "
<< c->device()->name() << " group " << group_key
<< " instance " << instance_key << " status " << s;
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
};
VLOG(1) << "CollectiveReduceV2 ExecuteAsync start for "
"collective "
<< col_params->name << " device " << c->device()->name()
<< " group " << col_params->group.group_key
<< " instance " << col_params->instance.instance_key;
col_exec->ExecuteAsync(
c, *col_params,
CollectiveKey(c, col_params->group.group_key,
col_params->instance.instance_key),
actual_done);
} else {
c->SetStatus(s);
done();
}
});
});
}
private:
std::shared_ptr<CollectiveParams> col_params_;
};
REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2").Device(DEVICE_CPU),
CollectiveReduceV2OpKernel);
REGISTER_KERNEL_BUILDER(Name("CollectiveReduceV2").Device(DEVICE_GPU),
CollectiveReduceV2OpKernel);
} // namespace
} // namespace tensorflow