| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h" |
| |
| #include <memory> |
| |
| #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/platform/error_payloads.h" |
| |
| namespace tensorflow { |
| |
| namespace { |
| Status WithErrorSourcePayload(Status error) { |
| core::platform::ErrorSourceProto error_source_proto; |
| error_source_proto.set_error_source( |
| core::platform::ErrorSourceProto::EAGER_REMOTE_MGR); |
| error.SetPayload(tensorflow::kErrorSource, |
| error_source_proto.SerializeAsString()); |
| return error; |
| } |
| } // namespace |
| |
| namespace eager { |
| |
| void RemoteMgr::AddOperationOutputs( |
| const gtl::ArraySlice<tensorflow::TensorHandle*> handles, |
| int64_t operation_id) { |
| mutex_lock l(remote_tensor_handle_mu_); |
| for (int i = 0, end = handles.size(); i < end; i++) { |
| // TODO(nareshmodi): Correctly handle operation_id not being unique. |
| remote_tensor_handle_map_.emplace( |
| RemoteTensorHandleInternal(operation_id, i), handles[i]); |
| } |
| } |
| |
| void RemoteMgr::AddOperationOutput(tensorflow::TensorHandle* handle, |
| int64_t operation_id, int32_t output_num) { |
| mutex_lock l(remote_tensor_handle_mu_); |
| remote_tensor_handle_map_.emplace( |
| RemoteTensorHandleInternal(operation_id, output_num), handle); |
| } |
| |
| Status RemoteMgr::GetTensorHandleImpl( |
| const RemoteTensorHandleInternal& remote_handle, |
| tensorflow::TensorHandle** handle) { |
| auto iter = remote_tensor_handle_map_.find(remote_handle); |
| if (iter == remote_tensor_handle_map_.end()) { |
| // TODO(b/217820532): Fix the tensor deallocation order issue. |
| return WithErrorSourcePayload(errors::InvalidArgument( |
| "Unable to find the relevant tensor remote_handle: Op ID: ", |
| remote_handle.op_id, ", Output num: ", remote_handle.output_num, |
| ". One possible cause is that the tensor was accessed after " |
| "deallocation in a distributed worker setup. Try setting " |
| "`os.environ['TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE']='False'` in " |
| "your client to disable async streaming behavior to see if it fixes " |
| "the problem.")); |
| } |
| |
| *handle = iter->second; |
| |
| return OkStatus(); |
| } |
| |
| Status RemoteMgr::GetTensorHandle( |
| const RemoteTensorHandleInternal& remote_handle, |
| tensorflow::TensorHandle** handle) { |
| tf_shared_lock l(remote_tensor_handle_mu_); |
| return GetTensorHandleImpl(remote_handle, handle); |
| } |
| |
| Status RemoteMgr::GetMirroredResourceShape( |
| const RemoteTensorHandleInternal& remote_handle, |
| std::vector<DtypeAndPartialTensorShape>* handle) { |
| tf_shared_lock l(mirrored_resource_shape_mu_); |
| auto iter = mirrored_resource_shape_map_.find(remote_handle); |
| if (iter == mirrored_resource_shape_map_.end()) { |
| // TODO(b/217820532): Fix the tensor deallocation order issue. |
| return WithErrorSourcePayload(errors::InvalidArgument( |
| "Unable to find the relevant tensor remote_handle: Op ID: ", |
| remote_handle.op_id, ", Output num: ", remote_handle.output_num, |
| ". One possible cause is that the tensor was accessed after " |
| "deallocation in a distributed worker setup. Try setting " |
| "`os.environ['TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE']='False'` in " |
| "your client to disable async streaming behavior to see if it fixes " |
| "the problem.")); |
| } |
| |
| *handle = iter->second; |
| |
| return OkStatus(); |
| } |
| |
| Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle, |
| const bool wait_until_ready, |
| int64_t* op_id, int32* output_num) { |
| TF_RETURN_IF_ERROR(handle->RemoteAddress(handle->device(), wait_until_ready, |
| op_id, output_num)); |
| tensorflow::TensorHandle* h; |
| TF_RETURN_IF_ERROR( |
| GetTensorHandleImpl(RemoteTensorHandleInternal(*op_id, *output_num), &h)); |
| if (handle != h) { |
| return WithErrorSourcePayload(errors::Internal( |
| "Found two different tensor handles with the same op_id:", *op_id, |
| " and output_num:", *output_num)); |
| } |
| return OkStatus(); |
| } |
| |
| Status RemoteMgr::DeleteTensorHandle( |
| const RemoteTensorHandleInternal& remote_handle) { |
| { |
| mutex_lock l(remote_tensor_handle_mu_); |
| auto iter = remote_tensor_handle_map_.find(remote_handle); |
| if (iter != remote_tensor_handle_map_.end()) { |
| iter->second->Unref(); |
| remote_tensor_handle_map_.erase(iter); |
| return OkStatus(); |
| } |
| } |
| { |
| mutex_lock l(mirrored_resource_shape_mu_); |
| auto iter = mirrored_resource_shape_map_.find(remote_handle); |
| if (iter != mirrored_resource_shape_map_.end()) { |
| mirrored_resource_shape_map_.erase(iter); |
| return OkStatus(); |
| } |
| } |
| return WithErrorSourcePayload(errors::InvalidArgument( |
| "Unable to find the relevant tensor remote_handle: Op ID: ", |
| remote_handle.op_id, ", Output num: ", remote_handle.output_num)); |
| } |
| |
| Status RemoteMgr::SerializeRemoteTensorHandle( |
| TensorHandle* in, const bool wait_until_ready, RemoteTensorHandle* out, |
| Device* device, const string& device_name, |
| const bool serialize_resource_dtype_and_shape) { |
| int64_t op_id; |
| int32_t output_num; |
| if (!in->RemoteAddress(device, wait_until_ready, &op_id, &output_num).ok()) { |
| tf_shared_lock l(remote_tensor_handle_mu_); |
| TF_RETURN_IF_ERROR( |
| GetRemoteTensorHandle(in, wait_until_ready, &op_id, &output_num)); |
| } |
| out->Clear(); |
| out->set_op_id(op_id); |
| out->set_output_num(output_num); |
| out->set_op_device(in->op_device() ? in->op_device()->name() : ""); |
| out->set_device(device_name); |
| out->set_dtype(in->dtype); |
| if (serialize_resource_dtype_and_shape) { |
| std::vector<DtypeAndPartialTensorShape> resource_dtypes_and_shapes; |
| TF_RETURN_IF_ERROR( |
| in->GetResourceHandleDtypesAndShapes(&resource_dtypes_and_shapes)); |
| for (const auto& dtype_and_shape : resource_dtypes_and_shapes) { |
| ResourceDtypeAndShape* dtype_and_shape_proto = |
| out->add_resource_dtypes_and_shapes(); |
| dtype_and_shape_proto->set_dtype(dtype_and_shape.dtype); |
| dtype_and_shape.shape.AsProto(dtype_and_shape_proto->mutable_shape()); |
| } |
| } |
| return OkStatus(); |
| } |
| |
| Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in, |
| TensorHandle** out) { |
| Device* device; |
| if (parent_->local_device_mgr()->LookupDevice(in.op_device(), &device).ok() || |
| parent_->local_device_mgr()->LookupDevice(in.device(), &device).ok()) { |
| TF_RETURN_IF_ERROR(GetTensorHandle(RemoteTensorHandleInternal(in), out)); |
| (*out)->Ref(); |
| } else { |
| // Create a remote TensorHandle for remote tensors which have not been |
| // copied to the local worker yet (e.g. remote function inputs). |
| const string& device_name = |
| in.op_device().empty() ? in.device() : in.op_device(); |
| TF_RETURN_IF_ERROR( |
| parent_->FindDeviceFromName(device_name.c_str(), &device)); |
| *out = TensorHandle::CreateLazyRemoteHandle(in.op_id(), in.output_num(), |
| in.dtype(), device, |
| /*is_ready=*/true, parent_); |
| std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes; |
| if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in), |
| &dtypes_and_shapes) |
| .ok()) { |
| for (const auto& dtype_and_shape_proto : |
| in.resource_dtypes_and_shapes()) { |
| dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{ |
| dtype_and_shape_proto.dtype(), |
| TensorShape(dtype_and_shape_proto.shape())}); |
| } |
| mutex_lock l(mirrored_resource_shape_mu_); |
| mirrored_resource_shape_map_.emplace( |
| RemoteTensorHandleInternal(in.op_id(), in.output_num()), |
| dtypes_and_shapes); |
| } |
| (*out)->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes)); |
| } |
| |
| return OkStatus(); |
| } |
| |
| EagerExecutor& RemoteMgr::GetOrCreateExecutorForStream(uint64 stream_id) { |
| mutex_lock l(executor_map_mu_); |
| auto it = executor_map_.find(stream_id); |
| if (it == executor_map_.end()) { |
| auto it_and_bool = executor_map_.emplace( |
| std::piecewise_construct, std::forward_as_tuple(stream_id), |
| std::forward_as_tuple(/*async=*/true)); |
| DCHECK(it_and_bool.second); |
| it = it_and_bool.first; |
| } |
| return it->second; |
| } |
| |
| void RemoteMgr::DeleteExecutorForStream(uint64 stream_id) { |
| mutex_lock l(executor_map_mu_); |
| auto it = executor_map_.find(stream_id); |
| if (it == executor_map_.end()) { |
| return; |
| } |
| Status s = it->second.ShutDown(); |
| if (!s.ok()) { |
| LOG(ERROR) << "EagerExecutor shutdown with error " << s.error_message(); |
| } |
| executor_map_.erase(it); |
| } |
| |
| } // namespace eager |
| } // namespace tensorflow |