| /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/common_runtime/direct_session.h" |
| |
| #include <atomic> |
| #include <string> |
| #include <vector> |
| |
| #include "tensorflow/core/common_runtime/collective_executor_mgr.h" |
| #include "tensorflow/core/common_runtime/collective_param_resolver_local.h" |
| #include "tensorflow/core/common_runtime/constant_folding.h" |
| #include "tensorflow/core/common_runtime/debugger_state_interface.h" |
| #include "tensorflow/core/common_runtime/device_factory.h" |
| #include "tensorflow/core/common_runtime/device_resolver_local.h" |
| #include "tensorflow/core/common_runtime/executor.h" |
| #include "tensorflow/core/common_runtime/executor_factory.h" |
| #include "tensorflow/core/common_runtime/function.h" |
| #include "tensorflow/core/common_runtime/graph_optimizer.h" |
| #include "tensorflow/core/common_runtime/memory_types.h" |
| #include "tensorflow/core/common_runtime/optimization_registry.h" |
| #include "tensorflow/core/common_runtime/process_util.h" |
| #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h" |
| #include "tensorflow/core/common_runtime/step_stats_collector.h" |
| #include "tensorflow/core/framework/function.h" |
| #include "tensorflow/core/framework/graph.pb_text.h" |
| #include "tensorflow/core/framework/graph.pb.h" |
| #include "tensorflow/core/framework/graph_def_util.h" |
| #include "tensorflow/core/framework/log_memory.h" |
| #include "tensorflow/core/framework/node_def.pb.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/framework/versions.pb.h" |
| #include "tensorflow/core/graph/algorithm.h" |
| #include "tensorflow/core/graph/graph.h" |
| #include "tensorflow/core/graph/graph_constructor.h" |
| #include "tensorflow/core/graph/graph_partition.h" |
| #include "tensorflow/core/graph/subgraph.h" |
| #include "tensorflow/core/graph/tensor_id.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/core/notification.h" |
| #include "tensorflow/core/lib/core/refcount.h" |
| #include "tensorflow/core/lib/core/status.h" |
| #include "tensorflow/core/lib/core/threadpool.h" |
| #include "tensorflow/core/lib/gtl/array_slice.h" |
| #include "tensorflow/core/lib/gtl/stl_util.h" |
| #include "tensorflow/core/lib/monitoring/counter.h" |
| #include "tensorflow/core/lib/strings/numbers.h" |
| #include "tensorflow/core/lib/strings/str_util.h" |
| #include "tensorflow/core/lib/strings/strcat.h" |
| #include "tensorflow/core/platform/byte_order.h" |
| #include "tensorflow/core/platform/device_tracer.h" |
| #include "tensorflow/core/platform/logging.h" |
| #include "tensorflow/core/platform/mutex.h" |
| #include "tensorflow/core/platform/types.h" |
| #include "tensorflow/core/util/device_name_utils.h" |
| #include "tensorflow/core/util/env_var.h" |
| |
| namespace tensorflow { |
| |
| namespace { |
| |
| auto* direct_session_runs = monitoring::Counter<0>::New( |
| "/tensorflow/core/direct_session_runs", |
| "The number of times DirectSession::Run() has been called."); |
| |
| Status NewThreadPoolFromThreadPoolOptions( |
| const SessionOptions& options, |
| const ThreadPoolOptionProto& thread_pool_options, int pool_number, |
| thread::ThreadPool** pool, bool* owned) { |
| int32 num_threads = thread_pool_options.num_threads(); |
| if (num_threads == 0) { |
| num_threads = NumInterOpThreadsFromSessionOptions(options); |
| } |
| const string& name = thread_pool_options.global_name(); |
| if (name.empty()) { |
| // Session-local threadpool. |
| VLOG(1) << "Direct session inter op parallelism threads for pool " |
| << pool_number << ": " << num_threads; |
| *pool = new thread::ThreadPool( |
| options.env, strings::StrCat("Compute", pool_number), num_threads); |
| *owned = true; |
| return Status::OK(); |
| } |
| |
| // Global, named threadpool. |
| typedef std::pair<int32, thread::ThreadPool*> MapValue; |
| static std::map<string, MapValue>* global_pool_map = |
| new std::map<string, MapValue>; |
| static mutex* mu = new mutex(); |
| mutex_lock l(*mu); |
| MapValue* mvalue = &(*global_pool_map)[name]; |
| if (mvalue->second == nullptr) { |
| mvalue->first = thread_pool_options.num_threads(); |
| mvalue->second = new thread::ThreadPool( |
| options.env, strings::StrCat("Compute", pool_number), num_threads); |
| } else { |
| if (mvalue->first != thread_pool_options.num_threads()) { |
| return errors::InvalidArgument( |
| "Pool ", name, |
| " configured previously with num_threads=", mvalue->first, |
| "; cannot re-configure with num_threads=", |
| thread_pool_options.num_threads()); |
| } |
| } |
| *owned = false; |
| *pool = mvalue->second; |
| return Status::OK(); |
| } |
| |
| thread::ThreadPool* GlobalThreadPool(const SessionOptions& options) { |
| static thread::ThreadPool* const thread_pool = |
| NewThreadPoolFromSessionOptions(options); |
| return thread_pool; |
| } |
| |
| // TODO(vrv): Figure out how to unify the many different functions |
| // that generate RendezvousKey, since many of them have to be |
| // consistent with each other. |
| string GetRendezvousKey(const string& tensor_name, |
| const DeviceAttributes& device_info, |
| const FrameAndIter& frame_iter) { |
| return strings::StrCat(device_info.name(), ";", |
| strings::FpToString(device_info.incarnation()), ";", |
| device_info.name(), ";", tensor_name, ";", |
| frame_iter.frame_id, ":", frame_iter.iter_id); |
| } |
| |
| } // namespace |
| |
| class DirectSessionFactory : public SessionFactory { |
| public: |
| DirectSessionFactory() {} |
| |
| bool AcceptsOptions(const SessionOptions& options) override { |
| return options.target.empty(); |
| } |
| |
| Status NewSession(const SessionOptions& options, |
| Session** out_session) override { |
| // Must do this before the CPU allocator is created. |
| if (options.config.graph_options().build_cost_model() > 0) { |
| EnableCPUAllocatorFullStats(true); |
| } |
| std::vector<Device*> devices; |
| TF_RETURN_IF_ERROR(DeviceFactory::AddDevices( |
| options, "/job:localhost/replica:0/task:0", &devices)); |
| |
| DirectSession* session = |
| new DirectSession(options, new DeviceMgr(devices), this); |
| { |
| mutex_lock l(sessions_lock_); |
| sessions_.push_back(session); |
| } |
| *out_session = session; |
| return Status::OK(); |
| } |
| |
| Status Reset(const SessionOptions& options, |
| const std::vector<string>& containers) override { |
| std::vector<DirectSession*> sessions_to_reset; |
| { |
| mutex_lock l(sessions_lock_); |
| // We create a copy to ensure that we don't have a deadlock when |
| // session->Close calls the DirectSessionFactory.Deregister, which |
| // acquires sessions_lock_. |
| std::swap(sessions_to_reset, sessions_); |
| } |
| Status s; |
| for (auto session : sessions_to_reset) { |
| s.Update(session->Reset(containers)); |
| } |
| // TODO(suharshs): Change the Reset behavior of all SessionFactories so that |
| // it doesn't close the sessions? |
| for (auto session : sessions_to_reset) { |
| s.Update(session->Close()); |
| } |
| return s; |
| } |
| |
| void Deregister(const DirectSession* session) { |
| mutex_lock l(sessions_lock_); |
| sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session), |
| sessions_.end()); |
| } |
| |
| private: |
| mutex sessions_lock_; |
| std::vector<DirectSession*> sessions_ GUARDED_BY(sessions_lock_); |
| }; |
| |
| class DirectSessionRegistrar { |
| public: |
| DirectSessionRegistrar() { |
| SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory()); |
| } |
| }; |
| static DirectSessionRegistrar registrar; |
| |
| std::atomic_int_fast64_t DirectSession::step_id_counter_(1); |
| |
| // NOTE: On Android with a single device, there is never |
| // a risk of an OpKernel blocking indefinitely: |
| // |
| // 1) No operations do I/O that depends on other simultaneous kernels, |
| // |
| // 2) Recv nodes always complete immediately: The inputs are sent into |
| // the local rendezvous before we start the executor, so the |
| // corresponding recvs will not block. |
| // |
| // Based on these assumptions, we can use the same thread pool for |
| // both "non-blocking" and "blocking" OpKernels on Android. |
| // |
| // This may change down the road when we add support for multiple |
| // devices that run concurrently, in which case we will need to |
| // revisit this decision. |
| void DirectSession::SchedClosure(thread::ThreadPool* pool, |
| std::function<void()> c) { |
| // TODO(sanjay): Get rid of __ANDROID__ path |
| #ifdef __ANDROID__ |
| // On Android, there is no implementation of ThreadPool that takes |
| // std::function, only Closure, which we cannot easily convert. |
| // |
| // Instead, we just run the function in-line, which is currently |
| // safe given the reasoning above. |
| c(); |
| #else |
| if (pool != nullptr) { |
| pool->Schedule(std::move(c)); |
| } else { |
| c(); |
| } |
| #endif // __ANDROID__ |
| } |
| |
| DirectSession::DirectSession(const SessionOptions& options, |
| const DeviceMgr* device_mgr, |
| DirectSessionFactory* const factory) |
| : options_(options), |
| device_mgr_(device_mgr), |
| factory_(factory), |
| cancellation_manager_(new CancellationManager()), |
| operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) { |
| const int thread_pool_size = |
| options_.config.session_inter_op_thread_pool_size(); |
| if (thread_pool_size > 0) { |
| for (int i = 0; i < thread_pool_size; ++i) { |
| thread::ThreadPool* pool = nullptr; |
| bool owned = false; |
| init_error_.Update(NewThreadPoolFromThreadPoolOptions( |
| options_, options_.config.session_inter_op_thread_pool(i), i, &pool, |
| &owned)); |
| thread_pools_.emplace_back(pool, owned); |
| } |
| } else if (options_.config.use_per_session_threads()) { |
| thread_pools_.emplace_back(NewThreadPoolFromSessionOptions(options_), |
| true /* owned */); |
| } else { |
| thread_pools_.emplace_back(GlobalThreadPool(options), false /* owned */); |
| } |
| // The default value of sync_on_finish will be flipped soon and this |
| // environment variable will be removed as well. |
| const Status status = |
| ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_); |
| if (!status.ok()) { |
| LOG(ERROR) << status.error_message(); |
| } |
| // NOTE(mrry): We do not need to use a unique string for the session |
| // handle, because DirectSession owns its devices. This may change |
| // in future versions. |
| session_handle_ = "direct"; |
| int devices_added = 0; |
| if (options.config.log_device_placement()) { |
| const string mapping_str = device_mgr_->DeviceMappingString(); |
| if (mapping_str.empty()) { |
| printf("Device mapping: no known devices.\n"); |
| } else { |
| printf("Device mapping:\n%s", mapping_str.c_str()); |
| } |
| LOG(INFO) << "Device mapping:\n" << mapping_str; |
| } |
| for (auto d : device_mgr_->ListDevices()) { |
| devices_.push_back(d); |
| device_set_.AddDevice(d); |
| d->op_segment()->AddHold(session_handle_); |
| |
| // The first device added is special: it is the 'client device' (a |
| // CPU device) from which we feed and fetch Tensors. |
| if (devices_added == 0) { |
| device_set_.set_client_device(d); |
| } |
| ++devices_added; |
| } |
| } |
| |
| DirectSession::~DirectSession() { |
| if (!closed_) Close().IgnoreError(); |
| for (auto& it : partial_runs_) { |
| it.second.reset(nullptr); |
| } |
| for (auto& it : executors_) { |
| it.second.reset(); |
| } |
| callables_.clear(); |
| for (auto d : device_mgr_->ListDevices()) { |
| d->op_segment()->RemoveHold(session_handle_); |
| } |
| for (auto d : device_mgr_->ListDevices()) { |
| d->ClearResourceMgr(); |
| } |
| functions_.clear(); |
| delete cancellation_manager_; |
| for (const auto& p_and_owned : thread_pools_) { |
| if (p_and_owned.second) delete p_and_owned.first; |
| } |
| |
| execution_state_.reset(nullptr); |
| flib_def_.reset(nullptr); |
| } |
| |
| Status DirectSession::MaybeInitializeExecutionState( |
| const GraphDef& graph, bool* out_already_initialized) { |
| // If already initialized, do nothing. |
| if (flib_def_ && execution_state_) { |
| *out_already_initialized = true; |
| return Status::OK(); |
| } |
| // Set up the per-session execution state. |
| // NOTE(mrry): The function library created here will be used for |
| // all subsequent extensions of the graph. |
| flib_def_.reset( |
| new FunctionLibraryDefinition(OpRegistry::Global(), graph.library())); |
| GraphExecutionStateOptions options; |
| options.device_set = &device_set_; |
| options.session_options = &options_; |
| // TODO(mrry,suharshs): We explicitly copy `graph` so that |
| // `MakeForBaseGraph()` can take ownership of its |
| // contents. Previously this happened implicitly in calls to the |
| // `GraphExecutionState`. Other sessions call |
| // `MakeForBaseGraph` in such a way that we can destructively read |
| // the passed-in `GraphDef`. In principle we could do the same here, |
| // with a wider refactoring; we might revise the direct session so |
| // that it copies the graph fewer times. |
| GraphDef temp(graph); |
| TF_RETURN_IF_ERROR( |
| GraphExecutionState::MakeForBaseGraph(&temp, options, &execution_state_)); |
| graph_created_ = true; |
| *out_already_initialized = false; |
| return Status::OK(); |
| } |
| |
| Status DirectSession::Create(const GraphDef& graph) { |
| TF_RETURN_IF_ERROR(init_error_); |
| if (graph.node_size() > 0) { |
| mutex_lock l(graph_def_lock_); |
| if (graph_created_) { |
| return errors::AlreadyExists( |
| "A Graph has already been created for this session."); |
| } |
| return ExtendLocked(graph); |
| } |
| return Status::OK(); |
| } |
| |
| Status DirectSession::Extend(const GraphDef& graph) { |
| TF_RETURN_IF_ERROR(CheckNotClosed()); |
| mutex_lock l(graph_def_lock_); |
| return ExtendLocked(graph); |
| } |
| |
| Status DirectSession::ExtendLocked(const GraphDef& graph) { |
| bool already_initialized; |
| // If this is the first call, we can initialize the execution state |
| // with `graph` and do not need to call `Extend()`. |
| TF_RETURN_IF_ERROR( |
| MaybeInitializeExecutionState(graph, &already_initialized)); |
| if (already_initialized) { |
| TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library())); |
| std::unique_ptr<GraphExecutionState> state; |
| TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state)); |
| execution_state_.swap(state); |
| } |
| return Status::OK(); |
| } |
| |
| Status DirectSession::Run(const NamedTensorList& inputs, |
| const std::vector<string>& output_names, |
| const std::vector<string>& target_nodes, |
| std::vector<Tensor>* outputs) { |
| RunMetadata run_metadata; |
| return Run(RunOptions(), inputs, output_names, target_nodes, outputs, |
| &run_metadata); |
| } |
| |
| Status DirectSession::CreateDebuggerState( |
| const CallableOptions& callable_options, int64 global_step, |
| int64 session_run_index, int64 executor_step_index, |
| std::unique_ptr<DebuggerStateInterface>* debugger_state) { |
| TF_RETURN_IF_ERROR(DebuggerStateRegistry::CreateState( |
| callable_options.run_options().debug_options(), debugger_state)); |
| std::vector<string> input_names(callable_options.feed().begin(), |
| callable_options.feed().end()); |
| std::vector<string> output_names(callable_options.fetch().begin(), |
| callable_options.fetch().end()); |
| std::vector<string> target_names(callable_options.target().begin(), |
| callable_options.target().end()); |
| |
| TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata( |
| global_step, session_run_index, executor_step_index, input_names, |
| output_names, target_names)); |
| return Status::OK(); |
| } |
| |
| Status DirectSession::DecorateAndPublishGraphForDebug( |
| const DebugOptions& debug_options, Graph* graph, Device* device) { |
| std::unique_ptr<DebugGraphDecoratorInterface> decorator; |
| TF_RETURN_IF_ERROR( |
| DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator)); |
| |
| TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device)); |
| TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name())); |
| return Status::OK(); |
| } |
| |
| Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options, |
| CallFrameInterface* call_frame, |
| ExecutorsAndKeys* executors_and_keys, |
| RunMetadata* run_metadata) { |
| const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1); |
| |
| std::unique_ptr<DebuggerStateInterface> debugger_state; |
| if (!run_options.debug_options().debug_tensor_watch_opts().empty()) { |
| TF_RETURN_IF_ERROR( |
| CreateDebuggerState(executors_and_keys->callable_options, |
| run_options.debug_options().global_step(), step_id, |
| executor_step_count, &debugger_state)); |
| } |
| |
| // Create a run state and start execution. |
| RunState run_state(step_id, &devices_); |
| run_state.rendez = new IntraProcessRendezvous(device_mgr_.get()); |
| #ifndef __ANDROID__ |
| // Set up for collectives if ExecutorsAndKeys declares a key. |
| if (executors_and_keys->collective_graph_key != |
| BuildGraphOptions::kNoCollectiveGraphKey) { |
| if (run_options.experimental().collective_graph_key() != |
| BuildGraphOptions::kNoCollectiveGraphKey) { |
| // If a collective_graph_key was specified in run_options, ensure that it |
| // matches what came out of GraphExecutionState::BuildGraph(). |
| if (run_options.experimental().collective_graph_key() != |
| executors_and_keys->collective_graph_key) { |
| return errors::Internal( |
| "collective_graph_key in RunOptions ", |
| run_options.experimental().collective_graph_key(), |
| " should match collective_graph_key from optimized graph ", |
| executors_and_keys->collective_graph_key); |
| } |
| } |
| if (!collective_executor_mgr_) { |
| std::unique_ptr<DeviceResolverInterface> drl( |
| new DeviceResolverLocal(device_mgr_.get())); |
| std::unique_ptr<ParamResolverInterface> cprl( |
| new CollectiveParamResolverLocal(device_mgr_.get(), drl.get(), |
| "/job:localhost/replica:0/task:0")); |
| collective_executor_mgr_.reset(new CollectiveExecutorMgr( |
| options_.config, device_mgr_.get(), std::move(drl), std::move(cprl))); |
| } |
| run_state.collective_executor.reset(new CollectiveExecutor::Handle( |
| collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/)); |
| } |
| #endif |
| |
| // Start parallel Executors. |
| const size_t num_executors = executors_and_keys->items.size(); |
| ExecutorBarrier* barrier = new ExecutorBarrier( |
| num_executors, run_state.rendez, [&run_state](const Status& ret) { |
| { |
| mutex_lock l(run_state.mu_); |
| run_state.status.Update(ret); |
| } |
| run_state.executors_done.Notify(); |
| }); |
| |
| Executor::Args args; |
| args.step_id = step_id; |
| args.call_frame = call_frame; |
| args.rendezvous = run_state.rendez; |
| args.collective_executor = |
| (run_state.collective_executor ? run_state.collective_executor->get() |
| : nullptr); |
| CancellationManager step_cancellation_manager; |
| args.cancellation_manager = &step_cancellation_manager; |
| args.session_state = &session_state_; |
| args.tensor_store = &run_state.tensor_store; |
| args.step_container = &run_state.step_container; |
| args.sync_on_finish = sync_on_finish_; |
| |
| const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE); |
| |
| bool update_cost_model = false; |
| if (options_.config.graph_options().build_cost_model() > 0) { |
| const int64 build_cost_model_every = |
| options_.config.graph_options().build_cost_model(); |
| const int64 build_cost_model_after = |
| options_.config.graph_options().build_cost_model_after(); |
| int64 measure_step_count = executor_step_count - build_cost_model_after; |
| if (measure_step_count >= 0) { |
| update_cost_model = |
| ((measure_step_count + 1) % build_cost_model_every == 0); |
| } |
| } |
| if (do_trace || update_cost_model || |
| run_options.report_tensor_allocations_upon_oom()) { |
| run_state.collector.reset( |
| new StepStatsCollector(run_metadata->mutable_step_stats())); |
| args.stats_collector = run_state.collector.get(); |
| } |
| |
| std::unique_ptr<DeviceTracer> tracer; |
| if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) { |
| tracer = CreateDeviceTracer(); |
| // tracer may be NULL on platforms without accelerators. |
| if (tracer) { |
| Status s = tracer->Start(); |
| if (!s.ok()) { |
| run_state.executors_done.Notify(); |
| delete barrier; |
| return s; |
| } |
| } |
| } |
| |
| if (run_options.inter_op_thread_pool() < -1 || |
| run_options.inter_op_thread_pool() >= |
| static_cast<int32>(thread_pools_.size())) { |
| run_state.executors_done.Notify(); |
| delete barrier; |
| return errors::InvalidArgument("Invalid inter_op_thread_pool: ", |
| run_options.inter_op_thread_pool()); |
| } |
| |
| // Register this step with session's cancellation manager, so that |
| // `Session::Close()` will cancel the step. |
| const CancellationToken cancellation_token = |
| cancellation_manager_->get_cancellation_token(); |
| const bool already_cancelled = !cancellation_manager_->RegisterCallback( |
| cancellation_token, [&step_cancellation_manager]() { |
| step_cancellation_manager.StartCancel(); |
| }); |
| if (already_cancelled) { |
| // NOTE(mrry): If we don't explicitly notify |
| // `run_state.executors_done`, the RunState destructor would |
| // block on this notification. |
| run_state.executors_done.Notify(); |
| delete barrier; |
| return errors::Cancelled("Run call was cancelled"); |
| } |
| |
| thread::ThreadPool* pool = |
| run_options.inter_op_thread_pool() >= 0 |
| ? thread_pools_[run_options.inter_op_thread_pool()].first |
| : nullptr; |
| |
| if (pool == nullptr) { |
| // We allow using the caller thread only when having a single executor |
| // specified. |
| if (executors_and_keys->items.size() > 1) { |
| pool = thread_pools_[0].first; |
| } else { |
| VLOG(1) << "Executing Session::Run() synchronously!"; |
| } |
| } |
| |
| Executor::Args::Runner default_runner = [this, |
| pool](Executor::Args::Closure c) { |
| SchedClosure(pool, std::move(c)); |
| }; |
| for (const auto& item : executors_and_keys->items) { |
| // TODO(zhengxq): support partial run. |
| // TODO(zhengxq): if the device picks its own threadpool, we need to assign |
| // less threads to the main compute pool by default. |
| thread::ThreadPool* device_thread_pool = |
| item.device->tensorflow_device_thread_pool(); |
| if (!device_thread_pool) { |
| args.runner = default_runner; |
| } else { |
| args.runner = [this, device_thread_pool](Executor::Args::Closure c) { |
| SchedClosure(device_thread_pool, std::move(c)); |
| }; |
| } |
| item.executor->RunAsync(args, barrier->Get()); |
| } |
| |
| WaitForNotification(&run_state, &step_cancellation_manager, |
| run_options.timeout_in_ms() > 0 |
| ? run_options.timeout_in_ms() |
| : operation_timeout_in_ms_); |
| |
| if (!cancellation_manager_->DeregisterCallback(cancellation_token)) { |
| // The step has been cancelled: make sure we don't attempt to receive the |
| // outputs as this would make it block forever. |
| mutex_lock l(run_state.mu_); |
| run_state.status.Update(errors::Cancelled("Run call was cancelled")); |
| } |
| |
| if (tracer) { |
| TF_RETURN_IF_ERROR(tracer->Stop()); |
| TF_RETURN_IF_ERROR(tracer->Collect(run_state.collector.get())); |
| } |
| |
| { |
| mutex_lock l(run_state.mu_); |
| TF_RETURN_IF_ERROR(run_state.status); |
| } |
| |
| // Save the output tensors of this run we choose to keep. |
| if (!run_state.tensor_store.empty()) { |
| TF_RETURN_IF_ERROR(run_state.tensor_store.SaveTensors( |
| {executors_and_keys->callable_options.fetch().begin(), |
| executors_and_keys->callable_options.fetch().end()}, |
| &session_state_)); |
| } |
| |
| if (run_state.collector) { |
| run_state.collector->Finalize(); |
| } |
| |
| // Build and return the cost model as instructed. |
| if (update_cost_model) { |
| // Build the cost model |
| std::unordered_map<string, const Graph*> device_to_graph; |
| for (const PerPartitionExecutorsAndLib& partition : |
| executors_and_keys->items) { |
| const Graph* graph = partition.graph; |
| const string device = partition.flib->device()->name(); |
| device_to_graph[device] = graph; |
| } |
| |
| mutex_lock l(executor_lock_); |
| run_state.collector->BuildCostModel(&cost_model_manager_, device_to_graph); |
| |
| // annotate stats onto cost graph. |
| CostGraphDef* cost_graph = run_metadata->mutable_cost_graph(); |
| for (const auto& item : executors_and_keys->items) { |
| TF_RETURN_IF_ERROR( |
| cost_model_manager_.AddToCostGraphDef(item.graph, cost_graph)); |
| } |
| } |
| |
| // If requested via RunOptions, output the partition graphs. |
| if (run_options.output_partition_graphs()) { |
| protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs = |
| run_metadata->mutable_partition_graphs(); |
| for (const PerPartitionExecutorsAndLib& exec_and_lib : |
| executors_and_keys->items) { |
| GraphDef* partition_graph_def = partition_graph_defs->Add(); |
| exec_and_lib.graph->ToGraphDef(partition_graph_def); |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status DirectSession::Run(const RunOptions& run_options, |
| const NamedTensorList& inputs, |
| const std::vector<string>& output_names, |
| const std::vector<string>& target_nodes, |
| std::vector<Tensor>* outputs, |
| RunMetadata* run_metadata) { |
| TF_RETURN_IF_ERROR(CheckNotClosed()); |
| TF_RETURN_IF_ERROR(CheckGraphCreated("Run()")); |
| direct_session_runs->GetCell()->IncrementBy(1); |
| |
| // Extract the inputs names for this run of the session. |
| std::vector<string> input_tensor_names; |
| input_tensor_names.reserve(inputs.size()); |
| for (const auto& it : inputs) { |
| input_tensor_names.push_back(it.first); |
| } |
| |
| // Check if we already have an executor for these arguments. |
| ExecutorsAndKeys* executors_and_keys; |
| RunStateArgs run_state_args(run_options.debug_options()); |
| run_state_args.collective_graph_key = |
| run_options.experimental().collective_graph_key(); |
| |
| TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names, |
| target_nodes, &executors_and_keys, |
| &run_state_args)); |
| { |
| mutex_lock l(collective_graph_key_lock_); |
| collective_graph_key_ = executors_and_keys->collective_graph_key; |
| } |
| |
| // Configure a call frame for the step, which we use to feed and |
| // fetch values to and from the executors. |
| FunctionCallFrame call_frame(executors_and_keys->input_types, |
| executors_and_keys->output_types); |
| gtl::InlinedVector<Tensor, 4> feed_args(inputs.size()); |
| for (const auto& it : inputs) { |
| if (it.second.dtype() == DT_RESOURCE) { |
| Tensor tensor_from_handle; |
| TF_RETURN_IF_ERROR( |
| ResourceHandleToInputTensor(it.second, &tensor_from_handle)); |
| feed_args[executors_and_keys->input_name_to_index[it.first]] = |
| tensor_from_handle; |
| } else { |
| feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second; |
| } |
| } |
| const Status s = call_frame.SetArgs(feed_args); |
| if (errors::IsInternal(s)) { |
| return errors::InvalidArgument(s.error_message()); |
| } else if (!s.ok()) { |
| return s; |
| } |
| |
| const int64 step_id = step_id_counter_.fetch_add(1); |
| |
| if (LogMemory::IsEnabled()) { |
| LogMemory::RecordStep(step_id, run_state_args.handle); |
| } |
| |
| TF_RETURN_IF_ERROR(RunInternal(step_id, run_options, &call_frame, |
| executors_and_keys, run_metadata)); |
| |
| // Receive outputs. |
| if (outputs) { |
| std::vector<Tensor> sorted_outputs; |
| const Status s = call_frame.ConsumeRetvals( |
| &sorted_outputs, /* allow_dead_tensors = */ false); |
| if (errors::IsInternal(s)) { |
| return errors::InvalidArgument(s.error_message()); |
| } else if (!s.ok()) { |
| return s; |
| } |
| const bool unique_outputs = |
| output_names.size() == executors_and_keys->output_name_to_index.size(); |
| // first_indices[i] = j implies that j is the smallest value for which |
| // output_names[i] == output_names[j]. |
| std::vector<int> first_indices; |
| if (!unique_outputs) { |
| first_indices.resize(output_names.size()); |
| for (int i = 0; i < output_names.size(); ++i) { |
| for (int j = 0; j <= i; ++j) { |
| if (output_names[i] == output_names[j]) { |
| first_indices[i] = j; |
| break; |
| } |
| } |
| } |
| } |
| outputs->clear(); |
| outputs->reserve(sorted_outputs.size()); |
| for (int i = 0; i < output_names.size(); ++i) { |
| const string& output_name = output_names[i]; |
| if (first_indices.empty() || first_indices[i] == i) { |
| outputs->emplace_back( |
| std::move(sorted_outputs[executors_and_keys |
| ->output_name_to_index[output_name]])); |
| } else { |
| outputs->push_back((*outputs)[first_indices[i]]); |
| } |
| } |
| } |
| |
| return Status::OK(); |
| } |
| |
| Status DirectSession::PRunSetup(const std::vector<string>& input_names, |
| const std::vector<string>& output_names, |
| const std::vector<string>& target_nodes, |
| string* handle) { |
| TF_RETURN_IF_ERROR(CheckNotClosed()); |
| TF_RETURN_IF_ERROR(CheckGraphCreated("PRunSetup()")); |
| |
| // RunOptions is not available in PRunSetup, so use thread pool 0. |
| thread::ThreadPool* pool = thread_pools_[0].first; |
| |
| // Check if we already have an executor for these arguments. |
| ExecutorsAndKeys* executors_and_keys; |
| // TODO(cais): TFDBG support for partial runs. |
| DebugOptions debug_options; |
| RunStateArgs run_state_args(debug_options); |
| run_state_args.is_partial_run = true; |
| TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names, |
| target_nodes, &executors_and_keys, |
| &run_state_args)); |
| |
| // Create the run state and save it for future PRun calls. |
| Executor::Args args; |
| args.step_id = step_id_counter_.fetch_add(1); |
| RunState* run_state = |
| new RunState(input_names, output_names, args.step_id, &devices_); |
| run_state->rendez = new IntraProcessRendezvous(device_mgr_.get()); |
| { |
| mutex_lock l(executor_lock_); |
| if (!partial_runs_ |
| .emplace(run_state_args.handle, |
| std::unique_ptr<RunState>(run_state)) |
| .second) { |
| return errors::Internal("The handle '", run_state_args.handle, |
| "' created for this partial run is not unique."); |
| } |
| } |
| |
| // Start parallel Executors. |
| const size_t num_executors = executors_and_keys->items.size(); |
| ExecutorBarrier* barrier = new ExecutorBarrier( |
| num_executors, run_state->rendez, [run_state](const Status& ret) { |
| if (!ret.ok()) { |
| mutex_lock l(run_state->mu_); |
| run_state->status.Update(ret); |
| } |
| run_state->executors_done.Notify(); |
| }); |
| |
| args.rendezvous = run_state->rendez; |
| args.cancellation_manager = cancellation_manager_; |
| // Note that Collectives are not supported in partial runs |
| // because RunOptions is not passed in so we can't know whether |
| // their use is intended. |
| args.collective_executor = nullptr; |
| args.runner = [this, pool](Executor::Args::Closure c) { |
| SchedClosure(pool, std::move(c)); |
| }; |
| args.session_state = &session_state_; |
| args.tensor_store = &run_state->tensor_store; |
| args.step_container = &run_state->step_container; |
| if (LogMemory::IsEnabled()) { |
| LogMemory::RecordStep(args.step_id, run_state_args.handle); |
| } |
| args.sync_on_finish = sync_on_finish_; |
| |
| if (options_.config.graph_options().build_cost_model()) { |
| run_state->collector.reset(new StepStatsCollector(nullptr)); |
| args.stats_collector = run_state->collector.get(); |
| } |
| |
| for (auto& item : executors_and_keys->items) { |
| item.executor->RunAsync(args, barrier->Get()); |
| } |
| |
| *handle = run_state_args.handle; |
| return Status::OK(); |
| } |
| |
| Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs, |
| const std::vector<string>& output_names, |
| std::vector<Tensor>* outputs) { |
| TF_RETURN_IF_ERROR(CheckNotClosed()); |
| std::vector<string> parts = str_util::Split(handle, ';'); |
| const string& key = parts[0]; |
| // Get the executors for this partial run. |
| ExecutorsAndKeys* executors_and_keys; |
| RunState* run_state; |
| { |
| mutex_lock l(executor_lock_); // could use reader lock |
| auto exc_it = executors_.find(key); |
| if (exc_it == executors_.end()) { |
| return errors::InvalidArgument( |
| "Must run 'setup' before performing partial runs!"); |
| } |
| executors_and_keys = exc_it->second.get(); |
| |
| auto prun_it = partial_runs_.find(handle); |
| if (prun_it == partial_runs_.end()) { |
| return errors::InvalidArgument( |
| "Must run 'setup' before performing partial runs!"); |
| } |
| run_state = prun_it->second.get(); |
| |
| // Make sure that this is a new set of feeds that are still pending. |
| for (const auto& input : inputs) { |
| auto it = run_state->pending_inputs.find(input.first); |
| if (it == run_state->pending_inputs.end()) { |
| return errors::InvalidArgument( |
| "The feed ", input.first, |
| " was not specified in partial_run_setup."); |
| } else if (it->second) { |
| return errors::InvalidArgument("The feed ", input.first, |
| " has already been fed."); |
| } |
| } |
| // Check that this is a new set of fetches that are still pending. |
| for (const auto& output : output_names) { |
| auto it = run_state->pending_outputs.find(output); |
| if (it == run_state->pending_outputs.end()) { |
| return errors::InvalidArgument( |
| "The fetch ", output, " was not specified in partial_run_setup."); |
| } else if (it->second) { |
| return errors::InvalidArgument("The fetch ", output, |
| " has already been fetched."); |
| } |
| } |
| } |
| |
| // Check that this new set of fetches can be computed from all the |
| // feeds we have supplied. |
| TF_RETURN_IF_ERROR( |
| CheckFetch(inputs, output_names, executors_and_keys, run_state)); |
| |
| // Send inputs. |
| Status s = SendPRunInputs(inputs, executors_and_keys, run_state->rendez); |
| |
| // Receive outputs. |
| if (s.ok()) { |
| s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs); |
| } |
| |
| // Save the output tensors of this run we choose to keep. |
| if (s.ok()) { |
| s = run_state->tensor_store.SaveTensors(output_names, &session_state_); |
| } |
| |
| { |
| mutex_lock l(executor_lock_); |
| // Delete the run state if there is an error or all fetches are done. |
| bool done = true; |
| if (s.ok()) { |
| { |
| mutex_lock l(run_state->mu_); |
| if (!run_state->status.ok()) { |
| LOG(WARNING) << "An error unrelated to this prun has been detected. " |
| << run_state->status; |
| } |
| } |
| for (const auto& input : inputs) { |
| auto it = run_state->pending_inputs.find(input.first); |
| it->second = true; |
| } |
| for (const auto& name : output_names) { |
| auto it = run_state->pending_outputs.find(name); |
| it->second = true; |
| } |
| done = run_state->PendingDone(); |
| } |
| if (done) { |
| WaitForNotification(run_state, cancellation_manager_, |
| operation_timeout_in_ms_); |
| partial_runs_.erase(handle); |
| } |
| } |
| |
| return s; |
| } |
| |
| Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor, |
| Tensor* retrieved_tensor) { |
| if (resource_tensor.dtype() != DT_RESOURCE) { |
| return errors::InvalidArgument(strings::StrCat( |
| "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ", |
| resource_tensor.dtype())); |
| } |
| |
| const ResourceHandle& resource_handle = |
| resource_tensor.scalar<ResourceHandle>()(); |
| |
| if (resource_handle.container() == |
| SessionState::kTensorHandleResourceTypeName) { |
| return session_state_.GetTensor(resource_handle.name(), retrieved_tensor); |
| } else { |
| return errors::InvalidArgument(strings::StrCat( |
| "Invalid resource type hash code: ", resource_handle.hash_code(), |
| "(name: ", resource_handle.name(), |
| " type: ", resource_handle.maybe_type_name(), |
| "). Perhaps a resource tensor was being provided as a feed? That is " |
| "not currently allowed. Please file an issue at " |
| "https://github.com/tensorflow/tensorflow/issues/new, ideally with a " |
| "short code snippet that leads to this error message.")); |
| } |
| } |
| |
| Status DirectSession::SendPRunInputs(const NamedTensorList& inputs, |
| const ExecutorsAndKeys* executors_and_keys, |
| IntraProcessRendezvous* rendez) { |
| Status s; |
| Rendezvous::ParsedKey parsed; |
| // Insert the input tensors into the local rendezvous by their |
| // rendezvous key. |
| for (const auto& input : inputs) { |
| auto it = |
| executors_and_keys->input_name_to_rendezvous_key.find(input.first); |
| if (it == executors_and_keys->input_name_to_rendezvous_key.end()) { |
| return errors::Internal("'", input.first, "' is not a pre-defined feed."); |
| } |
| const string& input_key = it->second; |
| |
| s = Rendezvous::ParseKey(input_key, &parsed); |
| if (!s.ok()) { |
| rendez->StartAbort(s); |
| return s; |
| } |
| |
| if (input.second.dtype() == DT_RESOURCE) { |
| Tensor tensor_from_handle; |
| s = ResourceHandleToInputTensor(input.second, &tensor_from_handle); |
| if (s.ok()) { |
| s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle, false); |
| } |
| } else { |
| s = rendez->Send(parsed, Rendezvous::Args(), input.second, false); |
| } |
| |
| if (!s.ok()) { |
| rendez->StartAbort(s); |
| return s; |
| } |
| } |
| return Status::OK(); |
| } |
| |
| Status DirectSession::RecvPRunOutputs( |
| const std::vector<string>& output_names, |
| const ExecutorsAndKeys* executors_and_keys, RunState* run_state, |
| std::vector<Tensor>* outputs) { |
| Status s; |
| if (!output_names.empty()) { |
| outputs->resize(output_names.size()); |
| } |
| |
| Rendezvous::ParsedKey parsed; |
| // Get the outputs from the rendezvous |
| for (size_t output_offset = 0; output_offset < output_names.size(); |
| ++output_offset) { |
| const string& output_name = output_names[output_offset]; |
| auto it = |
| executors_and_keys->output_name_to_rendezvous_key.find(output_name); |
| if (it == executors_and_keys->output_name_to_rendezvous_key.end()) { |
| return errors::Internal("'", output_name, |
| "' is not a pre-defined fetch."); |
| } |
| const string& output_key = it->second; |
| Tensor output_tensor; |
| bool is_dead; |
| IntraProcessRendezvous* rendez = run_state->rendez; |
| |
| s = Rendezvous::ParseKey(output_key, &parsed); |
| if (s.ok()) { |
| // Fetch data from the Rendezvous. |
| s = rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, &is_dead, |
| operation_timeout_in_ms_); |
| if (is_dead && s.ok()) { |
| s = errors::InvalidArgument("The tensor returned for ", output_name, |
| " was not valid."); |
| } |
| } |
| if (!s.ok()) { |
| rendez->StartAbort(s); |
| outputs->clear(); |
| return s; |
| } |
| |
| (*outputs)[output_offset] = output_tensor; |
| } |
| return Status::OK(); |
| } |
| |
| Status DirectSession::CheckFetch(const NamedTensorList& feeds, |
| const std::vector<string>& fetches, |
| const ExecutorsAndKeys* executors_and_keys, |
| const RunState* run_state) { |
| const Graph* graph = executors_and_keys->graph.get(); |
| const NameNodeMap* name_to_node = &executors_and_keys->name_to_node; |
| |
| // Build the set of pending feeds that we haven't seen. |
| std::unordered_set<TensorId, TensorId::Hasher> pending_feeds; |
| { |
| mutex_lock l(executor_lock_); |
| for (const auto& input : run_state->pending_inputs) { |
| // Skip if the feed has already been fed. |
| if (input.second) continue; |
| TensorId id(ParseTensorName(input.first)); |
| auto it = name_to_node->find(id.first); |
| if (it == name_to_node->end()) { |
| return errors::NotFound("Feed ", input.first, ": not found"); |
| } |
| pending_feeds.insert(id); |
| } |
| } |
| for (const auto& it : feeds) { |
| TensorId id(ParseTensorName(it.first)); |
| pending_feeds.erase(id); |
| } |
| |
| // Initialize the stack with the fetch nodes. |
| std::vector<const Node*> stack; |
| for (const string& fetch : fetches) { |
| TensorId id(ParseTensorName(fetch)); |
| auto it = name_to_node->find(id.first); |
| if (it == name_to_node->end()) { |
| return errors::NotFound("Fetch ", fetch, ": not found"); |
| } |
| stack.push_back(it->second); |
| } |
| |
| // Any tensor needed for fetches can't be in pending_feeds. |
| std::vector<bool> visited(graph->num_node_ids(), false); |
| while (!stack.empty()) { |
| const Node* n = stack.back(); |
| stack.pop_back(); |
| |
| for (const Edge* in_edge : n->in_edges()) { |
| const Node* in_node = in_edge->src(); |
| if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) { |
| return errors::InvalidArgument("Fetch ", in_node->name(), ":", |
| in_edge->src_output(), |
| " can't be computed from the feeds" |
| " that have been fed so far."); |
| } |
| if (!visited[in_node->id()]) { |
| visited[in_node->id()] = true; |
| stack.push_back(in_node); |
| } |
| } |
| } |
| return Status::OK(); |
| } |
| |
| Status DirectSession::CreateExecutors( |
| const CallableOptions& callable_options, |
| std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys, |
| std::unique_ptr<FunctionInfo>* out_func_info, |
| RunStateArgs* run_state_args) { |
| BuildGraphOptions options; |
| options.callable_options = callable_options; |
| options.use_function_convention = !run_state_args->is_partial_run; |
| options.collective_graph_key = |
| callable_options.run_options().experimental().collective_graph_key(); |
| |
| std::unique_ptr<FunctionInfo> func_info(new FunctionInfo); |
| std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys); |
| |
| ek->callable_options = callable_options; |
| |
| std::unordered_map<string, std::unique_ptr<Graph>> graphs; |
| TF_RETURN_IF_ERROR(CreateGraphs( |
| options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types, |
| &ek->output_types, &ek->collective_graph_key)); |
| |
| if (run_state_args->is_partial_run) { |
| ek->graph = std::move(run_state_args->graph); |
| std::unordered_set<StringPiece, StringPieceHasher> names; |
| for (const string& input : callable_options.feed()) { |
| TensorId id(ParseTensorName(input)); |
| names.emplace(id.first); |
| } |
| for (const string& output : callable_options.fetch()) { |
| TensorId id(ParseTensorName(output)); |
| names.emplace(id.first); |
| } |
| for (Node* n : ek->graph->nodes()) { |
| if (names.count(n->name()) > 0) { |
| ek->name_to_node.insert({n->name(), n}); |
| } |
| } |
| } |
| ek->items.reserve(graphs.size()); |
| const auto& optimizer_opts = |
| options_.config.graph_options().optimizer_options(); |
| |
| int graph_def_version; |
| { |
| mutex_lock l(graph_def_lock_); |
| graph_def_version = |
| execution_state_->original_graph_def().versions().producer(); |
| } |
| func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime( |
| device_mgr_.get(), options_.env, graph_def_version, |
| func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first)); |
| |
| GraphOptimizer optimizer(optimizer_opts); |
| for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) { |
| const string& partition_name = iter->first; |
| std::unique_ptr<Graph>& partition_graph = iter->second; |
| |
| Device* device; |
| TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device)); |
| |
| ek->items.resize(ek->items.size() + 1); |
| auto* item = &(ek->items.back()); |
| auto lib = func_info->proc_flr->GetFLR(partition_name); |
| if (lib == nullptr) { |
| return errors::Internal("Could not find device: ", partition_name); |
| } |
| item->flib = lib; |
| |
| LocalExecutorParams params; |
| params.device = device; |
| params.function_library = lib; |
| auto opseg = device->op_segment(); |
| params.create_kernel = [this, lib, opseg](const NodeDef& ndef, |
| OpKernel** kernel) { |
| // NOTE(mrry): We must not share function kernels (implemented |
| // using `CallOp`) between subgraphs, because `CallOp::handle_` |
| // is tied to a particular subgraph. Even if the function itself |
| // is stateful, the `CallOp` that invokes it is not. |
| if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) { |
| return lib->CreateKernel(ndef, kernel); |
| } |
| auto create_fn = [lib, &ndef](OpKernel** kernel) { |
| return lib->CreateKernel(ndef, kernel); |
| }; |
| // Kernels created for subgraph nodes need to be cached. On |
| // cache miss, create_fn() is invoked to create a kernel based |
| // on the function library here + global op registry. |
| return opseg->FindOrCreate(session_handle_, ndef.name(), kernel, |
| create_fn); |
| }; |
| params.delete_kernel = [lib](OpKernel* kernel) { |
| if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) |
| delete kernel; |
| }; |
| |
| optimizer.Optimize(lib, options_.env, device, &partition_graph, |
| /*shape_map=*/nullptr); |
| |
| // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph. |
| const DebugOptions& debug_options = |
| options.callable_options.run_options().debug_options(); |
| if (!debug_options.debug_tensor_watch_opts().empty()) { |
| TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug( |
| debug_options, partition_graph.get(), params.device)); |
| } |
| |
| TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()), |
| device->name(), |
| partition_graph.get())); |
| // NewLocalExecutor takes ownership of partition_graph. |
| item->graph = partition_graph.get(); |
| item->executor = nullptr; |
| item->device = device; |
| auto executor_type = options_.config.experimental().executor_type(); |
| TF_RETURN_IF_ERROR(NewExecutor( |
| executor_type, params, std::move(partition_graph), &item->executor)); |
| } |
| |
| // Cache the mapping from input/output names to graph elements to |
| // avoid recomputing it every time. |
| if (!run_state_args->is_partial_run) { |
| // For regular `Run()`, we use the function calling convention, and so |
| // maintain a mapping from input/output names to |
| // argument/return-value ordinal index. |
| for (int i = 0; i < callable_options.feed().size(); ++i) { |
| const string& input = callable_options.feed(i); |
| ek->input_name_to_index[input] = i; |
| } |
| for (int i = 0; i < callable_options.fetch().size(); ++i) { |
| const string& output = callable_options.fetch(i); |
| ek->output_name_to_index[output] = i; |
| } |
| } else { |
| // For `PRun()`, we use the rendezvous calling convention, and so |
| // maintain a mapping from input/output names to rendezvous keys. |
| // |
| // We always use the first device as the device name portion of the |
| // key, even if we're feeding another graph. |
| for (int i = 0; i < callable_options.feed().size(); ++i) { |
| const string& input = callable_options.feed(i); |
| ek->input_name_to_rendezvous_key[input] = GetRendezvousKey( |
| input, device_set_.client_device()->attributes(), FrameAndIter(0, 0)); |
| } |
| for (int i = 0; i < callable_options.fetch().size(); ++i) { |
| const string& output = callable_options.fetch(i); |
| ek->output_name_to_rendezvous_key[output] = |
| GetRendezvousKey(output, device_set_.client_device()->attributes(), |
| FrameAndIter(0, 0)); |
| } |
| } |
| |
| *out_executors_and_keys = std::move(ek); |
| *out_func_info = std::move(func_info); |
| return Status::OK(); |
| } |
| |
| Status DirectSession::GetOrCreateExecutors( |
| gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs, |
| gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys, |
| RunStateArgs* run_state_args) { |
| int64 handle_name_counter_value = -1; |
| if (LogMemory::IsEnabled() || run_state_args->is_partial_run) { |
| handle_name_counter_value = handle_name_counter_.fetch_add(1); |
| } |
| |
| string debug_tensor_watches_summary; |
| if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) { |
| debug_tensor_watches_summary = SummarizeDebugTensorWatches( |
| run_state_args->debug_options.debug_tensor_watch_opts()); |
| } |
| |
| // Fast lookup path, no sorting. |
| const string key = strings::StrCat( |
| str_util::Join(inputs, ","), "->", str_util::Join(outputs, ","), "/", |
| str_util::Join(target_nodes, ","), "/", run_state_args->is_partial_run, |
| "/", debug_tensor_watches_summary); |
| // Set the handle, if it's needed to log memory or for partial run. |
| if (handle_name_counter_value >= 0) { |
| run_state_args->handle = |
| strings::StrCat(key, ";", handle_name_counter_value); |
| } |
| |
| // See if we already have the executors for this run. |
| { |
| mutex_lock l(executor_lock_); // could use reader lock |
| auto it = executors_.find(key); |
| if (it != executors_.end()) { |
| *executors_and_keys = it->second.get(); |
| return Status::OK(); |
| } |
| } |
| |
| // Slow lookup path, the unsorted key missed the cache. |
| // Sort the inputs and outputs, and look up with the sorted key in case an |
| // earlier call used a different order of inputs and outputs. |
| // |
| // We could consider some other signature instead of sorting that |
| // preserves the same property to avoid the sort in the future. |
| std::vector<string> inputs_sorted(inputs.begin(), inputs.end()); |
| std::sort(inputs_sorted.begin(), inputs_sorted.end()); |
| std::vector<string> outputs_sorted(outputs.begin(), outputs.end()); |
| std::sort(outputs_sorted.begin(), outputs_sorted.end()); |
| std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end()); |
| std::sort(tn_sorted.begin(), tn_sorted.end()); |
| |
| const string sorted_key = strings::StrCat( |
| str_util::Join(inputs_sorted, ","), "->", |
| str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","), |
| "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary); |
| // Set the handle, if its needed to log memory or for partial run. |
| if (handle_name_counter_value >= 0) { |
| run_state_args->handle = |
| strings::StrCat(sorted_key, ";", handle_name_counter_value); |
| } |
| |
| // See if we already have the executors for this run. |
| { |
| mutex_lock l(executor_lock_); |
| auto it = executors_.find(sorted_key); |
| if (it != executors_.end()) { |
| *executors_and_keys = it->second.get(); |
| // Insert this under the original key. |
| executors_.emplace(key, it->second); |
| return Status::OK(); |
| } |
| } |
| |
| // Nothing found, so create the executors and store in the cache. |
| // The executor_lock_ is intentionally released while executors are |
| // being created. |
| CallableOptions callable_options; |
| for (const string& input : inputs_sorted) { |
| callable_options.add_feed(input); |
| } |
| for (const string& output : outputs_sorted) { |
| callable_options.add_fetch(output); |
| } |
| for (const string& target : tn_sorted) { |
| callable_options.add_target(target); |
| } |
| *callable_options.mutable_run_options()->mutable_debug_options() = |
| run_state_args->debug_options; |
| callable_options.mutable_run_options() |
| ->mutable_experimental() |
| ->set_collective_graph_key(run_state_args->collective_graph_key); |
| std::unique_ptr<ExecutorsAndKeys> ek; |
| std::unique_ptr<FunctionInfo> func_info; |
| TF_RETURN_IF_ERROR( |
| CreateExecutors(callable_options, &ek, &func_info, run_state_args)); |
| |
| // Reacquire the lock, try to insert into the map. |
| mutex_lock l(executor_lock_); |
| functions_.push_back(std::move(func_info)); |
| |
| // Another thread may have created the entry before us, in which case we will |
| // reuse the already created one. |
| auto insert_result = executors_.emplace( |
| sorted_key, std::shared_ptr<ExecutorsAndKeys>(std::move(ek))); |
| // Insert the value under the original key, so the fast path lookup will work |
| // if the user uses the same order of inputs, outputs, and targets again. |
| executors_.emplace(key, insert_result.first->second); |
| *executors_and_keys = insert_result.first->second.get(); |
| |
| return Status::OK(); |
| } |
| |
| Status DirectSession::CreateGraphs( |
| const BuildGraphOptions& subgraph_options, |
| std::unordered_map<string, std::unique_ptr<Graph>>* outputs, |
| std::unique_ptr<FunctionLibraryDefinition>* flib_def, |
| RunStateArgs* run_state_args, DataTypeVector* input_types, |
| DataTypeVector* output_types, int64* collective_graph_key) { |
| mutex_lock l(graph_def_lock_); |
| std::unique_ptr<ClientGraph> client_graph; |
| |
| std::unique_ptr<GraphExecutionState> temp_exec_state_holder; |
| GraphExecutionState* execution_state = nullptr; |
| if (options_.config.graph_options().place_pruned_graph()) { |
| // Because we are placing pruned graphs, we need to create a |
| // new GraphExecutionState for every new unseen graph, |
| // and then place it. |
| GraphExecutionStateOptions prune_options; |
| prune_options.device_set = &device_set_; |
| prune_options.session_options = &options_; |
| prune_options.stateful_placements = stateful_placements_; |
| TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph( |
| execution_state_->original_graph_def().library(), prune_options, |
| execution_state_->original_graph_def(), subgraph_options, |
| &temp_exec_state_holder, &client_graph)); |
| execution_state = temp_exec_state_holder.get(); |
| } else { |
| execution_state = execution_state_.get(); |
| TF_RETURN_IF_ERROR( |
| execution_state->BuildGraph(subgraph_options, &client_graph)); |
| } |
| *collective_graph_key = client_graph->collective_graph_key; |
| |
| if (subgraph_options.callable_options.feed_size() != |
| client_graph->feed_types.size()) { |
| return errors::Internal( |
| "Graph pruning failed: requested number of feed endpoints = ", |
| subgraph_options.callable_options.feed_size(), |
| " versus number of pruned feed endpoints = ", |
| client_graph->feed_types.size()); |
| } |
| if (subgraph_options.callable_options.fetch_size() != |
| client_graph->fetch_types.size()) { |
| return errors::Internal( |
| "Graph pruning failed: requested number of fetch endpoints = ", |
| subgraph_options.callable_options.fetch_size(), |
| " versus number of pruned fetch endpoints = ", |
| client_graph->fetch_types.size()); |
| } |
| |
| auto current_stateful_placements = execution_state->GetStatefulPlacements(); |
| // Update our current state based on the execution_state's |
| // placements. If there are any mismatches for a node, |
| // we should fail, as this should never happen. |
| for (auto placement_pair : current_stateful_placements) { |
| const string& node_name = placement_pair.first; |
| const string& placement = placement_pair.second; |
| auto iter = stateful_placements_.find(node_name); |
| if (iter == stateful_placements_.end()) { |
| stateful_placements_.insert(std::make_pair(node_name, placement)); |
| } else if (iter->second != placement) { |
| return errors::Internal( |
| "Stateful placement mismatch. " |
| "Current assignment of ", |
| node_name, " to ", iter->second, " does not match ", placement); |
| } |
| } |
| |
| stateful_placements_ = execution_state->GetStatefulPlacements(); |
| |
| // Remember the graph in run state if this is a partial run. |
| if (run_state_args->is_partial_run) { |
| run_state_args->graph.reset(new Graph(flib_def_.get())); |
| CopyGraph(*execution_state->full_graph(), run_state_args->graph.get()); |
| } |
| |
| // Partition the graph across devices. |
| PartitionOptions popts; |
| popts.node_to_loc = [](const Node* node) { |
| return node->assigned_device_name(); |
| }; |
| popts.new_name = [this](const string& prefix) { |
| return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1)); |
| }; |
| popts.get_incarnation = [](const string& name) { |
| // The direct session does not have changing incarnation numbers. |
| // Just return '1'. |
| return 1; |
| }; |
| popts.flib_def = &client_graph->graph.flib_def(); |
| popts.control_flow_added = false; |
| |
| std::unordered_map<string, GraphDef> partitions; |
| TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions)); |
| |
| std::vector<string> device_names; |
| for (auto device : devices_) { |
| // Extract the LocalName from the device. |
| device_names.push_back(DeviceNameUtils::LocalName(device->name())); |
| } |
| |
| // Check for valid partitions. |
| for (const auto& partition : partitions) { |
| const string local_partition_name = |
| DeviceNameUtils::LocalName(partition.first); |
| if (std::count(device_names.begin(), device_names.end(), |
| local_partition_name) == 0) { |
| return errors::InvalidArgument( |
| "Creating a partition for ", local_partition_name, |
| " which doesn't exist in the list of available devices. Available " |
| "devices: ", |
| str_util::Join(device_names, ",")); |
| } |
| } |
| |
| for (const auto& partition : partitions) { |
| std::unique_ptr<Graph> device_graph( |
| new Graph(client_graph->flib_def.get())); |
| GraphConstructorOptions device_opts; |
| // There are internal operations (e.g., send/recv) that we now allow. |
| device_opts.allow_internal_ops = true; |
| device_opts.expect_device_spec = true; |
| TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, partition.second, |
| device_graph.get())); |
| outputs->emplace(partition.first, std::move(device_graph)); |
| } |
| |
| GraphOptimizationPassOptions optimization_options; |
| optimization_options.session_options = &options_; |
| optimization_options.flib_def = client_graph->flib_def.get(); |
| optimization_options.partition_graphs = outputs; |
| TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( |
| OptimizationPassRegistry::POST_PARTITIONING, optimization_options)); |
| |
| Status s; |
| for (auto& partition : *outputs) { |
| const string& partition_name = partition.first; |
| std::unique_ptr<Graph>* graph = &partition.second; |
| |
| VLOG(2) << "Created " << DebugString(graph->get()) << " for " |
| << partition_name; |
| |
| // Give the device an opportunity to rewrite its subgraph. |
| Device* d; |
| s = device_mgr_->LookupDevice(partition_name, &d); |
| if (!s.ok()) break; |
| s = d->MaybeRewriteGraph(graph); |
| if (!s.ok()) { |
| break; |
| } |
| } |
| *flib_def = std::move(client_graph->flib_def); |
| std::swap(*input_types, client_graph->feed_types); |
| std::swap(*output_types, client_graph->fetch_types); |
| return s; |
| } |
| |
| ::tensorflow::Status DirectSession::ListDevices( |
| std::vector<DeviceAttributes>* response) { |
| response->clear(); |
| response->reserve(devices_.size()); |
| for (Device* d : devices_) { |
| const DeviceAttributes& attrs = d->attributes(); |
| response->emplace_back(attrs); |
| } |
| return ::tensorflow::Status::OK(); |
| } |
| |
| ::tensorflow::Status DirectSession::Reset( |
| const std::vector<string>& containers) { |
| device_mgr_->ClearContainers(containers); |
| return ::tensorflow::Status::OK(); |
| } |
| |
| ::tensorflow::Status DirectSession::Close() { |
| cancellation_manager_->StartCancel(); |
| { |
| mutex_lock l(closed_lock_); |
| if (closed_) return ::tensorflow::Status::OK(); |
| closed_ = true; |
| } |
| if (factory_ != nullptr) factory_->Deregister(this); |
| return ::tensorflow::Status::OK(); |
| } |
| |
| DirectSession::RunState::RunState( |
| const std::vector<string>& pending_input_names, |
| const std::vector<string>& pending_output_names, int64 step_id, |
| const std::vector<Device*>* devices) |
| : step_container(step_id, [devices, step_id](const string& name) { |
| for (auto d : *devices) { |
| if (!d->resource_manager()->Cleanup(name).ok()) { |
| // Do nothing... |
| } |
| ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr(); |
| if (sam) sam->Cleanup(step_id); |
| } |
| }) { |
| // Initially all the feeds and fetches are pending. |
| for (auto& name : pending_input_names) { |
| pending_inputs[name] = false; |
| } |
| for (auto& name : pending_output_names) { |
| pending_outputs[name] = false; |
| } |
| } |
| |
| DirectSession::RunState::RunState(int64 step_id, |
| const std::vector<Device*>* devices) |
| : RunState({}, {}, step_id, devices) {} |
| |
| DirectSession::RunState::~RunState() { |
| if (rendez != nullptr) { |
| if (!executors_done.HasBeenNotified()) { |
| rendez->StartAbort(errors::Cancelled("PRun cancellation")); |
| executors_done.WaitForNotification(); |
| } |
| rendez->Unref(); |
| } |
| } |
| |
| bool DirectSession::RunState::PendingDone() const { |
| for (const auto& it : pending_inputs) { |
| if (!it.second) return false; |
| } |
| for (const auto& it : pending_outputs) { |
| if (!it.second) return false; |
| } |
| return true; |
| } |
| |
| void DirectSession::WaitForNotification(RunState* run_state, |
| CancellationManager* cm, |
| int64 timeout_in_ms) { |
| const Status status = |
| WaitForNotification(&run_state->executors_done, timeout_in_ms); |
| if (!status.ok()) { |
| { |
| mutex_lock l(run_state->mu_); |
| run_state->status.Update(status); |
| } |
| cm->StartCancel(); |
| // We must wait for the executors to complete, because they have borrowed |
| // references to `cm` and other per-step state. After this notification, it |
| // is safe to clean up the step. |
| run_state->executors_done.WaitForNotification(); |
| } |
| } |
| |
| ::tensorflow::Status DirectSession::WaitForNotification( |
| Notification* notification, int64 timeout_in_ms) { |
| if (timeout_in_ms > 0) { |
| const int64 timeout_in_us = timeout_in_ms * 1000; |
| const bool notified = |
| WaitForNotificationWithTimeout(notification, timeout_in_us); |
| if (!notified) { |
| return Status(error::DEADLINE_EXCEEDED, |
| "Timed out waiting for notification"); |
| } |
| } else { |
| notification->WaitForNotification(); |
| } |
| return Status::OK(); |
| } |
| |
| Status DirectSession::MakeCallable(const CallableOptions& callable_options, |
| CallableHandle* out_handle) { |
| TF_RETURN_IF_ERROR(CheckNotClosed()); |
| TF_RETURN_IF_ERROR(CheckGraphCreated("MakeCallable()")); |
| |
| std::unique_ptr<ExecutorsAndKeys> ek; |
| std::unique_ptr<FunctionInfo> func_info; |
| RunStateArgs run_state_args(callable_options.run_options().debug_options()); |
| TF_RETURN_IF_ERROR( |
| CreateExecutors(callable_options, &ek, &func_info, &run_state_args)); |
| { |
| mutex_lock l(callables_lock_); |
| *out_handle = next_callable_handle_++; |
| callables_[*out_handle] = {std::move(ek), std::move(func_info)}; |
| } |
| return Status::OK(); |
| } |
| |
| class DirectSession::RunCallableCallFrame : public CallFrameInterface { |
| public: |
| RunCallableCallFrame(DirectSession* session, |
| ExecutorsAndKeys* executors_and_keys, |
| const std::vector<Tensor>* feed_tensors, |
| std::vector<Tensor>* fetch_tensors) |
| : session_(session), |
| executors_and_keys_(executors_and_keys), |
| feed_tensors_(feed_tensors), |
| fetch_tensors_(fetch_tensors) {} |
| |
| size_t num_args() const override { |
| return executors_and_keys_->input_types.size(); |
| } |
| size_t num_retvals() const override { |
| return executors_and_keys_->output_types.size(); |
| } |
| |
| Status GetArg(int index, Tensor* val) const override { |
| if (index > feed_tensors_->size()) { |
| return errors::Internal("Args index out of bounds: ", index); |
| } else if (executors_and_keys_->input_types[index] == DT_RESOURCE) { |
| TF_RETURN_IF_ERROR( |
| session_->ResourceHandleToInputTensor((*feed_tensors_)[index], val)); |
| } else { |
| *val = (*feed_tensors_)[index]; |
| } |
| return Status::OK(); |
| } |
| |
| Status SetRetval(int index, const Tensor& val) override { |
| if (index > fetch_tensors_->size()) { |
| return errors::Internal("RetVal index out of bounds: ", index); |
| } |
| (*fetch_tensors_)[index] = val; |
| return Status::OK(); |
| } |
| |
| private: |
| DirectSession* const session_; // Not owned. |
| ExecutorsAndKeys* const executors_and_keys_; // Not owned. |
| const std::vector<Tensor>* const feed_tensors_; // Not owned. |
| std::vector<Tensor>* const fetch_tensors_; // Not owned. |
| }; |
| |
| ::tensorflow::Status DirectSession::RunCallable( |
| CallableHandle handle, const std::vector<Tensor>& feed_tensors, |
| std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata) { |
| TF_RETURN_IF_ERROR(CheckNotClosed()); |
| TF_RETURN_IF_ERROR(CheckGraphCreated("RunCallable()")); |
| direct_session_runs->GetCell()->IncrementBy(1); |
| |
| // Check if we already have an executor for these arguments. |
| std::shared_ptr<ExecutorsAndKeys> executors_and_keys; |
| const int64 step_id = step_id_counter_.fetch_add(1); |
| |
| { |
| tf_shared_lock l(callables_lock_); |
| if (handle >= next_callable_handle_) { |
| return errors::InvalidArgument("No such callable handle: ", handle); |
| } |
| executors_and_keys = callables_[handle].executors_and_keys; |
| } |
| |
| if (!executors_and_keys) { |
| return errors::InvalidArgument( |
| "Attempted to run callable after handle was released: ", handle); |
| } |
| |
| // NOTE(mrry): Debug options are not currently supported in the |
| // callable interface. |
| DebugOptions debug_options; |
| RunStateArgs run_state_args(debug_options); |
| |
| // Configure a call frame for the step, which we use to feed and |
| // fetch values to and from the executors. |
| if (feed_tensors.size() != executors_and_keys->input_types.size()) { |
| return errors::InvalidArgument( |
| "Expected ", executors_and_keys->input_types.size(), |
| " feed tensors, but got ", feed_tensors.size()); |
| } |
| if (fetch_tensors != nullptr) { |
| fetch_tensors->resize(executors_and_keys->output_types.size()); |
| } else if (!executors_and_keys->output_types.empty()) { |
| return errors::InvalidArgument( |
| "`fetch_tensors` must be provided when the callable has one or more " |
| "outputs."); |
| } |
| |
| // A specialized CallFrame implementation that takes advantage of the |
| // optimized RunCallable interface. |
| |
| RunCallableCallFrame call_frame(this, executors_and_keys.get(), &feed_tensors, |
| fetch_tensors); |
| |
| if (LogMemory::IsEnabled()) { |
| LogMemory::RecordStep(step_id, run_state_args.handle); |
| } |
| |
| TF_RETURN_IF_ERROR( |
| RunInternal(step_id, executors_and_keys->callable_options.run_options(), |
| &call_frame, executors_and_keys.get(), run_metadata)); |
| |
| return Status::OK(); |
| } |
| |
| ::tensorflow::Status DirectSession::ReleaseCallable(CallableHandle handle) { |
| mutex_lock l(callables_lock_); |
| if (handle >= next_callable_handle_) { |
| return errors::InvalidArgument("No such callable handle: ", handle); |
| } |
| callables_.erase(handle); |
| return Status::OK(); |
| } |
| |
| DirectSession::Callable::~Callable() { |
| // We must delete the fields in this order, because the destructor |
| // of `executors_and_keys` will call into an object owned by |
| // `function_info` (in particular, when deleting a kernel, it relies |
| // on the `FunctionLibraryRuntime` to know if the kernel is stateful |
| // or not). |
| executors_and_keys.reset(); |
| function_info.reset(); |
| } |
| |
| } // namespace tensorflow |