blob: d8ea85f1955f572e71d218a1c62cacb397a136e8 [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/local_executor_params.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/session_state.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool_interface.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
class StepStatsCollector;
// If this is called, we will sample execution cost for "inexpensive" kernels
// and switch them to "expensive" when the estimated cost exceeds expensive-ness
// threshold.
// This is a temporary flag for validating the performance impact of
// this feature. For simplicity, a global flag is used and once the flag
// is turned on, it cannot be turned off. We will remove this flag once this
// feature is validated.
void EnableAlwaysTrackKernelExecutionCost();
// Executor runs a graph computation.
// Example:
// Graph* graph = ...;
// ... construct graph ...
// Executor* executor;
// TF_CHECK_OK(NewSimpleExecutor(my_device, graph, &executor));
// Rendezvous* rendezvous = NewNaiveRendezvous();
// TF_CHECK_OK(rendezvous->Send("input", some_input_tensor));
// TF_CHECK_OK(executor->Run({ExecutorOpts, rendezvous, nullptr}));
// TF_CHECK_OK(rendezvous->Recv("output", &output_tensor));
// ... ...
//
// Multiple threads can call Executor::Run concurrently.
class Executor {
public:
virtual ~Executor() {}
// RunAsync() executes the graph computation. "done" is run when the
// graph computation completes. If any error happens during the
// computation, "done" is run and the error is passed to "done".
//
// RunAsync() is given a few arguments in Args. The caller must
// ensure objects passed in Args (rendezvous, stats_collector, etc.)
// are alive at least until done is invoked. All pointers to the
// argument objects can be nullptr.
//
// "step_id" is a process-wide unique identifier for the step being
// run. Executors on different devices may receive the same step_id
// in the case that a step runs Ops on more than one device. The
// step_id is used for tracking resource usage of a given step.
//
// RunAsync() uses the given "rendezvous", if not null, as the
// mechanism to communicate inputs and outputs of the underlying
// graph computation.
//
// RunAsync() calls "stats_collector", if not null, to keep track of
// stats. This allows us to collect statistics and traces on demand.
//
// RunAsync() is provided a "call_frame", if the executor is used
// for executing a function, is used to pass arguments and return
// values between the caller and the callee.
//
// RunAsync() uses "cancellation_manager", if not nullptr, to
// register callbacks that should be called if the graph computation
// is canceled. Note that the callbacks merely unblock any
// long-running computation, and a canceled step will terminate by
// returning/calling the DoneCallback as usual.
//
// RunAsync() dispatches closures to "runner". Typically, "runner"
// is backed up by a bounded threadpool.
struct Args {
int64 step_id = 0;
RendezvousInterface* rendezvous = nullptr;
StepStatsCollectorInterface* stats_collector = nullptr;
CallFrameInterface* call_frame = nullptr;
CancellationManager* cancellation_manager = nullptr;
SessionState* session_state = nullptr;
// Unique session identifier. Can be empty.
string session_handle;
TensorStore* tensor_store = nullptr;
ScopedStepContainer* step_container = nullptr;
CollectiveExecutor* collective_executor = nullptr;
thread::ThreadPoolInterface* user_intra_op_threadpool = nullptr;
// If true, calls Sync() on the device.
bool sync_on_finish = false;
typedef std::function<void()> Closure;
typedef std::function<void(Closure)> Runner;
Runner runner = nullptr;
// If true, all kernels will be treated as "inexpensive", and hence executed
// on the scheduling thread.
bool run_all_kernels_inline = false;
};
typedef std::function<void(const Status&)> DoneCallback;
virtual void RunAsync(const Args& args, DoneCallback done) = 0;
// Synchronous wrapper for RunAsync().
virtual Status Run(const Args& args) {
Status ret;
Notification n;
RunAsync(args, [&ret, &n](const Status& s) {
ret = s;
n.Notify();
});
n.WaitForNotification();
return ret;
}
};
// Creates an Executor that computes the given "graph".
//
// If successful, returns the constructed executor in "*executor". Otherwise,
// returns an error status.
//
// "params" provides a set of context for the executor. We expect that
// different context would provide different implementations.
::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params,
const Graph& graph, Executor** executor);
// A class to help run multiple executors in parallel and wait until
// all of them are complete.
//
// ExecutorBarrier deletes itself after the function returned by Get()
// is called.
class ExecutorBarrier {
public:
typedef std::function<void(const Status&)> StatusCallback;
// Create an ExecutorBarrier for 'num' different executors.
//
// 'r' is the shared Rendezvous object that is used to communicate
// state. If any of the executors experiences an error, the
// rendezvous object will be aborted exactly once.
//
// 'done' is called after the last executor completes, and
// ExecutorBarrier is deleted.
ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done)
: rendez_(r), done_cb_(done), pending_(num) {}
~ExecutorBarrier() {}
// Returns a closure that Executors must call when they are done
// computing, passing the status of their execution as an argument.
StatusCallback Get() {
return std::bind(&ExecutorBarrier::WhenDone, this, std::placeholders::_1);
}
private:
Rendezvous* rendez_ = nullptr;
StatusCallback done_cb_ = nullptr;
mutable mutex mu_;
int pending_ TF_GUARDED_BY(mu_) = 0;
StatusGroup status_group_ TF_GUARDED_BY(mu_);
void WhenDone(const Status& s) {
Rendezvous* error_rendez = nullptr;
StatusCallback done = nullptr;
Status status;
{
mutex_lock l(mu_);
// If we are the first error encountered, trigger an abort of the
// Rendezvous object by this thread only.
if (status_group_.ok() && !s.ok()) {
error_rendez = rendez_;
error_rendez->Ref();
}
if (!s.ok() && !StatusGroup::IsDerived(s) &&
!status_group_.HasLogMessages()) {
status_group_.AttachLogMessages();
}
status_group_.Update(s);
// If this is the last call to WhenDone, call the final callback
// below.
if (--pending_ == 0) {
CHECK(done_cb_ != nullptr);
std::swap(done, done_cb_);
status = status_group_.as_summary_status();
}
}
if (error_rendez != nullptr) {
error_rendez->StartAbort(
errors::Aborted("Stopping remaining executors."));
error_rendez->Unref();
}
if (done != nullptr) {
delete this;
if (!status.ok()) {
VLOG(1) << "ExecutorBarrier finished with bad status: " << status;
}
done(status);
}
}
TF_DISALLOW_COPY_AND_ASSIGN(ExecutorBarrier);
};
// A few helpers to facilitate create/delete kernels.
// Creates a kernel based on "props" on device "device". The kernel can
// access the functions in the "flib". The caller takes ownership of
// returned "*kernel".
Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
const std::shared_ptr<const NodeProperties>& props,
int graph_def_version, OpKernel** kernel);
// Deletes "kernel" returned by CreateKernel.
void DeleteNonCachedKernel(OpKernel* kernel);
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_