blob: 355185953fe56948a3a77ff905a446677e7f7efb [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
#include <map>
#include <memory>
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow {
namespace eager {
namespace {
void StripDefaultAttributesInRegisterFunctionOp(
RegisterFunctionOp* register_function) {
StripDefaultAttributes(
*OpRegistry::Global(),
register_function->mutable_function_def()->mutable_node_def());
for (auto& function :
*register_function->mutable_library()->mutable_function()) {
StripDefaultAttributes(*OpRegistry::Global(), function.mutable_node_def());
}
}
} // namespace
void EagerClusterFunctionLibraryRuntime::Instantiate(
const string& function_name, const FunctionLibraryDefinition& lib_def,
AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::LocalHandle* handle,
FunctionLibraryRuntime::DoneCallback done) {
auto target = options.target;
auto released_op = std::make_unique<EagerOperation>(ctx_);
Status s =
released_op->Reset(function_name.c_str(), target.c_str(), true, nullptr);
if (!s.ok()) {
done(s);
return;
}
if (!released_op->is_function()) {
done(errors::Internal(function_name, " is not a function."));
return;
}
VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << target
<< " (this: " << this << ")";
core::RefCountPtr<eager::EagerClient> eager_client;
s = ctx_->GetClient(target, &eager_client);
if (!s.ok()) {
done(s);
return;
}
if (eager_client == nullptr) {
done(errors::InvalidArgument("Could not find eager client for target: ",
target));
return;
}
const FunctionLibraryDefinition& func_lib_def =
options.lib_def ? *options.lib_def : lib_def;
auto request = std::make_shared<EnqueueRequest>();
auto response = std::make_shared<EnqueueResponse>();
request->set_context_id(context_id_);
RegisterFunctionOp* register_function =
request->add_queue()->mutable_register_function();
*register_function->mutable_function_def() =
*func_lib_def.Find(function_name);
register_function->set_is_component_function(true);
*register_function->mutable_library() =
func_lib_def.ReachableDefinitions(register_function->function_def())
.ToProto();
StripDefaultAttributesInRegisterFunctionOp(register_function);
const absl::optional<std::vector<int>>& ret_indices = options.ret_indices;
eager_client->EnqueueAsync(
/*call_opts=*/nullptr, request.get(), response.get(),
[this, request, response, handle, released_op = released_op.release(),
target, ret_indices, eager_client = eager_client.get(),
done](const Status& s) {
{
mutex_lock l(mu_);
*handle = function_data_.size();
function_data_.emplace_back(target, ret_indices, eager_client,
absl::WrapUnique(released_op));
}
done(s);
});
}
void EagerClusterFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
std::vector<FunctionArg> function_args;
for (const auto& tensor : args) {
function_args.push_back(tensor);
}
std::vector<FunctionRet>* function_rets = new std::vector<FunctionRet>;
Run(opts, handle, function_args, function_rets,
[rets, function_rets, done = std::move(done)](const Status& s) {
Status status = s;
if (status.ok()) {
for (const auto& t : *function_rets) {
if (t.index() == 0) {
rets->push_back(absl::get<Tensor>(t));
} else {
status.Update(
errors::Internal("Expect a Tensor as a remote function "
"output but got a TensorShape."));
break;
}
}
}
delete function_rets;
done(status);
});
}
void EagerClusterFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle,
gtl::ArraySlice<FunctionArg> args, std::vector<FunctionRet>* rets,
FunctionLibraryRuntime::DoneCallback done) {
FunctionData* function_data = nullptr;
{
mutex_lock l(mu_);
DCHECK_LE(handle, function_data_.size());
function_data = &function_data_[handle];
}
EagerClient* eager_client = function_data->eager_client.get();
if (eager_client == nullptr) {
done(errors::Internal("Could not find eager client"));
return;
}
EagerOperation* op = function_data->op.get();
if (!op->Inputs().empty()) {
done(errors::Internal("Inputs should not be set during instantiation."));
return;
}
auto request = std::make_shared<RunComponentFunctionRequest>();
auto response = std::make_shared<RunComponentFunctionResponse>();
request->set_context_id(context_id_);
eager::Operation* remote_op = request->mutable_operation();
if (function_data->ret_indices.has_value()) {
for (const int ret_index : function_data->ret_indices.value()) {
request->add_output_num(ret_index);
}
}
for (const auto& arg : args) {
if (arg.index() == 0) {
absl::get<Tensor>(arg).AsProtoTensorContent(
remote_op->add_op_inputs()->mutable_tensor());
} else {
remote_op->add_op_inputs()->mutable_remote_handle()->Swap(
absl::get<RemoteTensorHandle*>(arg));
}
}
// The remote component function should use the same op_id as its parent
// multi-device function's in order to get the global unique op_id generated
// by the master context.
if (opts.op_id.has_value()) {
remote_op->set_id(opts.op_id.value());
} else {
remote_op->set_id(kInvalidOpId);
}
remote_op->set_is_function(true);
remote_op->set_is_component_function(true);
remote_op->set_func_step_id(opts.step_id);
remote_op->set_name(op->Name());
op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
remote_op->set_device(function_data->target);
CancellationManager* cm = opts.cancellation_manager;
CancellationToken token = 0;
auto call_opts = std::make_shared<CallOptions>();
call_opts->SetTimeout(
ctx_->session_options().config.operation_timeout_in_ms());
if (cm != nullptr) {
token = cm->get_cancellation_token();
const bool already_cancelled = !cm->RegisterCallback(
token,
[call_opts, request, response, done]() { call_opts->StartCancel(); });
if (already_cancelled) {
done(errors::Cancelled("EagerClusterFunctionLibraryRuntime::Run"));
return;
}
}
// Execute component function on remote worker using RunComponentFunction RPC.
// Different from executing remote functions with Enqueue, this method runs
// a function on remote worker without tying up a thread (i.e., pure
// asynchronously).
eager_client->RunComponentFunctionAsync(
call_opts.get(), request.get(), response.get(),
[request, response, rets, call_opts, cm, token,
done = std::move(done)](const Status& s) {
if (cm != nullptr) {
cm->TryDeregisterCallback(token);
}
if (!s.ok()) {
done(s);
return;
}
if (!response->shape().empty() && !response->tensor().empty()) {
done(errors::Internal(
"Both shape and tensor are specified in the same response"));
return;
}
for (const auto& shape : response->shape()) {
rets->push_back(shape);
}
for (const auto& tensor_proto : response->tensor()) {
Tensor t;
if (t.FromProto(tensor_proto)) {
rets->push_back(std::move(t));
} else {
done(errors::Internal("Could not convert tensor proto: ",
tensor_proto.DebugString()));
return;
}
}
done(OkStatus());
});
}
void EagerClusterFunctionLibraryRuntime::CleanUp(
uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
FunctionLibraryRuntime::DoneCallback done) {
FunctionData* function_data = nullptr;
{
mutex_lock l(mu_);
DCHECK_LE(handle, function_data_.size());
function_data = &function_data_[handle];
}
EagerClient* eager_client = function_data->eager_client.get();
if (eager_client == nullptr) {
done(errors::Internal("Could not find eager client"));
return;
}
auto request = std::make_shared<EnqueueRequest>();
auto response = std::make_shared<EnqueueResponse>();
request->set_context_id(context_id_);
CleanupFunctionOp* cleanup_function =
request->add_queue()->mutable_cleanup_function();
cleanup_function->set_step_id(step_id);
// StreamingEnqueueAsync could be blocking when streaming RPC is disabled.
// CleanUp() needs to be non-blocking since it would be invoked inside the
// enqueue done callback of Run(). So we don't use StreamingEnqueueAsync here.
eager_client->EnqueueAsync(
/*call_opts=*/nullptr, request.get(), response.get(),
[request, response, done](const Status& status) { done(status); });
}
DistributedFunctionLibraryRuntime* CreateClusterFLR(
const uint64 context_id, EagerContext* ctx, WorkerSession* worker_session) {
return new EagerClusterFunctionLibraryRuntime(
context_id, ctx, worker_session->remote_device_mgr());
}
} // namespace eager
} // namespace tensorflow