blob: c2c54c9f171a0c61fba0d52ae0d3fb850a7d4495 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
#include "tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/profiler/lib/traceme.h"
namespace tensorflow {
namespace {
void DestroyRemoteTensorHandle(EagerContext* ctx, const string& remote_task,
uint64 context_id, uint64 op_id, int output_num,
bool ready) {
if (ctx->GetContextId() != context_id) {
// This means that this tensor was pointing to a remote device, which
// has been changed out from under us. Simply return since there is
// nothing we can do.
return;
}
core::RefCountPtr<eager::EagerClient> eager_client;
Status status = ctx->GetClient(remote_task, &eager_client);
if (!status.ok()) {
LOG_EVERY_N_SEC(INFO, 60)
<< "Unable to destroy remote tensor handle because the target "
<< remote_task << " is no longer available.";
return;
}
std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
request->set_context_id(context_id);
auto* handle_to_decref = request->add_queue()->mutable_handle_to_decref();
handle_to_decref->set_op_id(op_id);
handle_to_decref->set_output_num(output_num);
VLOG(3) << "Sending request to delete " << request->DebugString();
std::unique_ptr<EagerNode> node(
absl::make_unique<eager::DestroyTensorHandleNode>(
std::move(request), std::move(eager_client), ready));
auto& executor = ctx->Executor();
if (executor.Async()) {
Status status = executor.AddOrExecute(std::move(node));
if (!status.ok()) {
LOG_EVERY_N_SEC(WARNING, 60)
<< "Unable to destroy remote tensor handles. If you are "
"running a tf.function, it usually indicates some op in "
"the graph gets an error: "
<< status.error_message();
}
} else {
// This thread may still hold tensorflow::StreamingRPCState::mu_. We need
// to send out the destroy request in a new thread to avoid deadlock.
auto* released_node = node.release();
(*ctx->runner())([ctx, released_node] {
Status status =
ctx->Executor().AddOrExecute(absl::WrapUnique(released_node));
if (!status.ok()) {
LOG_EVERY_N_SEC(WARNING, 60)
<< "Unable to destroy remote tensor handles. If you are "
"running a tf.function, it usually indicates some op in "
"the graph gets an error: "
<< status.error_message();
}
});
}
}
} // namespace
RemoteTensorHandleData::RemoteTensorHandleData(int64_t op_id, int output_num,
uint64 context_view_id,
bool is_ready)
: is_ready_(is_ready),
op_id_(op_id),
output_num_(output_num),
context_view_id_(context_view_id),
ctx_(nullptr) {
DCHECK(op_id_ >= 0 && output_num_ >= 0)
<< "Op ID and output num should be >= 0. Op ID: " << op_id
<< ", Output num: " << output_num;
}
RemoteTensorHandleData::RemoteTensorHandleData(int64_t op_id, int output_num,
const string& remote_task,
EagerContext* ctx)
: is_ready_(false),
op_id_(op_id),
output_num_(output_num),
remote_task_(remote_task),
context_id_(ctx->GetContextId()),
context_view_id_(ctx->GetContextViewId()),
ctx_(ctx) {
DCHECK(op_id_ >= 0 && output_num_ >= 0)
<< "Op ID and output num should be >= 0. Op ID: " << op_id
<< ", Output num: " << output_num;
ctx_->Ref();
}
RemoteTensorHandleData::~RemoteTensorHandleData() {
if (ctx_) {
DestroyRemoteTensorHandle(ctx_, remote_task_, context_id_, op_id_,
output_num_, /*ready=*/true);
ctx_->Unref();
}
}
Status RemoteTensorHandleData::Shape(TensorShape* shape) const {
TF_RETURN_IF_ERROR(WaitReady("Shape"));
tf_shared_lock l(mu_);
*shape = shape_;
return OkStatus();
}
Status RemoteTensorHandleData::NumDims(int* num_dims) const {
TF_RETURN_IF_ERROR(WaitReady("NumDims"));
tf_shared_lock l(mu_);
*num_dims = shape_.dims();
return OkStatus();
}
Status RemoteTensorHandleData::Dim(int dim_index, int64_t* dim) const {
TF_RETURN_IF_ERROR(WaitReady("Dim"));
tf_shared_lock l(mu_);
*dim = shape_.dim_size(dim_index);
return OkStatus();
}
Status RemoteTensorHandleData::NumElements(int64_t* num_elements) const {
TF_RETURN_IF_ERROR(WaitReady("NumElements"));
tf_shared_lock l(mu_);
*num_elements = shape_.num_elements();
return OkStatus();
}
bool RemoteTensorHandleData::IsReady() const {
tf_shared_lock l(mu_);
return is_ready_;
}
void RemoteTensorHandleData::Poison(Status status) {
mutex_lock l(mu_);
is_poisoned_ = status;
is_ready_ = true;
}
Status RemoteTensorHandleData::IsPoisoned() const {
tf_shared_lock l(mu_);
return is_poisoned_;
}
Status RemoteTensorHandleData::SetShape(const TensorShape& shape) {
return SetShapeAndRemoteTask(shape, /*remote_task=*/"");
}
Status RemoteTensorHandleData::SetShapeAndRemoteTask(
const TensorShape& shape, const string& remote_task) {
// If `is_ready_` is set previously due to poisoning, return the original
// error that poisoned this tensor.
TF_RETURN_IF_ERROR(IsPoisoned());
mutex_lock l(mu_);
if (is_ready_) {
return errors::Internal("SetShape is only called on non-ready handles.");
}
shape_ = shape;
if (!remote_task.empty()) {
remote_task_ = remote_task;
}
is_poisoned_ = OkStatus();
is_ready_ = true;
return OkStatus();
}
string RemoteTensorHandleData::DebugString() const {
return strings::StrCat("RemoteTensorHandleData:", " op_id: ", op_id_,
" output_num: ", output_num_);
}
Status RemoteTensorHandleData::OpIdAndOutputNum(const bool wait_util_ready,
int64_t* op_id,
int32* output_num) const {
if (wait_util_ready) {
TF_RETURN_IF_ERROR(WaitReady("OpIdAndOutputNumUntilReady"));
}
*op_id = op_id_;
*output_num = output_num_;
return OkStatus();
}
Status RemoteTensorHandleData::WaitReady(const char* caller) const {
tf_shared_lock l(mu_);
if (!is_ready_) {
profiler::TraceMe activity(
[caller] { return absl::StrCat(caller, " WaitReady"); },
profiler::TraceMeLevel::kInfo);
DVLOG(3) << "WaitReady: " << caller << " " << this;
mu_.Await(Condition(&is_ready_));
}
return is_poisoned_;
}
} // namespace tensorflow