Temporarily revert the WaitReady change for handles with unknown devices.
PiperOrigin-RevId: 326949430
Change-Id: I7170402785dcd86e1e16dc6f8391e6586b84a2ae
diff --git a/tensorflow/c/eager/c_api_remote_function_test.cc b/tensorflow/c/eager/c_api_remote_function_test.cc
index 52488e6..a9bbd5b 100644
--- a/tensorflow/c/eager/c_api_remote_function_test.cc
+++ b/tensorflow/c/eager/c_api_remote_function_test.cc
@@ -30,13 +30,12 @@
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false);
}
-// TODO(b/164506563): Re-enable after the fix.
-TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesFuncRemoteOutputs) {
+TEST(CAPI, RemoteExecuteSilentCopiesFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
-TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesAsyncFuncRemoteOutputs) {
+TEST(CAPI, RemoteExecuteSilentCopiesAsyncFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index adf1b55..620685e 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -539,14 +539,13 @@
}
Status TensorHandle::WaitUnknownDevice() const {
- // TODO(b/164506563): uncomment this when b/164506563 is fixed.
- // if (unknown_device_) {
- // TF_RETURN_IF_ERROR(absl::visit(
- // [](auto& data) {
- // return data.WaitReady("TensorHandle::UnknownDevice");
- // },
- // data_));
- // }
+ if (unknown_device_) {
+ TF_RETURN_IF_ERROR(absl::visit(
+ [](auto& data) {
+ return data.WaitReady("TensorHandle::UnknownDevice");
+ },
+ data_));
+ }
return Status::OK();
}
diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc b/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc
index 91c0503..e2bc73b 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc
+++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc
@@ -49,13 +49,6 @@
}
VLOG(3) << "Issuing: " << rpc_description;
- for (auto handle : inputs_) {
- handle->Ref();
- }
- for (auto handle : retvals) {
- handle->Ref();
- }
-
CancellationManager* cm = cancellation_manager_;
CancellationToken token = 0;
auto call_opts = std::make_shared<CallOptions>();
@@ -64,11 +57,22 @@
const bool already_cancelled = !cm->RegisterCallback(
token, [call_opts, response, done]() { call_opts->StartCancel(); });
if (already_cancelled) {
- done(errors::Cancelled("RemoteExecuteNode::RunAsync"));
+ Status s = errors::Cancelled("RemoteExecuteNode::RunAsync");
+ for (size_t i = 0; i < retvals.size(); ++i) {
+ retvals[i]->PoisonRemote(s, device, context_view_id_);
+ }
+ done(s);
return;
}
}
+ for (auto handle : inputs_) {
+ handle->Ref();
+ }
+ for (auto handle : retvals) {
+ handle->Ref();
+ }
+
eager_client_->StreamingEnqueueAsync(
call_opts.get(), request_.get(), response.get(),
[inputs, retvals, call_opts, response, device,