blob: 7eb8a33c72764b1d760fdaf662c96ba7da746d5c [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api.h"
#include <algorithm>
#include <cstddef>
#include <memory>
#include <string>
#include <vector>
// clang-format off
// Required for IS_MOBILE_PLATFORM
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "absl/container/fixed_array.h"
#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/platform.h" // NOLINT
#include "tensorflow/core/util/device_name_utils.h"
#ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#endif // TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#endif // !IS_MOBILE_PLATFORM
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/version.h"
using tensorflow::int64;
using tensorflow::string;
namespace {
const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
if (op->inference_ctx) {
return op->inference_ctx->op_def;
}
const tensorflow::OpDef* op_def;
status->status =
tensorflow::OpDefForOp(op->operation.Name().c_str(), &op_def);
return op_def;
}
bool IsCPU(const tensorflow::Device* d) {
return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
}
string DeviceName(const tensorflow::Device* d) {
return (d == nullptr) ? "cpu:0" : d->name();
}
#if !defined(IS_MOBILE_PLATFORM)
tensorflow::Status AddRemoteDevicesToMgr(
const std::vector<string>& added_remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
tensorflow::DynamicDeviceMgr* remote_device_mgr) {
std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
tensorflow::mutex remote_devices_mu;
int num_added_workers = added_remote_workers.size();
tensorflow::BlockingCounter counter(num_added_workers);
std::vector<tensorflow::Status> statuses(num_added_workers);
for (int i = 0; i < num_added_workers; i++) {
tensorflow::NewRemoteDevices(
tensorflow::Env::Default(), worker_cache, added_remote_workers[i],
[i, &statuses, &counter, &remote_devices, &remote_devices_mu](
const tensorflow::Status& s,
std::vector<tensorflow::Device*>* devices) {
statuses[i] = s;
if (s.ok()) {
tensorflow::mutex_lock l(remote_devices_mu);
for (tensorflow::Device* d : *devices) {
remote_devices.emplace_back(d);
}
}
counter.DecrementCount();
});
}
counter.Wait();
for (int i = 0; i < num_added_workers; i++) {
TF_RETURN_IF_ERROR(statuses[i]);
}
TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices)));
return tensorflow::Status::OK();
}
tensorflow::Status GetAllRemoteDevices(
const std::vector<string>& remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
std::unique_ptr<tensorflow::DynamicDeviceMgr>* device_mgr) {
auto remote_device_mgr = absl::make_unique<tensorflow::DynamicDeviceMgr>();
TF_RETURN_IF_ERROR(AddRemoteDevicesToMgr(remote_workers, worker_cache,
remote_device_mgr.get()));
*device_mgr = std::move(remote_device_mgr);
return tensorflow::Status::OK();
}
tensorflow::Status RemoveRemoteDevicesFromMgr(
const std::vector<string>& removed_remote_workers,
tensorflow::DynamicDeviceMgr* remote_device_mgr) {
const std::vector<tensorflow::Device*> remote_devices =
(remote_device_mgr->ListDevices());
std::vector<tensorflow::Device*> devices_to_remove;
for (tensorflow::Device* d : remote_devices) {
for (const string& remote_worker : removed_remote_workers) {
if (tensorflow::DeviceNameUtils::IsSameAddressSpace(remote_worker,
d->name())) {
devices_to_remove.emplace_back(d);
break;
}
}
}
TF_RETURN_IF_ERROR(remote_device_mgr->RemoveDevices(devices_to_remove));
return tensorflow::Status::OK();
}
tensorflow::Status ListRemoteWorkers(tensorflow::ServerInterface* server,
const string& local_worker,
std::vector<string>* remote_workers) {
tensorflow::GrpcServer* grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(server);
if (grpc_server == nullptr) {
return tensorflow::errors::Internal(
"Currently, TFE_NewContext only supports tensorflow::GrpcServer.");
}
grpc_server->master_env()->worker_cache->ListWorkers(remote_workers);
remote_workers->erase(
std::remove(remote_workers->begin(), remote_workers->end(), local_worker),
remote_workers->end());
return tensorflow::Status::OK();
}
void DifferentiateWorkerLists(const std::vector<string>* current_list,
const std::vector<string>* new_list,
std::vector<string>* added,
std::vector<string>* removed,
std::vector<string>* existing) {
// Get STL set_difference and set_intersection with one list traversal.
// Similar to the set_difference library function, the input lists
// (`current_list` and `new_list`) must be sorted before calling the function.
added->resize(new_list->size());
removed->resize(current_list->size());
existing->resize(current_list->size());
std::vector<string>::const_iterator curr_it = current_list->begin();
std::vector<string>::const_iterator new_it = new_list->begin();
std::vector<string>::iterator added_it = added->begin();
std::vector<string>::iterator removed_it = removed->begin();
std::vector<string>::iterator existing_it = existing->begin();
while (curr_it != current_list->end() && new_it != new_list->end()) {
if (*curr_it < *new_it) {
*removed_it++ = *curr_it++;
} else if (*curr_it > *new_it) {
*added_it++ = *new_it++;
} else {
*existing_it++ = *curr_it++;
new_it++;
}
}
removed_it = std::copy(curr_it, current_list->end(), removed_it);
added_it = std::copy(new_it, new_list->end(), added_it);
added->resize(added_it - added->begin());
removed->resize(removed_it - removed->begin());
existing->resize(existing_it - existing->begin());
}
tensorflow::Status CreateRemoteContexts(
const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
int keep_alive_secs, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
const tensorflow::eager::CreateContextRequest& base_request) {
int num_remote_workers = remote_workers.size();
tensorflow::BlockingCounter counter(num_remote_workers);
std::vector<tensorflow::Status> statuses(num_remote_workers);
for (int i = 0; i < num_remote_workers; i++) {
const string& remote_worker = remote_workers[i];
tensorflow::DeviceNameUtils::ParsedName parsed_name;
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
&parsed_name)) {
statuses[i] = tensorflow::errors::InvalidArgument(
"Unable to parse ", remote_worker, " as a device name");
counter.DecrementCount();
continue;
}
tensorflow::eager::EagerClient* eager_client;
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
if (eager_client == nullptr) {
statuses[i] = tensorflow::errors::Internal(
"Cannot find a client for the given target:", remote_worker);
}
if (!statuses[i].ok()) {
counter.DecrementCount();
continue;
}
tensorflow::eager::CreateContextRequest request(base_request);
tensorflow::eager::CreateContextResponse* response =
new tensorflow::eager::CreateContextResponse();
request.set_context_id(context_id);
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.set_async(async);
request.set_keep_alive_secs(keep_alive_secs);
eager_client->CreateContextAsync(
&request, response,
[i, &statuses, &counter, response](const tensorflow::Status& s) {
statuses[i] = s;
delete response;
counter.DecrementCount();
});
}
counter.Wait();
for (int i = 0; i < num_remote_workers; i++) {
TF_RETURN_IF_ERROR(statuses[i]);
}
return tensorflow::Status::OK();
}
tensorflow::Status UpdateRemoteContexts(
const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
const tensorflow::eager::CreateContextRequest& base_request) {
int num_remote_workers = remote_workers.size();
tensorflow::BlockingCounter counter(num_remote_workers);
std::vector<tensorflow::Status> statuses(num_remote_workers);
for (int i = 0; i < num_remote_workers; i++) {
const string& remote_worker = remote_workers[i];
tensorflow::DeviceNameUtils::ParsedName parsed_name;
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
&parsed_name)) {
statuses[i] = tensorflow::errors::InvalidArgument(
"Unable to parse ", remote_worker, " as a device name");
counter.DecrementCount();
continue;
}
tensorflow::eager::EagerClient* eager_client;
statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client);
if (eager_client == nullptr) {
statuses[i] = tensorflow::errors::Internal(
"Cannot find a client for the given target:", remote_worker);
}
if (!statuses[i].ok()) {
counter.DecrementCount();
continue;
}
tensorflow::eager::UpdateContextRequest request;
auto* response = new tensorflow::eager::UpdateContextResponse();
*request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
for (const auto& da : base_request.cluster_device_attributes()) {
*request.add_cluster_device_attributes() = da;
}
request.set_context_id(context_id);
request.set_context_view_id(context_view_id);
eager_client->UpdateContextAsync(
&request, response,
[i, &statuses, &counter, response](const tensorflow::Status& s) {
statuses[i] = s;
delete response;
counter.DecrementCount();
});
}
counter.Wait();
for (int i = 0; i < num_remote_workers; i++) {
TF_RETURN_IF_ERROR(statuses[i]);
}
return tensorflow::Status::OK();
}
tensorflow::Status UpdateTFE_ContextWithServerDef(
int keep_alive_secs, const tensorflow::ServerDef& server_def,
TFE_Context* ctx, bool reset_context) {
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
// server object (which currently CHECK-fails) and we miss the error, instead,
// we log the error, and then return to allow the user to see the error
// message.
#define LOG_AND_RETURN_IF_ERROR(...) \
do { \
const ::tensorflow::Status _status = (__VA_ARGS__); \
if (TF_PREDICT_FALSE(!_status.ok())) { \
LOG(ERROR) << _status.error_message(); \
return _status; \
} \
} while (0);
string worker_name =
tensorflow::strings::StrCat("/job:", server_def.job_name(),
"/replica:0/task:", server_def.task_index());
// List of current remote workers before updating server_def. Unused if
// resetting the server_def.
std::vector<string> curr_remote_workers;
// List of updated remote workers.
std::vector<string> remote_workers;
// New server created for new server_def. Unused if updating server_def.
std::unique_ptr<tensorflow::ServerInterface> new_server;
tensorflow::GrpcServer* grpc_server;
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
} else {
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(
ctx->context->GetServer(), worker_name, &curr_remote_workers));
// No need to check the cast here, since `ListRemoteWorkers` already checks
// if the server is a GRPC server or not.
grpc_server =
dynamic_cast<tensorflow::GrpcServer*>(ctx->context->GetServer());
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR(
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
}
tensorflow::uint64 context_id = ctx->context->GetContextId();
if (reset_context) {
context_id = tensorflow::EagerContext::NewContextId();
// Make master eager context accessible by local eager service, which might
// receive send tensor requests from remote workers.
LOG_AND_RETURN_IF_ERROR(grpc_server->AddMasterEagerContextToEagerService(
context_id, ctx->context));
}
// When updating an existing context, populate the following lists with:
// * added_workers: set(remote_workers) - set(curr_remote_workers)
// * removed_workers: set(curr_remote_workers) - set(remote_workers)
// * existing_workers: set(curr_remote_workers) intersect set(remote_workers)
std::vector<string> added_workers;
std::vector<string> removed_workers;
std::vector<string> existing_workers;
std::unique_ptr<tensorflow::DynamicDeviceMgr> remote_device_mgr;
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
remote_workers, grpc_server->master_env()->worker_cache,
&remote_device_mgr));
} else {
ctx->context->ClearCaches();
grpc_server->worker_env()->rendezvous_mgr->Cleanup(context_id);
remote_device_mgr = ctx->context->ReleaseRemoteDeviceMgr();
if (remote_device_mgr == nullptr) {
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument(
"Updating context with invalid a valid set of remote devices."));
}
std::sort(curr_remote_workers.begin(), curr_remote_workers.end());
std::sort(remote_workers.begin(), remote_workers.end());
DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
&added_workers, &removed_workers,
&existing_workers);
if (!added_workers.empty()) {
if (VLOG_IS_ON(1)) {
for (const string& w : added_workers) {
VLOG(1) << "Updating cluster with added worker " << w;
}
}
LOG_AND_RETURN_IF_ERROR(AddRemoteDevicesToMgr(
added_workers, grpc_server->master_env()->worker_cache,
remote_device_mgr.get()));
}
if (!removed_workers.empty()) {
if (VLOG_IS_ON(1)) {
for (const string& w : removed_workers) {
VLOG(1) << "Updating cluster with removed worker " << w;
}
}
LOG_AND_RETURN_IF_ERROR(
RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr.get()));
}
}
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes);
std::vector<tensorflow::DeviceAttributes> local_device_attributes;
grpc_server->worker_env()->device_mgr->ListDeviceAttributes(
&local_device_attributes);
// This request make sure that we can create Rendevzous properly between
// Local and Remote context.
tensorflow::eager::CreateContextRequest base_request;
for (const auto& da : cluster_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
for (const auto& da : local_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
base_request.mutable_server_def()
->mutable_default_session_config()
->MergeFrom(server_def.default_session_config());
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
LOG_AND_RETURN_IF_ERROR(
grpc_server->master_env()->worker_cache->GetEagerClientCache(
&remote_eager_workers));
// Initialize remote eager workers.
// TODO(b/138847548) Create remote eager contexts in async mode by default.
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(
CreateRemoteContexts(remote_workers, context_id, keep_alive_secs,
server_def, remote_eager_workers.get(),
ctx->context->Executor().Async(), base_request));
} else {
LOG_AND_RETURN_IF_ERROR(
CreateRemoteContexts(added_workers, context_id, keep_alive_secs,
server_def, remote_eager_workers.get(),
ctx->context->Executor().Async(), base_request));
if (!existing_workers.empty()) {
if (VLOG_IS_ON(1)) {
for (const string& w : existing_workers) {
VLOG(1) << "Updating cluster with existing worker " << w;
}
}
LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts(
existing_workers, context_id, ctx->context->GetContextViewId() + 1,
server_def, remote_eager_workers.get(),
ctx->context->Executor().Async(), base_request));
}
}
tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
auto session_name = tensorflow::strings::StrCat("eager_", context_id);
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
session_name, server_def, base_request.cluster_device_attributes(),
true));
std::shared_ptr<tensorflow::WorkerSession> worker_session;
TF_RETURN_IF_ERROR(
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
session_name, &worker_session));
// Initialize remote tensor communication based on worker session.
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
auto* device_mgr = grpc_server->worker_env()->device_mgr;
auto remote_mgr = absl::make_unique<tensorflow::eager::RemoteMgr>(
/*is_master=*/true, ctx->context);
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(ctx->context->InitializeRemoteMaster(
std::move(new_server), grpc_server->worker_env(), worker_session,
std::move(remote_eager_workers), std::move(remote_device_mgr),
remote_workers, context_id, r, device_mgr, keep_alive_secs,
worker_session->cluster_flr(), std::move(remote_mgr)));
} else {
LOG_AND_RETURN_IF_ERROR(ctx->context->UpdateRemoteMaster(
grpc_server->worker_env(), worker_session,
std::move(remote_eager_workers), std::move(remote_device_mgr),
added_workers, removed_workers, context_id, r, device_mgr,
keep_alive_secs, worker_session->cluster_flr()));
}
// NOTE: We start the server after all other initialization, because the
// GrpcServer cannot be destroyed after it is started.
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
#undef LOG_AND_RETURN_IF_ERROR
return tensorflow::Status::OK();
}
#endif // !IS_MOBILE_PLATFORM
tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,
TFE_TensorHandle* input) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
if (!input_def.number_attr().empty() || !input_def.type_list_attr().empty()) {
// Some clients that are still setting their input attributes manually are
// adding input list to their op by calling `TFE_OpAddInput` for each of
// its elements instead of calling `TFE_OpAddInputList`. When this happens,
// we cannot detect the end of such list, thus lose track of the input
// arguments in the op definition. To guarantee backward compatibility with
// those clients, disable automatic inference in this case.
op->inference_ctx.reset(nullptr);
return tensorflow::Status::OK();
}
const std::string& type_attr = input_def.type_attr();
if (!type_attr.empty() && ictx->attrs.find(type_attr) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(type_attr, input->handle->dtype);
ictx->attrs.insert(type_attr);
}
return tensorflow::Status::OK();
}
void OpInferSingleTypeInputListAttrs(TFE_Op* op,
const tensorflow::OpDef::ArgDef& input_def,
TFE_TensorHandle** inputs,
int num_inputs) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
if (ictx->attrs.find(input_def.number_attr()) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(input_def.number_attr(), num_inputs);
ictx->attrs.insert(input_def.number_attr());
}
if (ictx->attrs.find(input_def.type_attr()) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(input_def.type_attr(),
inputs[0]->handle->dtype);
ictx->attrs.insert(input_def.type_attr());
}
}
void OpInferMixedTypeInputListAttrs(TFE_Op* op,
const tensorflow::OpDef::ArgDef& input_def,
TFE_TensorHandle** inputs, int num_inputs) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
if (ictx->attrs.find(input_def.type_list_attr()) == ictx->attrs.end()) {
std::unique_ptr<tensorflow::DataType[]> dtypes(
new tensorflow::DataType[num_inputs]);
for (int i = 0; i < num_inputs; ++i) {
dtypes[i] = inputs[i]->handle->dtype;
}
op->operation.MutableAttrs()->Set(
input_def.type_list_attr(),
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(dtypes.get(),
num_inputs));
ictx->attrs.insert(input_def.type_list_attr());
}
}
tensorflow::Status OpInferInputListAttrs(TFE_Op* op, TFE_TensorHandle** inputs,
int num_inputs) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
if (!input_def.type_list_attr().empty()) {
OpInferMixedTypeInputListAttrs(op, input_def, inputs, num_inputs);
} else if (!input_def.type_attr().empty() &&
!input_def.number_attr().empty()) {
OpInferSingleTypeInputListAttrs(op, input_def, inputs, num_inputs);
} else {
return tensorflow::errors::InvalidArgument("Invalid input list definition");
}
return tensorflow::Status::OK();
}
} // namespace
extern "C" {
TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
size_t proto_len, TF_Status* status) {
TF_SetConfig(&options->session_options, proto, proto_len, status);
}
void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
unsigned char enable) {
options->async = enable;
}
void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
options->device_placement_policy = policy;
}
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
std::vector<std::unique_ptr<tensorflow::Device>> devices;
status->status = tensorflow::DeviceFactory::AddDevices(
opts->session_options.options, "/job:localhost/replica:0/task:0",
&devices);
if (!status->status.ok()) return nullptr;
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
new tensorflow::StaticDeviceMgr(std::move(devices)));
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
return new TFE_Context(opts->session_options.options,
opts->device_placement_policy, opts->mirroring_policy,
opts->async, device_mgr.release(),
/*device_mgr_owned*/ true, r,
tensorflow::GetDefaultCustomKernelCreator());
}
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
TF_Session* sess, TF_Status* status) {
const tensorflow::DeviceMgr* device_mgr = nullptr;
status->status = sess->session->LocalDeviceManager(&device_mgr);
if (!status->status.ok()) return nullptr;
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr);
return new TFE_Context(opts->session_options.options,
opts->device_placement_policy, opts->mirroring_policy,
opts->async, device_mgr, /*device_mgr_owned*/ false, r,
tensorflow::GetDefaultCustomKernelCreator());
}
void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
TF_DeviceList* list = new TF_DeviceList;
ctx->context->local_device_mgr()->ListDeviceAttributes(&list->response);
if (ctx->context->remote_device_mgr()) {
ctx->context->remote_device_mgr()->ListDeviceAttributes(&list->response);
}
return list;
}
void TFE_ContextClearCaches(TFE_Context* ctx) { ctx->context->ClearCaches(); }
// Set server_def on the context, possibly updating it.
TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
int keep_alive_secs,
const void* proto,
size_t proto_len,
TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM)
status->status = tensorflow::errors::Unimplemented(
"TFE_ContextSetServerDef not supported on mobile");
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::ServerDef server_def;
if (!server_def.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
"Invalid tensorflow.ServerDef protocol buffer");
return;
}
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
ctx, /*reset_context=*/true);
#endif // !IS_MOBILE_PLATFORM
}
TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
int keep_alive_secs,
const void* proto,
size_t proto_len,
TF_Status* status) {
#if defined(IS_MOBILE_PLATFORM)
status->status = tensorflow::errors::Unimplemented(
"TFE_ContextSetServerDef not supported on mobile");
#else // !defined(IS_MOBILE_PLATFORM)
tensorflow::ServerDef server_def;
if (!server_def.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
"Invalid tensorflow.ServerDef protocol buffer");
return;
} else if (ctx->context->GetContextId() ==
tensorflow::EagerContext::kInvalidContextId) {
status->status = tensorflow::errors::InvalidArgument(
"Trying to update a context with invalid context id.");
}
// TODO(haoyuzhang): Check server_def compatibility before the update
status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def,
ctx, /*reset_context=*/false);
#endif // !IS_MOBILE_PLATFORM
}
void TFE_ContextSetThreadLocalDevicePlacementPolicy(
TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
ctx->context->SetThreadLocalDevicePlacementPolicy(
static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
}
// Note: this function looks up a thread local policy. So it should be called in
// the appropriate client thread. In particular, in async mode, it may not be
// safe to call this function from the async EagerExecutor threads.
extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
TFE_Context* ctx) {
return static_cast<TFE_ContextDevicePlacementPolicy>(
ctx->context->GetDevicePlacementPolicy());
}
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
tensorflow::Tensor tensor;
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
if (!status->status.ok()) return nullptr;
return TFE_TensorHandle::CreateLocalHandle(tensor, status);
}
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
if (h == nullptr) return;
VLOG(1) << "Deleting tensor handle " << h << " with internal handle "
<< h->handle;
if (h->handle) {
h->handle->Unref();
}
delete h;
}
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
return static_cast<TF_DataType>(h->handle->dtype);
}
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return -1;
}
int result;
status->status = h->handle->NumDims(&result);
return result;
}
int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return -1;
}
tensorflow::int64 result;
status->status = h->handle->NumElements(&result);
return result;
}
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return -1;
}
tensorflow::int64 result;
status->status = h->handle->Dim(dim_index, &result);
return result;
}
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::Device* d = h->handle->op_device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
}
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::Device* d = h->handle->device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
}
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
h->handle->Ref();
return new TFE_TensorHandle(h->handle);
}
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::TensorHandle* handle = h->handle;
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
if (handle->IsRemote()) {
const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* h_cpu = nullptr;
status->status = EagerCopyToDevice(
handle, handle->Context(), &handle->Context()->Executor(),
handle->Context()->HostCPU(), false, &h_cpu);
if (!status->status.ok()) {
return nullptr;
}
status->status = h_cpu->Tensor(&t);
if (!status->status.ok()) {
h_cpu->Unref();
return nullptr;
}
TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
h_cpu->Unref();
return retval;
} else {
tensorflow::Tensor tensor;
if (IsCPU(handle->device())) {
const tensorflow::Tensor* src = nullptr;
status->status = handle->Tensor(&src);
if (!status->status.ok()) return nullptr;
tensor = *src;
} else {
tensorflow::EagerContext* ctx = handle->Context();
CHECK_NE(ctx, nullptr);
status->status = h->handle->CopyToDevice(ctx, ctx->HostCPU(), &tensor);
if (!status->status.ok()) return nullptr;
}
return tensorflow::TF_TensorFromTensor(tensor, status);
}
}
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) {
return NewOrResetOp(ctx, op_or_function_name, status,
/* op_to_reset= */ nullptr);
}
void TFE_DeleteOp(TFE_Op* op) { delete op; }
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
status->status = op->operation.SetDeviceName(device_name);
}
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
tensorflow::Device* device = (op->operation.Device() == nullptr)
? op->operation.EagerContext()->HostCPU()
: op->operation.Device();
return device->name().c_str();
}
void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
op->operation.SetUseXla(enable);
#ifndef TENSORFLOW_EAGER_USE_XLA
LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
"built with XLA support.";
#endif // TENSORFLOW_EAGER_USE_XLA
}
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
op->operation.AddInput(input->handle);
if (op->inference_ctx) {
status->status = OpInferSingleInputAttrs(op, input);
}
}
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) {
for (int i = 0; i < num_inputs; ++i) {
op->operation.AddInput(inputs[i]->handle);
}
if (op->inference_ctx) {
status->status = OpInferInputListAttrs(op, inputs, num_inputs);
}
}
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
unsigned char* is_list, TF_Status* status) {
TF_AttrType ret = TF_ATTR_INT;
status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
attr_name, &ret, is_list);
return ret;
}
TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
const char* op_or_function_name,
const char* attr_name, unsigned char* is_list,
TF_Status* status) {
TF_AttrType ret;
TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
if (status->status.ok()) {
ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
} else {
ret = TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType.
}
TFE_DeleteOp(op);
return ret;
}
void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
size_t length) {
op->operation.MutableAttrs()->Set(
attr_name,
tensorflow::StringPiece(static_cast<const char*>(value), length));
}
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
op->operation.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
}
void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
op->operation.MutableAttrs()->Set(attr_name, value);
}
void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true);
}
void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
op->operation.MutableAttrs()->Set(attr_name,
static_cast<tensorflow::DataType>(value));
}
void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
const int num_dims, TF_Status* out_status) {
if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat(
"Value specified for `", attr_name, "` has ", num_dims,
" dimensions which is over the limit of ",
tensorflow::TensorShape::MaxDimensions(), ".")
.c_str());
return;
}
tensorflow::TensorShapeProto proto;
if (num_dims < 0) {
proto.set_unknown_rank(true);
} else {
for (int d = 0; d < num_dims; ++d) {
proto.add_dim()->set_size(dims[d]);
}
}
op->operation.MutableAttrs()->Set(attr_name, proto);
}
void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
const TFE_Op* value) {
tensorflow::AttrValue attr_value;
tensorflow::NameAttrList* func = attr_value.mutable_func();
func->set_name(value->operation.Name());
value->operation.Attrs().FillAttrValueMap(func->mutable_attr());
op->operation.MutableAttrs()->Set(attr_name, attr_value);
}
void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
const char* data, size_t length) {
tensorflow::AttrValue attr_value;
tensorflow::NameAttrList* func = attr_value.mutable_func();
func->set_name(data, length);
op->operation.MutableAttrs()->Set(attr_name, attr_value);
}
void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
TF_Status* status) {
tensorflow::Tensor t;
status->status = TF_TensorToTensor(tensor, &t);
if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t);
}
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
const void* const* values, const size_t* lengths,
int num_values) {
std::vector<tensorflow::StringPiece> v(num_values);
for (int i = 0; i < num_values; ++i) {
v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]),
lengths[i]);
}
op->operation.MutableAttrs()->Set(attr_name, v);
}
void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
const float* values, int num_values) {
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values));
}
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
const int64_t* values, int num_values) {
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<const int64>(
reinterpret_cast<const int64*>(values), num_values));
}
void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
const TF_DataType* values, int num_values) {
op->operation.MutableAttrs()->Set(
attr_name,
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
}
void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
const unsigned char* values, int num_values) {
std::unique_ptr<bool[]> b(new bool[num_values]);
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
}
void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
const int64_t** dims, const int* num_dims,
int num_values, TF_Status* out_status) {
std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
new tensorflow::TensorShapeProto[num_values]);
for (int i = 0; i < num_values; ++i) {
const auto num_dims_i = num_dims[i];
if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat(
"Value specified for `", attr_name, "` has ", num_dims_i,
" dimensions which is over the limit of ",
tensorflow::TensorShape::MaxDimensions(), ".")
.c_str());
return;
}
if (num_dims_i < 0) {
proto[i].set_unknown_rank(true);
} else {
const int64_t* dims_i = dims[i];
auto proto_i = &proto[i];
for (int d = 0; d < num_dims_i; ++d) {
proto_i->add_dim()->set_size(dims_i[d]);
}
}
}
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
proto.get(), num_values));
}
void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
const TFE_Op** value, int num_values) {
std::unique_ptr<tensorflow::NameAttrList[]> funcs(
new tensorflow::NameAttrList[num_values]);
for (int i = 0; i < num_values; i++) {
funcs[i].set_name(value[i]->operation.Name());
value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr());
}
op->operation.MutableAttrs()->Set(
attr_name, tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
funcs.get(), num_values));
}
TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op,
const char* input_name,
TF_Status* status) {
const tensorflow::OpDef* op_def = GetOpDef(op, status);
if (!status->status.ok()) {
return -1;
}
tensorflow::AttrValueMap attrs;
op->operation.Attrs().FillAttrValueMap(&attrs);
tensorflow::NameRangeMap name_ranges;
status->status = tensorflow::NameRangesForNode(
tensorflow::AttrSlice(&attrs), *op_def, &name_ranges, nullptr);
if (!status->status.ok()) {
return -1;
}
auto iter = name_ranges.find(input_name);
if (iter == name_ranges.end()) {
status->status = tensorflow::errors::InvalidArgument("Input '", input_name,
"' not found");
return -1;
}
return iter->second.second - iter->second.first;
}
TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
const char* output_name,
TF_Status* status) {
const tensorflow::OpDef* op_def = GetOpDef(op, status);
if (!status->status.ok()) {
return -1;
}
tensorflow::AttrValueMap attrs;
op->operation.Attrs().FillAttrValueMap(&attrs);
tensorflow::NameRangeMap name_ranges;
status->status = tensorflow::NameRangesForNode(
tensorflow::AttrSlice(&attrs), *op_def, nullptr, &name_ranges);
if (!status->status.ok()) {
return -1;
}
auto iter = name_ranges.find(output_name);
if (iter == name_ranges.end()) {
status->status = tensorflow::errors::InvalidArgument(
"Output '", output_name, "' not found");
return -1;
}
return iter->second.second - iter->second.first;
}
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
VLOG(1) << "Calling TFE_Execute() on op " << op;
absl::FixedArray<tensorflow::TensorHandle*> handle_retvals(*num_retvals);
status->status = tensorflow::EagerExecute(&op->operation,
handle_retvals.data(), num_retvals);
if (!status->status.ok()) {
return;
}
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
}
}
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TFE_Context* ctx,
const char* device_name,
TF_Status* status) {
tensorflow::TensorHandle* handle = nullptr;
tensorflow::Device* device;
status->status = ctx->context->FindDeviceFromName(device_name, &device);
if (!status->status.ok()) {
return nullptr;
}
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
&ctx->context->Executor(),
device, false, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle(handle);
}
return nullptr;
}
void TFE_ContextAddFunctionDef(TFE_Context* ctx,
const char* serialized_function_def, size_t size,
TF_Status* status) {
tensorflow::FunctionDef function_def;
if (!function_def.ParseFromArray(serialized_function_def, size)) {
status->status =
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
return;
}
status->status = ctx->context->AddFunctionDef(function_def);
}
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) {
status->status = ctx->context->AddFunctionDef(function->fdef);
}
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
TF_Status* status) {
status->status = ctx->context->RemoveFunction(name);
}
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
return ctx->context->FindFunctionDef(name) != nullptr;
}
void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(true);
}
void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(false);
}
} // extern "C"
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
TF_Status* status) {
return TFE_TensorHandle::CreateLocalHandle(t, status);
}
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
TF_Status* status) {
status->status = ctx->context->Executor().WaitForAllPendingNodes();
if (!status->status.ok()) return;
tensorflow::mutex_lock ml(*ctx->context->MetadataMu());
status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf);
ctx->context->ClearRunMetadata();
}
namespace {
TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
TF_Status* status) {
TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
for (const auto& attr : func.attr()) {
if (!status->status.ok()) return nullptr;
SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
if (!status->status.ok()) return nullptr;
}
return func_op;
}
} // namespace
void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context->StartStep(); }
void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context->EndStep(); }
namespace tensorflow {
void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
const tensorflow::AttrValue& default_value,
const char* attr_name, TF_Status* status) {
switch (default_value.value_case()) {
case tensorflow::AttrValue::kS: {
const string& v = default_value.s();
TFE_OpSetAttrString(op, attr_name, v.data(), v.size());
break;
}
case tensorflow::AttrValue::kI:
TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i()));
break;
case tensorflow::AttrValue::kF:
TFE_OpSetAttrFloat(op, attr_name, default_value.f());
break;
case tensorflow::AttrValue::kB:
TFE_OpSetAttrBool(op, attr_name, default_value.b());
break;
case tensorflow::AttrValue::kType:
TFE_OpSetAttrType(op, attr_name,
static_cast<TF_DataType>(default_value.type()));
break;
case tensorflow::AttrValue::kShape: {
const auto& tensor_shape = default_value.shape();
if (tensor_shape.unknown_rank()) {
TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status);
} else {
const auto num_dims = tensor_shape.dim_size();
std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
for (int i = 0; i < num_dims; ++i) {
dims[i] = tensor_shape.dim(i).size();
}
TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status);
}
} break;
case tensorflow::AttrValue::kFunc: {
const auto func_op = GetFunc(ctx, default_value.func(), status);
if (!status->status.ok()) return;
// TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
// require TFE_Op* and just convert it internally a NameAttrValue, so
// consider adding an overload to the C API to make this case easier.
TFE_OpSetAttrFunction(op, attr_name, func_op);
} break;
case tensorflow::AttrValue::kList:
TF_FALLTHROUGH_INTENDED;
case tensorflow::AttrValue::kTensor:
TF_FALLTHROUGH_INTENDED;
case tensorflow::AttrValue::kPlaceholder:
TF_FALLTHROUGH_INTENDED;
case tensorflow::AttrValue::VALUE_NOT_SET:
TF_SetStatus(
status, TF_UNIMPLEMENTED,
tensorflow::strings::StrCat("Unable to get setfor default value: ",
default_value.DebugString())
.data());
}
}
} // namespace tensorflow