| /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/common_runtime/eager/execute.h" |
| |
| #include <cstddef> |
| #include <vector> |
| |
| // clang-format off |
| // Required for IS_MOBILE_PLATFORM |
| #include "tensorflow/core/common_runtime/eager/eager_operation.h" |
| #include "tensorflow/core/framework/cancellation.h" |
| #include "tensorflow/core/framework/node_def.pb.h" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/tensor_shape.h" |
| #include "tensorflow/core/lib/core/refcount.h" |
| #include "tensorflow/core/platform/platform.h" |
| // clang-format on |
| |
| #include "absl/container/inlined_vector.h" |
| #include "absl/strings/match.h" |
| #include "absl/strings/str_cat.h" |
| #include "absl/types/optional.h" |
| #include "tensorflow/c/tf_tensor_internal.h" |
| #include "tensorflow/compiler/jit/defs.h" |
| #include "tensorflow/core/common_runtime/device.h" |
| #include "tensorflow/core/common_runtime/device_set.h" |
| #include "tensorflow/core/common_runtime/eager/context.h" |
| #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h" |
| #include "tensorflow/core/common_runtime/eager/execute_node.h" |
| #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" |
| #include "tensorflow/core/common_runtime/eager/tensor_handle.h" |
| #include "tensorflow/core/common_runtime/colocation_graph.h" |
| #include "tensorflow/core/framework/dataset.h" |
| #include "tensorflow/core/framework/function.h" |
| #include "tensorflow/core/framework/logging.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/tensor_reference.h" |
| #include "tensorflow/core/framework/types.pb.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/profiler/lib/traceme.h" |
| #include "tensorflow/core/protobuf/error_codes.pb.h" |
| #include "tensorflow/core/util/device_name_utils.h" |
| #if !defined(IS_MOBILE_PLATFORM) |
| #include "tensorflow/core/distributed_runtime/eager/eager_client.h" |
| #include "tensorflow/core/distributed_runtime/eager/remote_copy_node.h" |
| #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h" |
| #include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h" |
| #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" |
| #endif // IS_MOBILE_PLATFORM |
| #include "tensorflow/core/framework/step_stats.pb.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/lib/gtl/cleanup.h" |
| #include "tensorflow/core/lib/gtl/flatset.h" |
| #include "tensorflow/core/lib/random/random.h" |
| #include "tensorflow/core/platform/env.h" |
| #include "tensorflow/core/platform/mutex.h" |
| #include "tensorflow/core/util/ptr_util.h" |
| #include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h" |
| |
| namespace tensorflow { |
| |
| namespace { |
| |
| const string& DeviceNameOrUnspecified(Device* device) { |
| static string* unspecified_string = new string("<unspecified>"); |
| return (device == nullptr) ? *unspecified_string : device->name(); |
| } |
| |
| const string& DeviceNameOrUnspecified(VariantDevice device) { |
| if (VariantDeviceIsCustom(device)) { |
| return absl::get<CustomDevice*>(device)->name(); |
| } else { |
| return DeviceNameOrUnspecified(absl::get<Device*>(device)); |
| } |
| } |
| |
| // Returns whether a kernel should be cached. |
| bool KernelCacheEnabled(const OpDef& op_def) { |
| if (data::DatasetOpKernel::IsDatasetOp(&op_def)) { |
| return false; |
| } |
| // TODO(b/162540360): Revisit a way to mark kernels as uncachable once we have |
| // 5+ kernels to exclude. |
| return true; |
| } |
| |
| // This function expects *handle to point to an existing tensor handle that is |
| // currently on "handle_device", but where the operation expects that input to |
| // reside on "expected_input_device". The function will arrange for this |
| // transfer to happen and will return OK on success and will storage a new |
| // handle to the equivalent tensor on the correct device in "*result". Or if an |
| // error is encountered, it will return a non-OK status and set "*result" to |
| // nullptr. |
| // |
| // `op_device` is passed in explicitly because `op->device()` might be |
| // unset and we might have selected some specific device to run this op on. |
| Status CopyInputToExpectedDevice(EagerContext* ctx, EagerOperation* op, |
| Device* op_device, |
| TensorHandle* handle, // op->Inputs()[i] |
| int i, Device* handle_device, |
| Device* expected_input_device, |
| TensorHandle** result) { |
| // Should only be called when these don't match |
| DCHECK(expected_input_device != handle_device); |
| *result = nullptr; |
| const string& op_device_name = DeviceNameOrUnspecified(op_device); |
| |
| switch (ctx->GetDevicePlacementPolicy()) { |
| case DEVICE_PLACEMENT_SILENT_FOR_INT32: |
| // TODO(xpan): See if we could bubble python related error up |
| // to python level. |
| if (handle->dtype == DT_INT32) { |
| // Note: enabling silent copies of int32 tensors to match behavior |
| // of graph mode. |
| break; |
| } |
| TF_FALLTHROUGH_INTENDED; |
| case DEVICE_PLACEMENT_EXPLICIT: |
| // tf.identity is allowed to copy, as indicated in the error message |
| // below. |
| if (op->Name() == "Identity" || op->Name() == "IdentityN") { |
| break; |
| } |
| return errors::InvalidArgument( |
| "Tensors on conflicting devices:" |
| " cannot compute ", |
| op->Name(), " as input #", i, " was expected to be on ", |
| expected_input_device->name(), " but is actually on ", |
| handle_device->name(), " (operation running on ", op_device_name, ")", |
| " Tensors can be copied explicitly using:" |
| " `with tf.device(device_name): x = tf.identity(x)`" |
| " or transparently copied by using" |
| " tf.config.experimental.set_device_policy('silent')." |
| " Copying tensors between devices may slow down your model"); |
| case DEVICE_PLACEMENT_WARN: |
| LOG(WARNING) << "before computing " << op->Name() << " input #" << i |
| << " was expected to be on " << expected_input_device->name() |
| << " but is actually on " << handle_device->name() |
| << " (operation running on " << op_device_name |
| << "). This triggers a copy which can be a performance " |
| "bottleneck."; |
| break; |
| case DEVICE_PLACEMENT_SILENT: // Do nothing. |
| break; |
| } |
| // We are only here if the policy is warn or silent copies, so we should |
| // trigger a copy. |
| TensorHandle* result_handle = nullptr; |
| profiler::TraceMe activity( |
| [&] { |
| return absl::StrCat("_Send input ", i, " from ", handle_device->name(), |
| " to ", expected_input_device->name()); |
| }, |
| profiler::TraceMeLevel::kInfo); |
| Status status = |
| EagerCopyToDevice(handle, ctx, &op->Executor(), expected_input_device, |
| /* mirror= */ true, &result_handle); |
| activity.Stop(); |
| if (!status.ok()) { |
| return Status( |
| status.code(), |
| absl::StrCat("Failed copying input tensor from ", handle_device->name(), |
| " to ", expected_input_device->name(), " in order to run ", |
| op->Name(), ": ", status.error_message())); |
| } |
| |
| *result = result_handle; |
| |
| return Status::OK(); |
| } |
| |
| // `op_device_name` the name of the device on which the op will run, if any. |
| // For functions running using function library runtime, the device can be |
| // unspecified. |
| Status ValidateInputTypeAndPlacement( |
| EagerContext* ctx, EagerOperation* op, |
| const core::RefCountPtr<KernelAndDevice>& kernel) { |
| profiler::TraceMe activity("ValidateInputTypeAndPlacement", |
| profiler::TraceMeLevel::kInfo); |
| const int n_inputs = op->Inputs().size(); |
| if (kernel->num_inputs() != n_inputs) { |
| return errors::InvalidArgument("expected ", kernel->num_inputs(), |
| " inputs, got ", n_inputs); |
| } |
| const bool skip_remote_copy = |
| ctx->LazyCopyFunctionRemoteInputs() && kernel->IsFunction(); |
| if (n_inputs > 0) { |
| const DataType* input_types = &kernel->input_dtypes()[0]; |
| TensorHandle* const* handles = &op->Inputs()[0]; |
| for (int i = 0; i < n_inputs; ++i) { |
| TensorHandle* handle = handles[i]; |
| Device* expected_device = kernel->InputDevice(i); |
| if (!kernel->IsFunction() && handle->Type() == TensorHandle::PACKED) { |
| // Extract a handle on the op device from a packed input. |
| // This happens when a function is marked for XLA compilation. |
| // MaybePackInputTensor guarantees that a primitive op has no packed |
| // input at this point. |
| for (int j = 0; j < handle->NumPackedHandles(); ++j) { |
| TensorHandle* h = nullptr; |
| TF_RETURN_IF_ERROR(handle->ExtractPackedHandle(j, &h)); |
| if ((h->op_device() != nullptr) && |
| (h->op_device()->name() == op->DeviceName())) { |
| op->UpdateInput(i, h); |
| handle = h; |
| break; |
| } |
| } |
| } |
| auto handle_device_variant = handle->DeviceOrHostCPU(*ctx); |
| if (VariantDeviceIsCustom(handle_device_variant)) { |
| return errors::Unimplemented( |
| "Custom devices and remote execution are not yet supported " |
| "together."); |
| } |
| Device* handle_device = absl::get<Device*>(handle_device_variant); |
| const bool maybe_copy = |
| !skip_remote_copy || handle->Type() != TensorHandle::REMOTE; |
| // If the input is already on the right device, then nothing to do. |
| if (expected_device != handle_device && maybe_copy) { |
| TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(ctx, op, kernel->device(), |
| handle, i, handle_device, |
| expected_device, &handle)); |
| op->UpdateInput(i, handle); |
| // Unref handle since it has a ref as an input now |
| handle->Unref(); |
| } |
| if (handle->dtype != input_types[i]) { |
| return errors::InvalidArgument( |
| "cannot compute ", op->Name(), " as input #", i, "(zero-based)", |
| " was expected to be a ", DataTypeString(input_types[i]), |
| " tensor but is a ", DataTypeString(handle->dtype), " tensor"); |
| } |
| } |
| } |
| return Status::OK(); |
| } |
| |
| Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) { |
| const auto& node_def = op->MutableAttrs()->BuildNodeDef(); |
| const OpDef* op_def = nullptr; |
| |
| const FunctionDef* function_def = |
| op->EagerContext().FuncLibDef()->Find(op->Name()); |
| if (function_def != nullptr) { |
| op_def = &(function_def->signature()); |
| } else { |
| TF_RETURN_IF_ERROR(OpDefForOp(op->Name().c_str(), &op_def)); |
| } |
| |
| TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, *op_def, output_dtypes)); |
| |
| return Status::OK(); |
| } |
| |
| inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a, |
| const tensorflow::Fprint128& b) { |
| return {tensorflow::FingerprintCat64(a.low64, b.low64), |
| tensorflow::FingerprintCat64(a.high64, b.high64)}; |
| } |
| |
| inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a, |
| const int64 b) { |
| auto x = tensorflow::FingerprintCat64(a.low64, b); |
| return {x, tensorflow::FingerprintCat64(a.high64, x)}; |
| } |
| |
| Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle, |
| Device** result) { |
| if (TF_PREDICT_FALSE(VariantDeviceIsCustom(tensor_handle->device()))) { |
| return errors::Unimplemented( |
| "The kernel cache does not work with custom devices."); |
| } |
| Device* cpu_device = ctx.HostCPU(); |
| string device_name; |
| if (tensor_handle->Type() != TensorHandle::LOCAL) { |
| Device* device = absl::get<Device*>(tensor_handle->device()); |
| device_name = device != nullptr ? device->name() : cpu_device->name(); |
| *result = (device == nullptr ? cpu_device : device); |
| } else if (tensor_handle->dtype == DT_RESOURCE) { |
| // Use the resource's actual device because it is the device that will |
| // influence partitioning the multi-device function. |
| const Tensor* tensor; |
| // TODO(fishx): Avoid blocking here. |
| TF_RETURN_IF_ERROR(tensor_handle->Tensor(&tensor)); |
| const ResourceHandle& handle = tensor->flat<ResourceHandle>()(0); |
| device_name = handle.device(); |
| |
| Device* input_device; |
| TF_RETURN_IF_ERROR( |
| ctx.FindDeviceFromName(device_name.c_str(), &input_device)); |
| *result = input_device; |
| } else { |
| Device* device = absl::get<Device*>(tensor_handle->device()); |
| const bool is_tpu = device != nullptr && device->device_type() == "TPU"; |
| // int32 return values can be placed on TPUs. |
| const bool use_host_memory = |
| is_tpu ? MTypeFromDTypeIntsOnDevice(tensor_handle->dtype) |
| : MTypeFromDType(tensor_handle->dtype); |
| if (use_host_memory) { |
| *result = cpu_device; |
| } else { |
| device_name = device != nullptr ? device->name() : cpu_device->name(); |
| *result = (device == nullptr ? cpu_device : device); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| // Appends a TensorShape object to Fprint128 hash. |
| // For best performance, we would like to avoid dynamic memory allocation in |
| // this function. |
| // If "shape" has unknown rank, we attach "?" to hashed content; otherwise we |
| // attach every dim size to hashed content. |
| void AppendTensorShapeToFingerprint(const PartialTensorShape& shape, |
| Fprint128* fingerprint) { |
| if (shape.unknown_rank()) { |
| char c = '?'; |
| *fingerprint = FingerprintCat128(*fingerprint, c); |
| } else { |
| for (int i = 0; i < shape.dims(); i++) { |
| int64 dim = shape.dim_size(i); |
| *fingerprint = FingerprintCat128(*fingerprint, dim); |
| } |
| } |
| } |
| |
| Status GetFuncAttr(const EagerOperation* op, const EagerContext& ctx, |
| const char* attr_name, bool* value) { |
| Status status = op->Attrs().Get(attr_name, value); |
| if (status.ok()) { |
| DVLOG(2) << "Caller explicitly specifies " |
| << (attr_name ? "=true " : "=false, ") << op->DebugString(); |
| return Status::OK(); |
| } |
| |
| const FunctionDef* function_def = |
| ctx.pflr()->GetFunctionLibraryDefinition()->Find(op->Name()); |
| if (function_def == nullptr) { |
| return errors::NotFound("Failed to find function '", op->Name(), "'"); |
| } |
| |
| status = GetNodeAttr(AttrSlice(&function_def->attr()), attr_name, value); |
| if (status.ok()) { |
| DVLOG(2) << "Function definition explicitly specifies " |
| << (attr_name ? "=true" : "=false"); |
| return Status::OK(); |
| } |
| return status; |
| } |
| |
| Status MustCompileWithXLA(const EagerOperation* op, const EagerContext& ctx, |
| bool* compile_with_xla) { |
| if (!op->is_function()) { |
| *compile_with_xla = false; |
| return Status::OK(); |
| } |
| |
| if (op->remote_func_params().has_value() && |
| op->remote_func_params().value().step_id.has_value()) { |
| // If the op is a component of a multi-device function, don't compile it |
| // with XLA. |
| *compile_with_xla = false; |
| return Status::OK(); |
| } |
| |
| Status status = GetFuncAttr(op, ctx, kXlaMustCompileAttr, compile_with_xla); |
| if (status.ok()) { |
| return Status::OK(); |
| } |
| |
| // No explicit requests. Compile for XLA devices by default. |
| if (op->GetDeviceParsedName().type == "TPU" || |
| op->GetDeviceParsedName().type == "XLA_GPU" || |
| op->GetDeviceParsedName().type == "XLA_CPU") { |
| DVLOG(2) << "Compiling " << op->Name() |
| << " with XLA because it is running on an XLA device " |
| << op->GetDeviceParsedName().type; |
| *compile_with_xla = true; |
| } else { |
| *compile_with_xla = false; |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status GetOrCreateKernelAndDevice( |
| EagerOperation* op, TensorHandle** retvals, int* num_retvals, |
| core::RefCountPtr<KernelAndDevice>* out_kernel) { |
| EagerContext& ctx = op->EagerContext(); |
| Device* device = absl::get<Device*>(op->Device()); |
| |
| Fprint128 cache_key = op->MutableAttrs()->CacheKey(op->DeviceName()); |
| /// Include soft placement policy in cache key since the placement strategy |
| // can change and thus affect which kernel is picked. |
| cache_key = FingerprintCat128(cache_key, ctx.AllowSoftPlacement()); |
| |
| std::vector<Device*> input_dev_ptrs; |
| absl::flat_hash_map<string, const std::vector<string>*> composite_devices; |
| std::unordered_map<int, DtypeAndPartialTensorShape> |
| input_resource_variable_dtypes_and_shapes; |
| // We can eliminate some overhead by running simple functions using regular |
| // CallOp kernel. However, it is tricky to figure out which functions should |
| // be run using CallOp. Also, currently CallOp runs neither optimization |
| // passes (needed for TPU/XLA) nor grappler. |
| // Here are some cases where a function should be run in multi-device mode: |
| // - Function takes at least two resources on different devices. |
| // - Function takes a resource on deviceA and a body op explicitly placed |
| // on deviceB. |
| // - Function has a colocation constraint. |
| // - Function has an explicit device annotation (which might not be using |
| // full canonical device name) different from op_device. Note that false |
| // positives are ok. |
| // - Function has a node or a (node) attribute that can potentially make |
| // the function multi-device after a rewrite pass (e.g. various XLA/TPU |
| // special nodes and attributes) |
| if (op->is_function()) { |
| profiler::TraceMe activity("EagerCopyToDeviceAndAddCacheKey", |
| profiler::TraceMeLevel::kInfo); |
| input_dev_ptrs.reserve(op->Inputs().size()); |
| // When LazyCopyFunctionRemoteInputs is disabled, all inputs need to be on |
| // local devices, since we execute a remote function through worker service, |
| // which doesn't accept remote inputs. |
| for (int i = 0, end = op->Inputs().size(); i < end; i++) { |
| TensorHandle* input = op->Inputs()[i]; |
| if (!ctx.LazyCopyFunctionRemoteInputs() && |
| input->Type() == TensorHandle::REMOTE) { |
| TensorHandle* handle = nullptr; |
| TF_RETURN_IF_ERROR( |
| EagerCopyToDevice(input, &ctx, &op->Executor(), |
| device == nullptr ? ctx.HostCPU() : device, |
| /*mirror=*/true, &handle)); |
| op->UpdateInput(i, handle); |
| // Unref handle since it has a ref as an input now |
| handle->Unref(); |
| input = handle; |
| } |
| |
| // Get device for this input, and add it to 'cache_key'. |
| Device* input_device; |
| TF_RETURN_IF_ERROR(GetDeviceForInput(ctx, input, &input_device)); |
| input_dev_ptrs.push_back(input_device); |
| CompositeDevice* composite_device = nullptr; |
| if (ctx.FindCompositeDeviceFromName(input_device->name(), |
| &composite_device) |
| .ok()) { |
| composite_devices[input_device->name()] = |
| composite_device->underlying_devices(); |
| } |
| cache_key = |
| FingerprintCat128(cache_key, Fingerprint128(input_device->name())); |
| |
| // If input is a ResourceHandle, get its resource handle dtypes and shapes |
| // and add them to 'cache_key'. |
| if (input->dtype == DT_RESOURCE) { |
| // We only care about data type and shape for resource variable inputs. |
| // But we have no way to tell if input is resource variable (other than |
| // looking it up in ResourceMgr, which is slow). So we just get |
| // resource_dtypes_and_shapes for all DT_RESOURCE inputs. If |
| // resource_dtypes_and_shapes is not empty, take the first element. |
| std::vector<DtypeAndPartialTensorShape> resource_dtypes_and_shapes; |
| TF_RETURN_IF_ERROR(input->GetResourceHandleDtypesAndShapes( |
| &resource_dtypes_and_shapes)); |
| if (!resource_dtypes_and_shapes.empty()) { |
| const DtypeAndPartialTensorShape& dtype_and_shape = |
| resource_dtypes_and_shapes.at(0); |
| input_resource_variable_dtypes_and_shapes[i] = dtype_and_shape; |
| |
| // Add _Arg index, dtype and shape to "cache_key". |
| cache_key = FingerprintCat128(cache_key, i); |
| DataType dtype = dtype_and_shape.dtype; |
| cache_key = FingerprintCat128(cache_key, dtype); |
| AppendTensorShapeToFingerprint(dtype_and_shape.shape, &cache_key); |
| } |
| } |
| } |
| } |
| |
| core::RefCountPtr<KernelAndDevice> kernel = ctx.GetCachedKernel(cache_key); |
| if (kernel == nullptr) { |
| DVLOG(2) << "Creating new kernel for " << op->Name() << " on device " |
| << DeviceNameOrUnspecified(op->Device()); |
| bool run_function_with_flr = false; |
| bool function_outputs_on_op_device = false; |
| if (op->is_function()) { |
| bool compile_with_xla; |
| TF_RETURN_IF_ERROR(MustCompileWithXLA(op, ctx, &compile_with_xla)); |
| if (compile_with_xla) { |
| // Note that it is not ideal, but currently correct, to set this |
| // attribute after computing the kernel cache key above. |
| // Note: If the attribute is already set to true, this is a noop. |
| op->MutableAttrs()->Set(kXlaMustCompileAttr, true); |
| } else { |
| run_function_with_flr = true; |
| } |
| GetFuncAttr(op, ctx, kOutputsOnOpDevice, &function_outputs_on_op_device) |
| .IgnoreError(); |
| } |
| |
| const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef(); |
| if (device == nullptr) { |
| TF_RETURN_IF_ERROR( |
| ctx.SelectDevice(op->GetDeviceParsedName(), ndef, &device)); |
| |
| DVLOG(1) << "Placer place op [" << op->Name() |
| << "] on device: " << device->name(); |
| DVLOG(4) << "Available kernels for " << op->Name() << "are " |
| << KernelsRegisteredForOp(op->Name()); |
| op->SetDevice(device); |
| } |
| |
| FunctionLibraryRuntime* flr = |
| device == nullptr ? nullptr : ctx.func_lib(device); |
| if (device != nullptr && flr == nullptr) { |
| return errors::Unavailable( |
| "Unable to find a FunctionLibraryRuntime corresponding to device ", |
| device->name()); |
| } |
| auto runner = (flr != nullptr && flr->runner() != nullptr) ? flr->runner() |
| : ctx.runner(); |
| GraphCollector* graph_collector = nullptr; |
| if (ctx.ShouldStoreGraphs()) { |
| graph_collector = ctx.GetGraphCollector(); |
| } |
| // Treat the function as multi_device only when we are not compiling |
| // it wholly with XLA. When compiling wholly with XLA, flr->CreateKernel |
| // will create an XlaLaunchOp kernel to compile and run the function. |
| if (run_function_with_flr) { |
| // Multi-device functions don't use the rendezvous from eager context. |
| // If we use that rendezvous, multiple concurrent calls to the same |
| // function will likely result in collisions. However, this also means |
| // that we don't support legitimate sending/receiving across function |
| // boundary. |
| DVLOG(2) << "Running " << ndef.op() << " using multi-device function. " |
| << "Full node_def=" << ndef.DebugString(); |
| std::function<int64()> get_op_id = nullptr; |
| #if !defined(IS_MOBILE_PLATFORM) |
| if (ctx.LazyCopyFunctionRemoteInputs()) { |
| get_op_id = [&ctx]() { return ctx.RemoteMgr()->NextOpId(); }; |
| } |
| #endif // IS_MOBILE_PLATFORM |
| kernel.reset(new KernelAndDeviceFunc( |
| flr, ctx.pflr(), std::move(input_dev_ptrs), |
| std::move(composite_devices), |
| std::move(input_resource_variable_dtypes_and_shapes), runner, |
| ctx.GetCollectiveExecutorHandle(), ctx.HostCPU(), op->Name(), |
| function_outputs_on_op_device, |
| [&ctx](const int64 step_id) { return ctx.CreateRendezvous(step_id); }, |
| get_op_id)); |
| } else { |
| DVLOG(2) << "Running " << ndef.op() << " using op kernel. " |
| << ". Full node_def=" << ndef.DebugString(); |
| kernel.reset(new KernelAndDeviceOp( |
| ctx.GetRendezvous(), ctx.LogMemory(), flr, runner, |
| ctx.GetCollectiveExecutorHandle(), ctx.HostCPU())); |
| } |
| |
| TF_RETURN_IF_ERROR(kernel->Init( |
| {ctx.LogDevicePlacement(), ctx.LazyCopyFunctionRemoteInputs()}, ndef, |
| graph_collector)); |
| |
| if (op->is_function()) { |
| ctx.AddKernelToCache(cache_key, kernel.get()); |
| } else { |
| // Exclude tf.data op kernels from being cached. The reason for this is |
| // that tf.data op kernels that accept a user-defined function will have a |
| // unique cache key every time they are executed (because the user-defined |
| // function is traced every time). Caching such kernels provides no |
| // benefit and in some cases results in linear memory growth of use |
| // programs that build input pipeline graphs in a loop. |
| const OpDef* op_def; |
| TF_RETURN_IF_ERROR(OpDefForOp(op->Name().data(), &op_def)); |
| if (KernelCacheEnabled(*op_def)) { |
| ctx.AddKernelToCache(cache_key, kernel.get()); |
| } |
| } |
| } |
| |
| int num_outputs = kernel->num_outputs(); |
| if (num_outputs > *num_retvals) { |
| return errors::InvalidArgument("Expecting ", num_outputs, |
| " outputs, but *num_retvals is ", |
| *num_retvals); |
| } |
| *num_retvals = num_outputs; |
| |
| kernel->Ref(); // Ownership of reference is passed to out_kernel. |
| out_kernel->reset(kernel.get()); |
| return Status::OK(); |
| } |
| |
| Status CreateUnshapedOutput( |
| const KernelAndDevice& kernel, const int output_num, Device* output_device, |
| const DataType& output_dtype, |
| const absl::optional<EagerRemoteFunctionParams>& remote_func_params, |
| EagerContext* ctx, TensorHandle** output) { |
| #if defined(IS_MOBILE_PLATFORM) |
| return errors::Unimplemented( |
| "Remote outputs are not available on mobile devices."); |
| #else // !IS_MOBILE_PLATFORM |
| int64 op_id; |
| if (remote_func_params.has_value()) { |
| op_id = remote_func_params.value().op_id; |
| } else { |
| return errors::InvalidArgument( |
| "Unable to find a remote op id for a remote output of ", kernel.name()); |
| } |
| string remote_task; |
| if (!DeviceNameUtils::GetTaskName(output_device->parsed_name(), |
| &remote_task)) { |
| return errors::InvalidArgument( |
| "Unable to find remote task corresponding to device ", |
| output_device->name()); |
| } |
| if (ctx->RemoteMgr()->IsMaster()) { |
| *output = TensorHandle::CreateUnshapedRemoteHandle( |
| op_id, output_num, remote_task, output_dtype, output_device, ctx); |
| } else { |
| *output = TensorHandle::CreateLazyRemoteHandle(op_id, output_num, |
| output_dtype, output_device, |
| /*is_ready=*/false, ctx); |
| } |
| return Status::OK(); |
| #endif // !IS_MOBILE_PLATFORM |
| } |
| |
| Status AddOrExecuteNode(core::RefCountPtr<KernelAndDevice> kernel, |
| EagerOperation* op, TensorHandle** retvals) { |
| EagerExecutor& executor = op->Executor(); |
| EagerContext& ctx = op->EagerContext(); |
| GraphCollector* graph_collector = nullptr; |
| if (ctx.ShouldStoreGraphs()) { |
| graph_collector = ctx.GetGraphCollector(); |
| } |
| const int num_outputs = kernel->num_outputs(); |
| absl::optional<EagerRemoteFunctionParams> remote_func_params = |
| op->remote_func_params(); |
| if (kernel->IsCrossProcess() && !remote_func_params.has_value()) { |
| // Create an eager op id for a cross-process function if not exist. |
| #if defined(IS_MOBILE_PLATFORM) |
| return errors::Unimplemented( |
| "Cross-process functions are not supported on mobile devices."); |
| #else // !IS_MOBILE_PLATFORM |
| const int64 op_id = ctx.RemoteMgr()->NextOpId(); |
| remote_func_params = |
| EagerRemoteFunctionParams{op_id, /*step_id=*/absl::nullopt}; |
| #endif // !IS_MOBILE_PLATFORM |
| } |
| if (executor.Async()) { |
| const DataTypeVector& output_dtypes = kernel->output_dtypes(); |
| for (int i = 0, end = num_outputs; i < end; ++i) { |
| Device* output_device = ctx.CanonicalDevice(kernel->OutputDevice(i)); |
| if (output_device == nullptr || output_device->IsLocal()) { |
| retvals[i] = TensorHandle::CreateEmptyLocalHandle( |
| /* d= */ output_device, /* op_device= */ kernel->device(), |
| /* resource_device= */ kernel->OutputResourceDevice(i), |
| output_dtypes[i], &ctx); |
| } else { |
| TF_RETURN_IF_ERROR( |
| CreateUnshapedOutput(*kernel, i, output_device, output_dtypes[i], |
| remote_func_params, &ctx, &retvals[i])); |
| } |
| } |
| auto node = absl::make_unique<AsyncExecuteNode>( |
| &ctx, op->Inputs(), remote_func_params, std::move(kernel), |
| graph_collector, op->GetCancellationManager(), |
| absl::Span<TensorHandle*>(retvals, num_outputs), op->GetStackTrace()); |
| // Release the inputs from the eager operation since the AsyncExecuteNode |
| // would have taken ownership. This allows the inputs to be forwarded if |
| // possible. |
| op->Clear(); |
| // For async mode, execution order will make sure that all |
| // input handles are ready before executing them. |
| // TODO(b/137118203): Consider executing "cheap" kernels inline for |
| // performance. |
| return executor.AddOrExecute(std::move(node)); |
| } else { |
| for (int i = 0, end = num_outputs; i < end; ++i) { |
| retvals[i] = nullptr; |
| } |
| ExecuteNode node(&ctx, op->Inputs(), remote_func_params, kernel, |
| graph_collector, op->GetCancellationManager(), |
| {retvals, static_cast<size_t>(num_outputs)}); |
| Status s = executor.SyncExecute(&node); |
| // We release the inputs AFTER executing the operation in sync mode since |
| // ExecuteNode does not increment the reference count and thus does not have |
| // ownership of the inputs while executing. |
| op->Clear(); |
| return s; |
| } |
| } |
| |
| // There are a lot of references to devices in this function and around. |
| // Here is what they mean: |
| // EagerOperation::Device(): The device on which the user requested the op |
| // be executed, except if we had to change the device due to resource inputs |
| // or CPU pinning. If the user did not request a device, the op does not |
| // take resources, and we did not pin it to CPU, the device can be nullptr. |
| // KernelAndDevice::Device(): The first time we see an op (combined with |
| // its attributes), we need to create a KernelAndDevice object for it. |
| // If op->Device() is a nullptr, we select a device for the op when |
| // creating the KernelAndDevice. A concrete device will always be selected |
| // here except when `op` is a function to be executed using function library |
| // runtime. In this case, we don't select a device because running |
| // a function with explicitly requested device has different behavior than |
| // running without an explicitly requested device. |
| Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, |
| int* num_retvals) { |
| ScopedMemoryDebugAnnotation op_annotation( |
| op->op_name(), op->remote_func_params().has_value() |
| ? op->remote_func_params().value().step_id.value_or(0) |
| : 0); |
| profiler::TraceMe activity( |
| [&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); }, |
| profiler::TraceMeLevel::kInfo); |
| EagerContext& ctx = op->EagerContext(); |
| auto& executor = op->Executor(); |
| TF_RETURN_IF_ERROR(executor.status()); |
| |
| core::RefCountPtr<KernelAndDevice> kernel; |
| auto status = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel); |
| |
| // Run all the registered rewrite pass after the placement, regardless whether |
| // the placement is successful or not. The passes can either create new ops |
| // (without placement) or update some fields of the input op. |
| std::unique_ptr<tensorflow::EagerOperation> out_op; |
| TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite( |
| EagerOpRewriteRegistry::POST_PLACEMENT, op, &out_op)); |
| if (out_op) { |
| op = out_op.get(); |
| // If the out op doesn't have device, either because it is a new op or |
| // the op wasn't placed successfully, then we do the placement again. |
| if (op->Device() == kVariantDeviceNull) { |
| status = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel); |
| } |
| } |
| if (!status.ok()) return status; |
| |
| int num_outputs = kernel->num_outputs(); |
| TF_RETURN_IF_ERROR(ValidateInputTypeAndPlacement(&ctx, op, kernel)); |
| |
| if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) { |
| string msg = strings::StrCat("Executing op ", op->Name(), " in device ", |
| kernel->device()->name()); |
| if (!logging::LogToListeners(msg)) { |
| LOG(INFO) << msg; |
| } |
| } |
| |
| Status s = AddOrExecuteNode(std::move(kernel), op, retvals); |
| // Since the operation failed, we need to Unref any outputs if they were |
| // allocated. |
| if (!s.ok()) { |
| for (int i = 0, end = num_outputs; i < end; ++i) { |
| if (retvals[i] != nullptr) { |
| retvals[i]->Unref(); |
| } |
| } |
| } |
| |
| return s; |
| } |
| |
| // Run a Pack op to pack the tensors pointed by a packed input TensorHandle if |
| // the op is a primitive op. |
| Status MaybePackInputTensor(EagerOperation* op) { |
| if (op->is_function()) { |
| // Functions could take packed TensorHandles as inputs. |
| return Status::OK(); |
| } |
| EagerContext& ctx = op->EagerContext(); |
| for (int i = 0; i < op->Inputs().size(); ++i) { |
| TensorHandle* handle = op->Inputs()[i]; |
| if (handle->Type() == TensorHandle::PACKED) { |
| EagerOperation pack_op(&ctx); |
| TF_RETURN_IF_ERROR(pack_op.Reset("Pack", /*device_name=*/nullptr, |
| /*remote=*/false, /*executor=*/nullptr)); |
| pack_op.MutableAttrs()->Set("N", handle->NumPackedHandles()); |
| pack_op.MutableAttrs()->Set("T", handle->dtype); |
| for (int i = 0; i < handle->NumPackedHandles(); ++i) { |
| tensorflow::TensorHandle* h = nullptr; |
| TF_RETURN_IF_ERROR(handle->ExtractPackedHandle(i, &h)); |
| TF_RETURN_IF_ERROR(pack_op.AddInput(h)); |
| } |
| int num_retvals = 1; |
| absl::FixedArray<tensorflow::TensorHandle*> retvals(num_retvals); |
| TF_RETURN_IF_ERROR( |
| EagerLocalExecute(&pack_op, retvals.data(), &num_retvals)); |
| tensorflow::TensorHandle* ret = retvals.at(0); |
| op->UpdateInput(i, ret); |
| ret->Unref(); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| #if !defined(IS_MOBILE_PLATFORM) |
| void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) { |
| EagerContext& ctx = op->EagerContext(); |
| |
| remote_op->set_id(ctx.RemoteMgr()->NextOpId()); |
| remote_op->set_name(op->Name()); |
| |
| op->Attrs().FillAttrValueMapWithoutDefaults(remote_op->mutable_attrs()); |
| remote_op->set_device(absl::get<Device*>(op->Device())->name()); |
| remote_op->set_is_function(op->is_function()); |
| } |
| |
| Status StoreResourceDtypesAndShapes(const eager::Operation& remote_op, |
| const DataTypeVector& output_dtypes, |
| TensorHandle** retvals) { |
| if (remote_op.name() == "VarHandleOp") { |
| if (output_dtypes.size() != 1) { |
| return errors::Internal("VarHandleOp should only have one output."); |
| } |
| if (output_dtypes[0] != DT_RESOURCE) { |
| return errors::Internal( |
| "The output of VarHandleOp should be a DT_RESOURCE."); |
| } |
| AttrSlice attr_slice = AttrSlice(&remote_op.attrs()); |
| const AttrValue* dtype; |
| TF_RETURN_IF_ERROR(attr_slice.Find("dtype", &dtype)); |
| const AttrValue* shape; |
| TF_RETURN_IF_ERROR(attr_slice.Find("shape", &shape)); |
| retvals[0]->SetResourceHandleDtypeAndShape( |
| {DtypeAndPartialTensorShape{dtype->type(), shape->shape()}}); |
| } |
| return Status::OK(); |
| } |
| |
| Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, |
| int* num_retvals) { |
| EagerContext& ctx = op->EagerContext(); |
| |
| // TODO(fishx): Remove following code when lazy tensor copy is ready. |
| if (op->Device() == kVariantDeviceNull) { |
| tensorflow::Device* device = nullptr; |
| string device_name = op->DeviceName(); |
| TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(device_name.c_str(), &device)); |
| op->SetDevice(device); |
| } |
| |
| core::RefCountPtr<eager::EagerClient> eager_client; |
| uint64 context_id = ctx.GetContextId(); |
| TF_RETURN_IF_ERROR(ctx.GetClient(op->GetDeviceParsedName(), &eager_client)); |
| string remote_task; |
| if (!DeviceNameUtils::GetTaskName(op->GetDeviceParsedName(), &remote_task)) { |
| return errors::InvalidArgument( |
| "Unable to find remote task corresponding to device ", |
| VariantDeviceName(op->Device())); |
| } |
| |
| std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest); |
| request->set_context_id(context_id); |
| |
| eager::Operation* remote_op = request->add_queue()->mutable_operation(); |
| |
| tensorflow::Device* op_device = absl::get<Device*>(op->Device()); |
| { |
| profiler::TraceMe activity("CopyInputToExpectedDevice", |
| profiler::TraceMeLevel::kInfo); |
| const bool eagerly_copy_function_remote_inputs = |
| !ctx.LazyCopyFunctionRemoteInputs() || !op->is_function(); |
| for (int i = 0, end = op->Inputs().size(); i < end; i++) { |
| tensorflow::TensorHandle* input = op->Inputs()[i]; |
| tensorflow::Device* input_device = absl::get<Device*>(input->device()); |
| tensorflow::Device* input_device_or_cpu = |
| absl::get<Device*>(input->DeviceOrHostCPU(ctx)); |
| const string* input_device_name = &input_device_or_cpu->name(); |
| bool serialize_resource_dtype_and_shape = false; |
| if (op_device != input_device && |
| // If the expected and actual devices are on the same task, don't |
| // explicitly copy, and instead depend on the copy to happen locally |
| // when the op is executed on the device. |
| !ctx.OnSameTask(op_device, input_device)) { |
| if (eagerly_copy_function_remote_inputs || |
| input_device_or_cpu->IsLocal()) { |
| tensorflow::Device* remote_cpu_device; |
| TF_RETURN_IF_ERROR( |
| ctx.CPUDeviceOnTask(op_device, &remote_cpu_device)); |
| // TODO(b/110044833): It's possible the same tensor gets copied to the |
| // remote device repeatedly. |
| // Always copy to the remote CPU so that the actual device can be |
| // correctly determined after the kernel is selected/instantiated, |
| // since the op might have its inputs on host memory. |
| TensorHandle* handle = op->Inputs()[i]; |
| Device* handle_device = |
| absl::get<Device*>(handle->DeviceOrHostCPU(ctx)); |
| // If the input is already on the right device, then nothing to do. |
| if (remote_cpu_device != handle_device) { |
| TF_RETURN_IF_ERROR(CopyInputToExpectedDevice( |
| &ctx, op, op_device, handle, i, handle_device, |
| remote_cpu_device, &handle)); |
| op->UpdateInput(i, handle); |
| input = handle; |
| input_device = remote_cpu_device; |
| input_device_name = &remote_cpu_device->name(); |
| // Unref handle since it has a ref as an input now |
| handle->Unref(); |
| } |
| } else { |
| serialize_resource_dtype_and_shape = |
| (input->dtype == DT_RESOURCE) && |
| (!input->HasResourceShapeMirror(op_device, |
| ctx.GetContextViewId())); |
| } |
| } |
| auto* input_handle = remote_op->add_op_inputs()->mutable_remote_handle(); |
| // For a multi-device function, a remote RunComponentFunction request is |
| // not sent through StreamingEnqueueAsync. It could arrive at a remote |
| // worker before a remote execution request which produces an input of the |
| // component function. So we wait until the remote input is ready before |
| // serializing it. |
| const bool wait_until_ready = op->is_function(); |
| TF_RETURN_IF_ERROR(ctx.RemoteMgr()->SerializeRemoteTensorHandle( |
| input, wait_until_ready, input_handle, input_device, |
| *input_device_name, serialize_resource_dtype_and_shape)); |
| if (!input_handle->resource_dtypes_and_shapes().empty()) { |
| TF_RETURN_IF_ERROR( |
| input->AddResourceShapeMirror(op_device, input_handle->op_id(), |
| input_handle->output_num(), &ctx)); |
| } |
| } |
| } |
| |
| PrepareRemoteOp(remote_op, op); |
| |
| DataTypeVector output_dtypes; |
| TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes)); |
| |
| const size_t num_outputs = output_dtypes.size(); |
| if (num_outputs != *num_retvals) { |
| return errors::InvalidArgument( |
| "num_retvals does not match expected output dtypes"); |
| } |
| *num_retvals = num_outputs; |
| |
| const tensorflow::uint64 id = remote_op->id(); |
| for (size_t i = 0; i < num_outputs; ++i) { |
| // TODO(nareshmodi): Change the callback to instead add the decref to a |
| // list of pending decrefs that we can send as a batch with the next |
| // execute. |
| |
| // The device_ and resource_device_ of this TensorHandle might be |
| // incorrect. For multi-device functions, we don't know the output device |
| // until the function is instantiated on a remote worker. Luckily, we don't |
| // need to know the correct remote device here. We just need to know that it |
| // is remote. If we need copy this tensor to this process or run any ops |
| // which take this tensor as an input, block until the correct device is |
| // set. |
| const bool unknown_device = op->is_function(); |
| retvals[i] = TensorHandle::CreateUnshapedRemoteHandle( |
| id, i, remote_task, output_dtypes[i], op_device, &ctx, unknown_device); |
| } |
| |
| if (ctx.LazyCopyFunctionRemoteInputs()) { |
| // Store the data type and shape of a remote resource variable on the |
| // corresponding remote TensorHandle (output of 'VarHandleOp'). |
| // If the variable is an input of a remote function, the function may need |
| // the type and shape during function instantiation. When |
| // LazyCopyFunctionRemoteInputs is enabled, we no longer copy the resource |
| // handle (contains the type and shape) of the variable to the default |
| // function device. Instead, we store the type and shape on eager master |
| // and sent them to the default function device along with the |
| // EnqueueRequest. |
| TF_RETURN_IF_ERROR( |
| StoreResourceDtypesAndShapes(*remote_op, output_dtypes, retvals)); |
| } |
| |
| auto& executor = op->Executor(); |
| DVLOG(4) << "Execute remote eager op: " << op->Name() |
| << " (is async?: " << executor.Async() << ")."; |
| |
| std::unique_ptr<EagerNode> node(new eager::RemoteExecuteNode( |
| &op->EagerContext(), std::move(request), op_device, |
| ctx.GetContextViewId(), eager_client.get(), op->GetCancellationManager(), |
| op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(), |
| op->Inputs(), {retvals, num_outputs})); |
| |
| if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) { |
| string msg = strings::StrCat( |
| "Executing op ", op->Name(), " on task ", |
| DeviceNameUtils::ParsedNameToString(op->GetDeviceParsedName())); |
| if (!logging::LogToListeners(msg)) { |
| LOG(INFO) << msg; |
| } |
| } |
| |
| Status s = executor.AddOrExecute(std::move(node)); |
| // Since the operation failed, we need to Unref any outputs that were |
| // allocated. |
| if (!s.ok()) { |
| for (size_t i = 0; i < num_outputs; ++i) { |
| retvals[i]->Unref(); |
| } |
| } |
| |
| return s; |
| } |
| #endif // IS_MOBILE_PLATFORM |
| |
| Status GetKernelOutputs( |
| std::vector<EagerKernelRet>* outputs, int num_outputs, |
| TensorHandle** retvals, EagerContext* ctx, KernelAndDevice* kernel, |
| const absl::optional<EagerRemoteFunctionParams>& remote_func_params) { |
| for (int i = 0, end = num_outputs; i < end; ++i) { |
| if (retvals[i] == nullptr) { |
| EagerKernelRet& ret = (*outputs)[i]; |
| Device* output_device = ctx->CanonicalDevice(kernel->OutputDevice(i)); |
| if (ret.index() == 0) { |
| retvals[i] = TensorHandle::CreateLocalHandle( |
| std::move(absl::get<Tensor>(ret)), |
| /* d= */ output_device, |
| /* op_device= */ kernel->device(), |
| /* resource_device= */ kernel->OutputResourceDevice(i), ctx); |
| } else { |
| const DataTypeVector& output_dtypes = kernel->output_dtypes(); |
| TF_RETURN_IF_ERROR( |
| CreateUnshapedOutput(*kernel, i, output_device, output_dtypes[i], |
| remote_func_params, ctx, &retvals[i])); |
| #if !defined(IS_MOBILE_PLATFORM) |
| TF_RETURN_IF_ERROR( |
| retvals[i]->SetRemoteShape(absl::get<TensorShape>(ret), |
| output_device, ctx->GetContextViewId())); |
| #endif // IS_MOBILE_PLATFORM |
| } |
| } else { |
| if (!kernel->IsFunction() && |
| TF_PREDICT_FALSE(kernel->device() != retvals[i]->op_device())) { |
| return errors::Internal( |
| "Kernel output tensor handle has a different op device than the " |
| "kernel. This should never happen."); |
| } |
| if (TF_PREDICT_FALSE(ctx->CanonicalDevice(kernel->OutputDevice(i)) != |
| absl::get<Device*>(retvals[i]->device()))) { |
| return errors::Internal( |
| "Kernel output tensor handle locates on a different device than " |
| "the specified kernel output device. This should never happen."); |
| } |
| |
| EagerKernelRet& ret = (*outputs)[i]; |
| if (ret.index() == 0) { |
| TF_RETURN_IF_ERROR(retvals[i]->SetTensor( |
| std::move(absl::get<Tensor>(ret)), |
| ctx->CanonicalDevice(kernel->OutputDevice(i)))); |
| } else { |
| #if defined(IS_MOBILE_PLATFORM) |
| return errors::Unimplemented( |
| "Remote outputs are not available on mobile devices."); |
| #else // !IS_MOBILE_PLATFORM |
| TF_RETURN_IF_ERROR(retvals[i]->SetRemoteShape( |
| absl::get<TensorShape>(ret), |
| absl::get<Device*>(retvals[i]->device()), ctx->GetContextViewId())); |
| #endif // !IS_MOBILE_PLATFORM |
| } |
| } |
| } |
| return Status::OK(); |
| } |
| |
| void CollectGraphs(EagerContext* ctx) { |
| mutex_lock ml(*ctx->MetadataMu()); |
| |
| GraphCollector* collector = ctx->GetGraphCollector(); |
| mutex_lock mll(collector->mu); |
| |
| // Adding to partition graphs for backward compatibility. |
| for (const auto& graph : collector->partitioned_graphs) { |
| *ctx->RunMetadataProto()->add_partition_graphs() = graph; |
| } |
| |
| if (collector->dirty) { |
| auto* function_graphs = ctx->RunMetadataProto()->add_function_graphs(); |
| *function_graphs->mutable_post_optimization_graph() = |
| collector->optimized_graph; |
| *function_graphs->mutable_pre_optimization_graph() = collector->raw_graph; |
| for (const auto& graph : collector->partitioned_graphs) { |
| *function_graphs->add_partition_graphs() = graph; |
| } |
| } |
| |
| collector->ClearGraphs(); |
| } |
| } // namespace |
| |
| Status EagerExecute(EagerOperation* op, TensorHandle** retvals, |
| int* num_retvals) { |
| profiler::TraceMe activity( |
| [&] { return absl::StrCat("EagerExecute: ", op->Name()); }, |
| profiler::TraceMeLevel::kInfo); |
| |
| if (!op->Executor().Async()) { |
| // In sync mode, always clear error to maintain the same behavior as before. |
| // TODO(b/141004939): Remove this. |
| op->Executor().ClearError(); |
| } |
| |
| std::unique_ptr<tensorflow::EagerOperation> out_op; |
| TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite( |
| EagerOpRewriteRegistry::PRE_EXECUTION, op, &out_op)); |
| |
| if (op->IsLocal()) { |
| if (out_op) { |
| op = out_op.get(); |
| } |
| TF_RETURN_IF_ERROR(MaybePackInputTensor(op)); |
| return EagerLocalExecute(op, retvals, num_retvals); |
| } |
| |
| #if defined(IS_MOBILE_PLATFORM) |
| return errors::Unimplemented( |
| "Eager's remote execution is not available on mobile devices."); |
| #else // !IS_MOBILE_PLATFORM |
| if (out_op) { |
| op = out_op.get(); |
| } |
| return EagerRemoteExecute(op, retvals, num_retvals); |
| #endif // !IS_MOBILE_PLATFORM |
| } |
| |
| // TODO(gjn): Consider moving into ExecuteNode class |
| Status EagerKernelExecute( |
| EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& op_inputs, |
| const absl::optional<EagerRemoteFunctionParams>& remote_func_params, |
| const core::RefCountPtr<KernelAndDevice>& kernel, |
| GraphCollector* graph_collector, CancellationManager* cancellation_manager, |
| absl::Span<TensorHandle*> retvals) { |
| profiler::TraceMe activity("EagerKernelExecute", |
| profiler::TraceMeLevel::kInfo); |
| std::vector<EagerKernelRet> outputs(1); |
| |
| ExecuteNodeArgs inputs(op_inputs.size()); |
| TF_RETURN_IF_ERROR(inputs.Init(ctx, op_inputs, kernel)); |
| // TODO(apassos) figure out how to record stats for ops which are a part of |
| // functions. |
| // TODO(b/111859745): When we support recovering from kernel/device errors, we |
| // would need to call XlaDevice::EnsureDeviceContextOk() before using an XLA |
| // device. We don't call it now because it is an unneeded overhead (it |
| // acquires a lock) and we can't recover from errors anyway. |
| ScopedStepContainer* container = ctx->StepContainer(); |
| TF_RETURN_IF_ERROR(kernel->Run(container, inputs, &outputs, |
| cancellation_manager, remote_func_params)); |
| if (graph_collector != nullptr) { |
| CollectGraphs(ctx); |
| } |
| |
| if (TF_PREDICT_FALSE(retvals.size() != outputs.size())) { |
| return errors::Internal( |
| "EagerKernelExecute returns a list of ", outputs.size(), |
| " tensors but ", retvals.size(), |
| " is expected. This should never " |
| "happen. Please file a bug with the TensorFlow team."); |
| } |
| return GetKernelOutputs(&outputs, retvals.size(), retvals.data(), ctx, |
| kernel.get(), remote_func_params); |
| } |
| |
| namespace { |
| |
| Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, |
| EagerExecutor* executor, Device* dstd, |
| bool mirror, TensorHandle** result) { |
| TF_RETURN_IF_ERROR(executor->status()); |
| Device* d = ctx->CanonicalDevice(dstd); |
| if (mirror && h->HasLocalMirror(d)) { |
| h->Ref(); |
| *result = h; |
| return Status::OK(); |
| } |
| |
| bool async = executor->Async(); |
| if (mirror) { |
| h->Ref(); |
| *result = h; |
| |
| if (h->HasLocalMirror(d)) { |
| return Status::OK(); |
| } |
| |
| // We don't bother adding an empty local mirror in sync mode since we'll be |
| // executing the operation directly and be calling AddLocalMirror. A |
| // reference count is still needed which will be removed if the operation |
| // fails. |
| if (async) { |
| Status s = h->AddEmptyLocalMirror(d); |
| if (!s.ok()) { |
| // If a mirror was added since we called HasLocalMirror then just return |
| // since another thread has already added the mirror. |
| if (s.code() == error::Code::ALREADY_EXISTS) { |
| return Status::OK(); |
| } |
| |
| // Remove the previously added reference count since adding the mirror |
| // failed. |
| h->Unref(); |
| *result = nullptr; |
| return s; |
| } |
| } |
| } else { |
| *result = TensorHandle::CreateEmptyLocalHandle( |
| d, dstd, h->resource_device(), h->dtype, ctx); |
| } |
| |
| Status s; |
| if (async) { |
| // Note that `h` may not be currently ready. However execution order will |
| // make sure that `h` is ready before the copy is actually done. |
| std::unique_ptr<EagerNode> node( |
| new CopyToDeviceNode(h, *result, d, *ctx, async, mirror)); |
| s = executor->AddOrExecute(std::move(node)); |
| } else { |
| CopyToDeviceNode node(h, *result, d, *ctx, async, mirror); |
| s = executor->SyncExecute(&node); |
| } |
| |
| // Since the operation failed, we need to Unref any outputs that were |
| // allocated. |
| if (!s.ok()) { |
| (*result)->Unref(); |
| } |
| |
| return s; |
| } |
| |
| } // namespace |
| |
| Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, |
| EagerExecutor* executor, Device* device, bool mirror, |
| TensorHandle** result) { |
| TF_RETURN_IF_ERROR(h->WaitUnknownDevice()); |
| auto send_device = h->DeviceOrHostCPU(*ctx); |
| if (VariantDeviceIsCustom(send_device)) { |
| return errors::Unimplemented( |
| "Copying a TensorHandle from a custom device is not supported."); |
| } |
| bool sender_is_local = absl::get<Device*>(send_device)->IsLocal(); |
| |
| bool receiver_is_local = device->IsLocal(); |
| |
| if (!executor->Async()) { |
| // In sync mode, always clear error to maintain the same behavior as before. |
| // TODO(b/141004939): Remove this. |
| executor->ClearError(); |
| } |
| |
| if (sender_is_local && receiver_is_local) { |
| return LocalEagerCopyToDevice(h, ctx, executor, device, mirror, result); |
| } else { |
| #if defined(IS_MOBILE_PLATFORM) |
| return errors::Unimplemented( |
| "Eager's remote execution is not available on mobile devices."); |
| #else // !IS_MOBILE_PLATFORM |
| uint64 recv_op_id = 0; |
| if (receiver_is_local) { |
| Device* d = ctx->CanonicalDevice(device); |
| // TODO(gjn): Need to add support for async execution. Note if receiver |
| // is local, we need to first add support in TensorHandle to wait on local |
| // mirrors. |
| if (mirror) { |
| h->Ref(); |
| *result = h; |
| |
| if (h->HasLocalMirror(d)) { |
| return Status::OK(); |
| } |
| |
| Status s = h->AddEmptyLocalMirror(d); |
| if (!s.ok()) { |
| // If a mirror was added since we called HasLocalMirror then just |
| // return since another thread has already added the mirror. |
| if (s.code() == error::Code::ALREADY_EXISTS) { |
| return Status::OK(); |
| } |
| |
| // Remove the previously added reference count since adding the mirror |
| // failed. |
| h->Unref(); |
| *result = nullptr; |
| return s; |
| } |
| } else { |
| *result = TensorHandle::CreateEmptyLocalHandle( |
| /* d= */ d, /* op_device= */ device, |
| /*resource_device=*/nullptr, h->dtype, ctx); |
| } |
| } else { |
| if (mirror) { |
| if (h->HasRemoteMirror(device, ctx->GetContextViewId())) { |
| h->Ref(); |
| *result = h; |
| return Status::OK(); |
| } |
| } |
| string remote_task; |
| if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) { |
| return errors::InvalidArgument( |
| "Unable to find remote task corresponding to device ", |
| device->name()); |
| } |
| recv_op_id = ctx->RemoteMgr()->NextOpId(); |
| if (mirror) { |
| TF_RETURN_IF_ERROR(h->AddUnshapedRemoteMirror(device, recv_op_id, 0, |
| remote_task, ctx)); |
| h->Ref(); |
| *result = h; |
| } else { |
| *result = TensorHandle::CreateUnshapedRemoteHandle( |
| recv_op_id, 0, remote_task, h->dtype, device, ctx); |
| } |
| } |
| |
| auto node = std::make_unique<eager::RemoteCopyNode>( |
| ctx, executor, h, result[0], device, recv_op_id); |
| Status s = executor->AddOrExecute(std::move(node)); |
| if (!s.ok()) { |
| result[0]->Unref(); |
| } |
| return s; |
| #endif // !IS_MOBILE_PLATFORM |
| } |
| } |
| |
| namespace { |
| // Low-level utility function to execute the kernel specified by `kernel` on |
| // `kernel->device()`, with the provided inputs as `op_inputs` in the 'ctx'. |
| // Different from `EagerKernelExecute` that ties up the thread until the |
| // underlying function finishes execute, this function does not block the thread |
| // and could return before the function execution finishes. The provided |
| // `StatusCallback` will be triggered after function execution with its status. |
| void EagerKernelExecuteAsync( |
| EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& op_inputs, |
| const absl::optional<EagerRemoteFunctionParams>& remote_func_params, |
| const core::RefCountPtr<KernelAndDevice> kernel, |
| GraphCollector* graph_collector, CancellationManager* cancellation_manager, |
| TensorHandle** retvals, int num_outputs, StatusCallback done) { |
| auto inputs = std::make_shared<ExecuteNodeArgs>(op_inputs.size()); |
| auto outputs = std::make_shared<std::vector<EagerKernelRet>>(1); |
| |
| Status s = inputs->Init(ctx, op_inputs, kernel); |
| if (!s.ok()) { |
| done(s); |
| return; |
| } |
| |
| kernel->Ref(); // Ownership of reference is transferred to the callback |
| kernel->RunAsync( |
| ctx->StepContainer(), *inputs, outputs.get(), cancellation_manager, |
| remote_func_params, |
| [retvals, inputs, outputs, num_outputs, ctx, graph_collector, |
| remote_func_params, kernel_raw = kernel.get(), |
| done = std::move(done)](const Status& s) { |
| auto wrapped_done = [&](const Status& s) { |
| kernel_raw->Unref(); |
| done(s); |
| }; |
| if (!s.ok()) { |
| wrapped_done(s); |
| return; |
| } |
| if (graph_collector != nullptr) { |
| CollectGraphs(ctx); |
| } |
| DCHECK_EQ(num_outputs, outputs->size()); |
| wrapped_done(GetKernelOutputs(outputs.get(), num_outputs, retvals, ctx, |
| kernel_raw, remote_func_params)); |
| }); |
| } |
| } // namespace |
| |
| // Low-level utility to run the eager operation on local devices. Different from |
| // `EagerLocalExecute` which blocks and waits for the finishing the op |
| // execution, this method does not block the thread and could return before the |
| // eager operation execution finishes. The provided `StatusCallback` will be |
| // triggered after execution with its status. |
| void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals, |
| int* num_retvals, StatusCallback done) { |
| if (VariantDeviceIsCustom(op->Device())) { |
| done(errors::Unimplemented( |
| "Custom device is not supported in EagerLocalExecuteAsync.")); |
| return; |
| } |
| if (!op->IsLocal()) { |
| done(errors::InvalidArgument( |
| "Remote execution is not supported in async EagerLocalExecuteAsync")); |
| return; |
| } |
| |
| ScopedMemoryDebugAnnotation op_annotation( |
| op->op_name(), op->remote_func_params().has_value() |
| ? op->remote_func_params().value().step_id.value_or(0) |
| : 0); |
| profiler::TraceMe activity( |
| [&] { return absl::StrCat("EagerLocalExecuteAsync: ", op->Name()); }, |
| profiler::TraceMeLevel::kInfo); |
| EagerContext& ctx = op->EagerContext(); |
| |
| core::RefCountPtr<KernelAndDevice> kernel; |
| Status s = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel); |
| if (!s.ok()) { |
| done(s); |
| return; |
| } |
| |
| int num_outputs = kernel->num_outputs(); |
| s = ValidateInputTypeAndPlacement(&ctx, op, kernel); |
| if (!s.ok()) { |
| done(s); |
| return; |
| } |
| |
| if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) { |
| string msg = strings::StrCat("Executing op ", op->Name(), " in device ", |
| kernel->device()->name()); |
| if (!logging::LogToListeners(msg)) { |
| LOG(INFO) << msg; |
| } |
| } |
| |
| GraphCollector* graph_collector = nullptr; |
| if (ctx.ShouldStoreGraphs()) { |
| graph_collector = ctx.GetGraphCollector(); |
| } |
| |
| for (int i = 0, end = num_outputs; i < end; ++i) { |
| const DataTypeVector& output_dtypes = kernel->output_dtypes(); |
| retvals[i] = TensorHandle::CreateEmptyLocalHandle( |
| /* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)), |
| /* op_device= */ kernel->device(), |
| /* resource_device= */ kernel->OutputResourceDevice(i), |
| output_dtypes[i], &ctx); |
| } |
| |
| EagerKernelExecuteAsync( |
| &ctx, op->Inputs(), op->remote_func_params(), std::move(kernel), |
| graph_collector, op->GetCancellationManager(), retvals, num_outputs, |
| [op, num_outputs, retvals, done = std::move(done)](const Status& s) { |
| op->Clear(); |
| // Since the operation failed, we need to Unref any outputs if they were |
| // allocated. |
| if (!s.ok()) { |
| for (int i = 0, end = num_outputs; i < end; ++i) { |
| if (retvals[i] != nullptr) { |
| retvals[i]->Unref(); |
| } |
| } |
| } |
| done(s); |
| }); |
| } |
| } // namespace tensorflow |