blob: fc6e8671e53f6030f6a2ed946165b7ce7e8097f6 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include "grpcpp/channel.h"
#include "grpcpp/create_channel.h"
#include "grpcpp/generic/generic_stub.h"
#include "grpcpp/impl/codegen/client_context.h"
#include "grpcpp/impl/codegen/server_context.h"
#include "grpcpp/impl/codegen/status.h"
#include "grpcpp/security/credentials.h"
#include "grpcpp/server_builder.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
// Needed for encoding and decoding ResourceDeleter Variant.
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include "tensorflow/core/data/dataset_utils.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/threadpool.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/struct.pb.h"
#include "tensorflow/distribute/experimental/rpc/kernels/grpc_credentials.h"
#include "tensorflow/distribute/experimental/rpc/kernels/grpc_rpc_service.h"
#include "tensorflow/distribute/experimental/rpc/proto/tf_rpc_service.pb.h"
namespace tensorflow {
namespace rpc {
// Register a function to local built in server or RPC server
class RpcServerRegisterOp : public OpKernel {
public:
explicit RpcServerRegisterOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override;
private:
NameAttrList func_;
StructuredValue output_specs_;
StructuredValue input_specs_;
TF_DISALLOW_COPY_AND_ASSIGN(RpcServerRegisterOp);
};
// Create a server resource to store registered functions
class RpcServerOp : public OpKernel {
public:
explicit RpcServerOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override;
private:
TF_DISALLOW_COPY_AND_ASSIGN(RpcServerOp);
};
// Start GRPC server with registered methods
class RpcServerStartOp : public OpKernel {
public:
explicit RpcServerStartOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override;
private:
TF_DISALLOW_COPY_AND_ASSIGN(RpcServerStartOp);
};
// Create a client resource to store registered functions.
class RpcClientOp : public AsyncOpKernel {
public:
explicit RpcClientOp(OpKernelConstruction* ctx);
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
private:
std::string name_;
bool list_registered_methods_;
TF_DISALLOW_COPY_AND_ASSIGN(RpcClientOp);
};
// Remote RPC using client handle passed and returns a future Resource handle to
// get Status and value.
class RpcCallOp : public OpKernel {
public:
explicit RpcCallOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override;
private:
TF_DISALLOW_COPY_AND_ASSIGN(RpcCallOp);
};
// Remote Check Status Op waits till the RPC issued by Call Op is finished.
class RpcCheckStatusOp : public AsyncOpKernel {
public:
explicit RpcCheckStatusOp(OpKernelConstruction* ctx);
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
private:
TF_DISALLOW_COPY_AND_ASSIGN(RpcCheckStatusOp);
};
// Op to get response output after RPC Call.
class RpcGetValueOp : public AsyncOpKernel {
public:
explicit RpcGetValueOp(OpKernelConstruction* ctx);
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
private:
TF_DISALLOW_COPY_AND_ASSIGN(RpcGetValueOp);
};
class DeleteRpcFutureResourceOp : public OpKernel {
public:
explicit DeleteRpcFutureResourceOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {}
protected:
void Compute(OpKernelContext* ctx) override {
const ResourceHandle& handle = ctx->input(0).flat<ResourceHandle>()(0);
// The resource is guaranteed to exist because the variant tensor
// wrapping the deleter is provided as an unused input to this op, which
// guarantees that it has not run yet.
OP_REQUIRES_OK(ctx, ctx->resource_manager()->Delete(handle));
}
};
struct FunctionMetadata {
FunctionLibraryRuntime::Handle handle;
FunctionLibraryRuntime* lib;
std::vector<Tensor> captured_inputs;
StructuredValue input_specs;
StructuredValue output_specs;
};
class FunctionRegistry {
public:
std::string DebugString() const {
mutex_lock l(mu_);
std::string debug_string = "Registered methods: [";
debug_string.append(absl::StrJoin(
registered_methods_, ", ",
[](std::string* out, const auto& pair) { return pair.first; }));
debug_string.append("]");
return debug_string;
}
tensorflow::Status Register(const std::string& method,
FunctionLibraryRuntime* lib,
FunctionLibraryRuntime::Handle fn_handle,
std::vector<Tensor> captured_inputs,
const StructuredValue& input_specs,
const StructuredValue& output_specs) {
mutex_lock l(mu_);
FunctionMetadata fn_metadata;
fn_metadata.handle = fn_handle;
fn_metadata.lib = lib;
fn_metadata.captured_inputs = std::move(captured_inputs);
fn_metadata.input_specs = input_specs;
fn_metadata.output_specs = output_specs;
auto result = registered_methods_.insert(
std::pair<std::string, FunctionMetadata>(method, fn_metadata));
if (!result.second) {
return tensorflow::errors::InvalidArgument(
absl::StrCat(method, " is already registered."));
}
return tensorflow::Status::OK();
}
tensorflow::Status LookUp(const std::string& method,
FunctionMetadata* output) const {
mutex_lock l(mu_);
auto it = registered_methods_.find(method);
if (it == registered_methods_.end()) {
return tensorflow::errors::InvalidArgument(
absl::StrCat(method, " is not registered."));
}
*output = it->second;
return tensorflow::Status::OK();
}
const gtl::FlatMap<std::string, FunctionMetadata>& List() const {
return registered_methods_;
}
private:
mutable mutex mu_;
gtl::FlatMap<std::string, FunctionMetadata> registered_methods_
TF_GUARDED_BY(mu_);
};
class RpcServiceImpl : public grpc::RpcService::Service {
public:
explicit RpcServiceImpl(const FunctionRegistry& registry)
: registry_(registry) {}
::grpc::Status Call(::grpc::ServerContext* context,
const CallRequest* request,
CallResponse* response) override {
const auto& method_name = request->method();
FunctionLibraryRuntime::Options opts;
FunctionMetadata fn_metadata;
auto status = registry_.LookUp(method_name, &fn_metadata);
FunctionLibraryRuntime::Handle handle = fn_metadata.handle;
FunctionLibraryRuntime* fn_lib = fn_metadata.lib;
std::vector<Tensor> captured_inputs =
std::move(fn_metadata.captured_inputs);
if (!status.ok()) {
return ToGrpcStatus(status);
}
std::vector<Tensor> args;
for (const auto& t : request->input_tensors()) {
Tensor tensor;
if (tensor.FromProto(t)) {
args.push_back(std::move(tensor));
} else {
return ::grpc::Status(::grpc::StatusCode::INVALID_ARGUMENT,
"Failed to parse input tensor from proto.");
}
}
// Add captured args as well
for (const auto& t : captured_inputs) {
args.push_back(std::move(t));
}
std::vector<Tensor>* rets = new std::vector<Tensor>;
Notification notification;
fn_lib->Run(opts, handle, args, rets,
[rets, response, &notification, &status](const Status& st) {
status = st;
if (status.ok()) {
for (size_t i = 0; i < rets->size(); ++i) {
auto t = response->add_output_tensors();
(*rets)[i].AsProtoField(t);
}
}
delete rets;
notification.Notify();
});
notification.WaitForNotification();
return ToGrpcStatus(status);
}
::grpc::Status List(::grpc::ServerContext* context,
const rpc::ListRequest* request,
rpc::ListResponse* response) override {
auto methods = registry_.List();
for (auto it : methods) {
auto* registered_method = response->add_registered_methods();
registered_method->set_method(it.first);
*registered_method->mutable_output_specs() = it.second.output_specs;
*registered_method->mutable_input_specs() = it.second.input_specs;
}
return ::grpc::Status(::grpc::Status::OK);
}
private:
const FunctionRegistry& registry_;
};
class RpcServer : public ResourceBase {
public:
explicit RpcServer(std::string server_address)
: server_address_(server_address),
server_(nullptr),
server_started_(false) {
service_ = std::make_unique<RpcServiceImpl>(registry_);
}
~RpcServer() override {
if (server_) {
LOG(INFO) << "Shutting down server listening on: " << server_address_;
server_->Shutdown();
}
}
std::string DebugString() const override {
return absl::StrCat("RpcServer resource with ", registry_.DebugString());
}
tensorflow::Status Register(const std::string& method,
FunctionLibraryRuntime* lib,
FunctionLibraryRuntime::Handle fn_handle,
std::vector<Tensor> captured_inputs,
const StructuredValue& input_specs,
const StructuredValue& output_specs) {
mutex_lock m(mu_);
if (server_started_) {
return tensorflow::errors::FailedPrecondition(
"All methods must be registered before starting the server. Method "
"registration after starting the server is not supported.");
}
return registry_.Register(method, lib, fn_handle, captured_inputs,
input_specs, output_specs);
}
void StartServer() {
mutex_lock l(mu_);
::grpc::ServerBuilder builder;
std::shared_ptr<::grpc::ServerCredentials> creds =
GetDefaultServerCredentials();
builder.AddListeningPort(server_address_, creds);
builder.RegisterService(service_.get());
server_ = builder.BuildAndStart();
LOG(INFO) << "Server listening on: " << server_address_;
server_started_ = true;
}
private:
FunctionRegistry registry_;
std::unique_ptr<RpcServiceImpl> service_;
std::string server_address_;
std::unique_ptr<::grpc::Server> server_;
bool server_started_ TF_GUARDED_BY(mu_);
mutex mu_;
};
class GrpcPollingThread {
public:
explicit GrpcPollingThread(std::string thread_name) {
// Thread name can only have alpha numeric characters. Remove special
// characters from input thread_name.
thread_name.erase(
std::remove_if(thread_name.begin(), thread_name.end(),
[](auto const c) -> bool { return !std::isalnum(c); }),
thread_name.end());
thread_.reset(Env::Default()->StartThread(
ThreadOptions(), absl::StrCat("GrpcPollingThread", thread_name),
[this]() {
void* tag;
bool ok;
while (completion_queue_.Next(&tag, &ok)) {
GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
callback_tag->OnCompleted(ok);
}
}));
}
~GrpcPollingThread() {
completion_queue_.Shutdown();
thread_.reset();
}
::grpc::CompletionQueue* completion_queue() { return &completion_queue_; }
private:
::grpc::CompletionQueue completion_queue_;
std::unique_ptr<Thread> thread_;
};
class RpcClient : public ResourceBase {
public:
explicit RpcClient(std::string address, std::string resource_name,
int64 timeout_in_ms)
: server_address_(address),
thread_(resource_name),
timeout_in_ms_(timeout_in_ms) {
std::shared_ptr<::grpc::ChannelCredentials> creds =
GetDefaultChannelCredentials();
channel_ = ::grpc::CreateChannel(address, creds);
stub_ = std::make_unique<::grpc::GenericStub>(channel_);
cq_ = thread_.completion_queue();
callback_threadpool_ = std::make_unique<thread::ThreadPool>(
Env::Default(), ThreadOptions(), "RPC_Client_threadpool", 5,
/*low_latency_hint=*/false, /*allocator=*/nullptr);
}
std::string DebugString() const override {
return absl::StrCat("Rpc client for address: ", server_address_);
}
void CallAsync(const std::string& method_name,
const std::vector<Tensor>& inputs, CallResponse* response,
StatusCallback callback, int64 timeout_in_ms) {
CallRequest request;
request.set_method(method_name);
for (const auto& t : inputs) {
t.AsProtoField(request.add_input_tensors());
}
::grpc::ClientContext context;
// Use per call timeout if specified, otherwise use default client timeout.
int64 timeout = timeout_in_ms > 0 ? timeout_in_ms : timeout_in_ms_;
new RPCState<CallResponse>(
stub_.get(), cq_, "/tensorflow.rpc.RpcService/Call", request, response,
/*done=*/std::move(callback),
/*call_opts=*/nullptr,
/*threadpool=*/callback_threadpool_.get(),
/*fail_fast=*/false, /*timeout_in_ms=*/timeout,
/*max_retries=*/0, /*target=*/nullptr);
}
void ListAsync(rpc::ListResponse* response, StatusCallback callback) {
rpc::ListRequest request;
::grpc::ClientContext context;
// fail_fast=false sets wait_for_ready to true in GRPC call.
// ListAsync is called during Client creation thus, we want to wait till
// server is ready for issuing RPC.
new RPCState<rpc::ListResponse>(
stub_.get(), cq_, "/tensorflow.rpc.RpcService/List", request, response,
/*done=*/std::move(callback),
/*call_opts=*/nullptr,
/*threadpool=*/callback_threadpool_.get(),
/*fail_fast=*/false, /*timeout_in_ms=*/timeout_in_ms_,
/*max_retries=*/0, /*target=*/nullptr);
}
private:
std::shared_ptr<::grpc::Channel> channel_;
std::string server_address_;
std::unique_ptr<::grpc::GenericStub> stub_;
::grpc::CompletionQueue* cq_;
GrpcPollingThread thread_;
std::unique_ptr<thread::ThreadPool> callback_threadpool_;
int64 timeout_in_ms_;
};
class RpcFutureResource : public ResourceBase {
typedef std::function<void(const Status&, const CallResponse&)>
FutureCallBack;
public:
RpcFutureResource() : done_(false) {}
std::string DebugString() const override { return "Wait Resource"; }
void AddDoneCallback(FutureCallBack cb) {
mutex_lock l(mu_);
if (!done_) {
call_backs_.push_back(cb);
} else {
cb(status_, response_);
}
}
void OperationFinished() {
mutex_lock l(mu_);
for (const auto& cb : call_backs_) {
cb(status_, response_);
}
done_ = true;
}
void set_status(Status status) { status_.Update(status); }
Status get_status() { return status_; }
CallResponse* get_response() { return &response_; }
private:
CallResponse response_;
bool done_ TF_GUARDED_BY(mu_);
Status status_;
std::vector<FutureCallBack> call_backs_ TF_GUARDED_BY(mu_);
mutable mutex mu_;
};
Status ExtractServerAddressFromInput(OpKernelContext* ctx,
std::string* address) {
const Tensor* server_address;
auto status = ctx->input("server_address", &server_address);
if (status.ok()) {
*address = server_address->scalar<tstring>()();
}
return status;
}
RpcServerOp::RpcServerOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void RpcServerOp::Compute(OpKernelContext* ctx) {
std::string address = "";
OP_REQUIRES_OK(ctx, ExtractServerAddressFromInput(ctx, &address));
// Create resource handle
AllocatorAttributes attr;
attr.set_on_host(true);
ResourceHandle resource_handle =
MakeResourceHandle<RpcServer>(ctx, "rpc_server", address);
Tensor handle;
OP_REQUIRES_OK(
ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr));
handle.scalar<ResourceHandle>()() = resource_handle;
// Create resource
auto creator = [address](RpcServer** server) {
*server = new RpcServer(address);
return Status::OK();
};
core::RefCountPtr<RpcServer> server;
OP_REQUIRES_OK(ctx, LookupOrCreateResource<RpcServer>(ctx, resource_handle,
&server, creator));
ctx->set_output(0, handle);
}
RpcClientOp::RpcClientOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
OP_REQUIRES_OK(
ctx, ctx->GetAttr("list_registered_methods", &list_registered_methods_));
}
void RpcClientOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
std::string address = "";
OP_REQUIRES_OK_ASYNC(ctx, ExtractServerAddressFromInput(ctx, &address), done);
const Tensor* timeout;
OP_REQUIRES_OK_ASYNC(ctx, ctx->input("timeout_in_ms", &timeout), done);
auto timeout_in_ms = timeout->scalar<int64_t>()();
// Create resource handle
AllocatorAttributes attr;
attr.set_on_host(true);
auto resource_name = absl::StrCat(name_, address);
ResourceHandle resource_handle =
MakeResourceHandle<RpcClient>(ctx, "rpc_client", resource_name);
Tensor handle;
OP_REQUIRES_OK_ASYNC(
ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr),
done);
handle.scalar<ResourceHandle>()() = resource_handle;
// Delete old client handle if exists, to clear old client resource state.
DeleteResource(ctx, resource_handle).IgnoreError();
// Create resource
auto creator = [&address, &resource_name, timeout_in_ms](RpcClient** client) {
*client = new RpcClient(address, resource_name, timeout_in_ms);
return Status::OK();
};
core::RefCountPtr<RpcClient> client;
OP_REQUIRES_OK_ASYNC(
ctx,
LookupOrCreateResource<RpcClient>(ctx, resource_handle, &client, creator),
done);
ctx->set_output(0, handle);
if (!list_registered_methods_) {
Tensor* method_output_t;
OP_REQUIRES_OK_ASYNC(
ctx, ctx->allocate_output(1, TensorShape({}), &method_output_t), done);
method_output_t->scalar<tstring>()() = "";
done();
return;
}
auto* response = new ListResponse();
client->ListAsync(
response, [ctx, response, done](const Status& status) {
if (!status.ok()) {
ctx->SetStatus(status);
} else {
Tensor* method_output_signatures_t;
auto method_output_shape = TensorShape(
{static_cast<int64_t>(response->registered_methods_size())});
OP_REQUIRES_OK_ASYNC(
ctx,
ctx->allocate_output(1, method_output_shape,
&method_output_signatures_t),
done);
auto method_output_signatures =
method_output_signatures_t->vec<tstring>();
for (int i = 0; i < response->registered_methods_size(); ++i) {
method_output_signatures(i) =
response->registered_methods(i).SerializeAsString();
}
}
delete response;
done();
});
}
RpcServerStartOp::RpcServerStartOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void RpcServerStartOp::Compute(OpKernelContext* ctx) {
core::RefCountPtr<RpcServer> server;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &server));
server->StartServer();
ctx->SetStatus(Status::OK());
}
RpcServerRegisterOp::RpcServerRegisterOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
OP_REQUIRES_OK(ctx,
ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, &func_));
std::string output_specs_string;
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_specs", &output_specs_string));
OP_REQUIRES(ctx, output_specs_.ParseFromString(output_specs_string),
tensorflow::errors::InvalidArgument(
"Unable to parse StructuredValue output_spec string: ",
output_specs_string));
std::string input_specs_string;
OP_REQUIRES_OK(ctx, ctx->GetAttr("input_specs", &input_specs_string));
OP_REQUIRES(ctx, input_specs_.ParseFromString(input_specs_string),
tensorflow::errors::InvalidArgument(
"Unable to parse StructuredValue output_spec string: ",
input_specs_string));
}
void RpcServerRegisterOp::Compute(OpKernelContext* ctx) {
FunctionLibraryRuntime* lib = ctx->function_library();
OP_REQUIRES(ctx, lib != nullptr,
errors::Internal("No function library is provided"));
const Tensor* method_name;
OP_REQUIRES_OK(ctx, ctx->input("method_name", &method_name));
std::string method = method_name->scalar<tstring>()();
OpInputList captured_inputs;
OP_REQUIRES_OK(ctx, ctx->input_list("captured_inputs", &captured_inputs));
std::vector<Tensor> captured(captured_inputs.begin(), captured_inputs.end());
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
instantiate_opts.target = ctx->device()->name();
instantiate_opts.lib_def = lib->GetFunctionLibraryDefinition();
// In case captured inputs are on different device.
instantiate_opts.is_multi_device_function = true;
const FunctionDef* fdef =
lib->GetFunctionLibraryDefinition()->Find(func_.name());
OP_REQUIRES(ctx, fdef != nullptr,
errors::Internal("Failed to find function."));
int num_args = fdef->signature().input_arg_size();
const int num_non_captured_inputs = num_args - captured.size();
for (int i = 0; i < num_non_captured_inputs; ++i) {
instantiate_opts.input_devices.push_back(ctx->device()->name());
}
absl::flat_hash_map<string, std::vector<string>> composite_devices;
for (int i = 0; i < captured.size(); ++i) {
if (captured[i].dtype() == DT_RESOURCE) {
instantiate_opts.input_devices.push_back(GetFunctionResourceInputDevice(
captured[i], num_non_captured_inputs + i, *fdef, &composite_devices));
} else {
instantiate_opts.input_devices.push_back(ctx->device()->name());
}
}
for (const auto& it : composite_devices) {
instantiate_opts.composite_devices[it.first] = &it.second;
}
FunctionLibraryRuntime::Handle handle;
OP_REQUIRES_OK(ctx, lib->Instantiate(func_.name(), AttrSlice(&func_.attr()),
instantiate_opts, &handle));
core::RefCountPtr<RpcServer> server;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &server));
OP_REQUIRES_OK(ctx, server->Register(method, lib, handle, std::move(captured),
input_specs_, output_specs_));
}
RpcCallOp::RpcCallOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void RpcCallOp::Compute(OpKernelContext* ctx) {
const Tensor* method_name;
OP_REQUIRES_OK(ctx, ctx->input("method_name", &method_name));
std::string method = method_name->scalar<tstring>()();
const Tensor* timeout;
OP_REQUIRES_OK(ctx, ctx->input("timeout_in_ms", &timeout));
auto timeout_in_ms = timeout->scalar<int64_t>()();
OpInputList arguments;
OP_REQUIRES_OK(ctx, ctx->input_list("args", &arguments));
std::vector<Tensor> args(arguments.begin(), arguments.end());
core::RefCountPtr<RpcClient> client;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client));
ResourceHandle resource_handle = MakeResourceHandle<RpcFutureResource>(
ctx, "rpc_future_resource", absl::StrFormat("%d", random::New64()));
AllocatorAttributes attr;
attr.set_on_host(true);
Tensor handle;
OP_REQUIRES_OK(
ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr));
handle.scalar<ResourceHandle>()() = resource_handle;
// Create resource
auto creator = [](RpcFutureResource** resource) {
*resource = new RpcFutureResource();
return Status::OK();
};
core::RefCountPtr<RpcFutureResource> future_resource;
OP_REQUIRES_OK(ctx, LookupOrCreateResource<RpcFutureResource>(
ctx, resource_handle, &future_resource, creator));
Tensor deleter_t;
OP_REQUIRES_OK(
ctx, ctx->allocate_temp(DT_VARIANT, TensorShape({}), &deleter_t, attr));
deleter_t.scalar<Variant>()() =
ResourceDeleter(resource_handle, ctx->resource_manager());
ctx->set_output(0, handle);
ctx->set_output(1, deleter_t);
CallResponse* response = future_resource->get_response();
auto* future_resource_ptr = future_resource.release();
client->CallAsync(
method, args, response,
[future_resource_ptr](const Status& status) {
future_resource_ptr->set_status(status);
future_resource_ptr->OperationFinished();
future_resource_ptr->Unref();
},
timeout_in_ms);
}
RpcCheckStatusOp::RpcCheckStatusOp(OpKernelConstruction* ctx)
: AsyncOpKernel(ctx) {}
void RpcCheckStatusOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
core::RefCountPtr<RpcFutureResource> future_resource;
auto handle = HandleFromInput(ctx, 0);
{
auto status = LookupResource(ctx, handle, &future_resource);
if (!status.ok()) {
if (errors::IsNotFound(status)) {
ctx->SetStatus(tensorflow::errors::NotFound(
absl::StrCat("Future resource no longer exists. Please make sure "
"resource is not already deleted.")));
done();
return;
} else {
ctx->SetStatus(status);
}
}
}
future_resource->AddDoneCallback(
[ctx, done, handle](const Status& status, const CallResponse& response) {
Tensor error_code(DT_INT64, TensorShape({})),
error_message(DT_STRING, TensorShape({}));
error_code.scalar<int64_t>()() = status.code();
error_message.scalar<tstring>()() = status.error_message();
ctx->set_output(0, error_code);
ctx->set_output(1, error_message);
done();
});
}
RpcGetValueOp::RpcGetValueOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {}
void RpcGetValueOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
core::RefCountPtr<RpcFutureResource> future_resource;
auto handle = HandleFromInput(ctx, 0);
{
auto status = LookupResource(ctx, handle, &future_resource);
if (!status.ok()) {
if (errors::IsNotFound(status)) {
ctx->SetStatus(tensorflow::errors::NotFound(
absl::StrCat("Future resource no longer exists. Please ensure "
"resource is not already deleted.")));
done();
return;
} else {
ctx->SetStatus(status);
}
}
}
future_resource->AddDoneCallback(
[ctx, done, handle](const Status& status, const CallResponse& response) {
if (!status.ok()) {
ctx->SetStatus(status);
} else {
if (ctx->num_outputs() != response.output_tensors().size()) {
ctx->SetStatus(tensorflow::errors::InvalidArgument(absl::StrCat(
"Incorrect number of output types specified.",
ctx->num_outputs(), " ", response.output_tensors().size())));
} else {
int i = 0;
for (const auto& t_proto : response.output_tensors()) {
Tensor t;
if (!t.FromProto(t_proto)) {
ctx->SetStatus(tensorflow::errors::Internal(
absl::StrCat("Invalid Tensor Proto response returned.")));
}
ctx->set_output(i++, std::move(t));
}
}
}
done();
});
}
REGISTER_OP("RpcServer")
.Input("server_address: string")
.Output("server: resource")
.SetIsStateful();
REGISTER_OP("RpcClient")
.Attr("shared_name: string = ''")
.Input("server_address: string")
.Attr("list_registered_methods: bool = false")
.Input("timeout_in_ms: int64") // 0 indicates no timeout.
// Positive value indicates specified
// timeout.
.Output("client: resource")
.Output("method_specs: string")
.SetIsStateful();
REGISTER_OP("RpcServerStart").Input("server: resource").SetIsStateful();
REGISTER_OP("RpcServerRegister")
.Input("server: resource")
.Input("method_name: string")
.Input("captured_inputs: Tin")
.Attr("Tin: list(type) >=0 = []")
.Attr("f: func")
.Attr("input_specs: string = ''")
.Attr("output_specs: string")
.SetIsStateful();
REGISTER_OP("DeleteRpcFutureResource")
.Input("handle: resource")
.Input("deleter: variant")
.SetShapeFn(shape_inference::NoOutputs);
REGISTER_OP("RpcCall")
.Input("client: resource")
.Input("method_name: string")
.Input("args: Tin")
.Input("timeout_in_ms: int64")
.Attr("Tin: list(type) >= 0")
.Output("future: resource")
.Output("deleter: variant")
.SetIsStateful();
REGISTER_OP("RpcCheckStatus")
.Input("status_or: resource")
.Output("error_code: int64")
.Output("error: string")
.SetIsStateful();
REGISTER_OP("RpcGetValue")
.Input("status_or: resource")
.Attr("Tout: list(type) >= 0")
.Output("output: Tout")
.SetIsStateful();
REGISTER_KERNEL_BUILDER(Name("RpcServer").Device(DEVICE_CPU), RpcServerOp);
REGISTER_KERNEL_BUILDER(Name("RpcClient").Device(DEVICE_CPU), RpcClientOp);
REGISTER_KERNEL_BUILDER(Name("RpcServerStart").Device(DEVICE_CPU),
RpcServerStartOp);
REGISTER_KERNEL_BUILDER(Name("RpcServerRegister").Device(DEVICE_CPU),
RpcServerRegisterOp);
REGISTER_KERNEL_BUILDER(Name("RpcCall").Device(DEVICE_CPU), RpcCallOp);
REGISTER_KERNEL_BUILDER(Name("RpcCheckStatus").Device(DEVICE_CPU),
RpcCheckStatusOp);
REGISTER_KERNEL_BUILDER(Name("RpcGetValue").Device(DEVICE_CPU), RpcGetValueOp);
REGISTER_KERNEL_BUILDER(Name("DeleteRpcFutureResource").Device(DEVICE_CPU),
DeleteRpcFutureResourceOp);
REGISTER_INPUT_COLOCATION_EXEMPTION("RpcServerRegister");
} // namespace rpc
} // namespace tensorflow