blob: 3f9402843965f9a04a66e27081062a2a342aca96 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
#include <map>
#include <memory>
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow {
namespace eager {
void EagerClusterFunctionLibraryRuntime::Instantiate(
const string& function_name, const FunctionLibraryDefinition& lib_def,
AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::LocalHandle* handle,
FunctionLibraryRuntime::DoneCallback done) {
const tensorflow::AttrTypeMap* attr_types;
bool is_function = false;
Status s;
s = tensorflow::AttrTypeMapForOp(function_name.c_str(), &attr_types,
&is_function);
if (!s.ok()) {
done(s);
return;
}
if (!is_function) {
done(errors::Internal(function_name, " is not a function."));
return;
}
auto target = options.target;
auto* released_op =
new EagerOperation(ctx_, function_name.c_str(), is_function, attr_types);
s = released_op->SetDeviceName(target.c_str());
if (!s.ok()) {
done(s);
return;
}
VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << target
<< " (this: " << this << ")";
core::RefCountPtr<eager::EagerClient> eager_client;
Device* device;
s = ctx_->FindDeviceFromName(target.c_str(), &device);
if (!s.ok()) {
done(s);
return;
}
s = ctx_->GetClient(device, &eager_client);
if (!s.ok()) {
done(s);
return;
}
if (eager_client == nullptr) {
done(errors::InvalidArgument("Could not find eager client for target: ",
target));
return;
}
const FunctionLibraryDefinition& func_lib_def =
options.lib_def ? *options.lib_def : lib_def;
EnqueueRequest* request = new EnqueueRequest;
EnqueueResponse* response = new EnqueueResponse;
request->set_context_id(context_id_);
RegisterFunctionOp* register_function =
request->add_queue()->mutable_register_function();
*register_function->mutable_function_def() =
*func_lib_def.Find(function_name);
register_function->set_is_component_function(true);
*register_function->mutable_library() =
func_lib_def.ReachableDefinitions(register_function->function_def())
.ToProto();
eager_client->EnqueueAsync(request, response,
[this, request, response, handle, released_op,
target, eager_client = eager_client.get(),
done](const Status& s) {
{
mutex_lock l(mu_);
*handle = function_data_.size();
function_data_.emplace_back(
target, eager_client,
absl::WrapUnique(released_op));
}
done(s);
delete request;
delete response;
});
}
void EagerClusterFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
done(errors::Unimplemented("Not implemented"));
}
void EagerClusterFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::LocalHandle handle,
std::vector<eager::RemoteTensorHandle>* args,
FunctionLibraryRuntime::DoneCallback done) {
FunctionData* function_data = nullptr;
{
mutex_lock l(mu_);
DCHECK_LE(handle, function_data_.size());
function_data = &function_data_[handle];
}
EagerClient* eager_client = function_data->eager_client;
if (eager_client == nullptr) {
done(errors::Internal("Could not find eager client"));
return;
}
Device* device;
Status s = ctx_->FindDeviceFromName(function_data->target.c_str(), &device);
if (!s.ok()) {
done(errors::Internal("Failed to get device"));
return;
}
EagerOperation* op = function_data->op.get();
if (!opts.op_id.has_value()) {
done(
errors::Internal("op_id is not set for remote function: ", op->Name()));
}
eager::EnqueueRequest* request = new eager::EnqueueRequest;
request->set_context_id(context_id_);
eager::Operation* remote_op = request->add_queue()->mutable_operation();
for (size_t i = 0; i < args->size(); ++i) {
remote_op->add_inputs()->Swap(&(*args)[i]);
}
// The remote component function should use the same op_id as its parent
// multi-device function's in order to get the global unqiue op_id generated
// by the master context.
remote_op->set_id(opts.op_id.value());
remote_op->set_is_function(true);
remote_op->set_is_component_function(true);
remote_op->set_func_step_id(opts.step_id);
remote_op->set_name(op->Name());
op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
remote_op->set_device(function_data->target);
for (auto handle : op->Inputs()) {
handle->Ref();
}
// TODO(yujingzhang): Use RemoteExecuteNode once we enable async execution.
EnqueueResponse* response = new EnqueueResponse;
eager_client->EnqueueAsync(request, response,
[op, request, response, done](const Status& s) {
for (auto handle : op->Inputs()) {
handle->Unref();
}
done(s);
delete request;
delete response;
});
}
void EagerClusterFunctionLibraryRuntime::CleanUp(
uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
FunctionLibraryRuntime::DoneCallback done) {
FunctionData* function_data = nullptr;
{
mutex_lock l(mu_);
DCHECK_LE(handle, function_data_.size());
function_data = &function_data_[handle];
}
EagerClient* eager_client = function_data->eager_client;
if (eager_client == nullptr) {
done(errors::Internal("Could not find eager client"));
return;
}
eager::EnqueueRequest* request = new eager::EnqueueRequest;
EnqueueResponse* response = new EnqueueResponse;
request->set_context_id(context_id_);
CleanupFunctionOp* cleanup_function =
request->add_queue()->mutable_cleanup_function();
cleanup_function->set_step_id(step_id);
eager_client->StreamingEnqueueAsync(
request, response, [request, response, done](const Status& status) {
done(status);
delete request;
delete response;
});
}
DistributedFunctionLibraryRuntime* CreateClusterFLR(
const uint64 context_id, EagerContext* ctx, WorkerSession* worker_session) {
if (ctx->LazyCopyFunctionRemoteInputs()) {
return new EagerClusterFunctionLibraryRuntime(
context_id, ctx, worker_session->remote_device_mgr());
} else {
return worker_session->cluster_flr();
}
}
} // namespace eager
} // namespace tensorflow