blob: e5d3c6ae3544b0372d329fba5b322ed19709fd5d [file] [log] [blame]
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/master_session.h"
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/profile_handler.h"
#include "tensorflow/core/common_runtime/stats_publisher_interface.h"
#include "tensorflow/core/debug/debug_graph_utils.h"
#include "tensorflow/core/distributed_runtime/request_id.h"
#include "tensorflow/core/distributed_runtime/scheduler.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/graph/graph_partition.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
// MasterSession wraps ClientGraph in a reference counted object.
// This way, MasterSession can clear up the cache mapping Run requests to
// compiled graphs while the compiled graph is still being used.
//
// TODO(zhifengc): Cleanup this class. It's becoming messy.
class MasterSession::ReffedClientGraph : public core::RefCounted {
public:
ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
std::unique_ptr<ClientGraph> client_graph,
const SessionOptions& session_opts,
const StatsPublisherFactory& stats_publisher_factory,
bool is_partial, WorkerCacheInterface* worker_cache,
bool should_deregister)
: session_handle_(handle),
bg_opts_(bopts),
client_graph_before_register_(std::move(client_graph)),
session_opts_(session_opts),
is_partial_(is_partial),
callable_opts_(bopts.callable_options),
worker_cache_(worker_cache),
should_deregister_(should_deregister),
collective_graph_key_(
client_graph_before_register_->collective_graph_key) {
VLOG(1) << "Created ReffedClientGraph for node with "
<< client_graph_before_register_->graph.num_node_ids();
stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
// Initialize a name to node map for processing device stats.
for (Node* n : client_graph_before_register_->graph.nodes()) {
name_to_node_details_.emplace(
n->name(),
NodeDetails(n->type_string(),
strings::StrCat(
"(", absl::StrJoin(n->requested_inputs(), ", "))));
}
}
~ReffedClientGraph() override {
if (should_deregister_) {
DeregisterPartitions();
} else {
for (Part& part : partitions_) {
worker_cache_->ReleaseWorker(part.name, part.worker);
}
}
}
const CallableOptions& callable_options() { return callable_opts_; }
const BuildGraphOptions& build_graph_options() { return bg_opts_; }
int64 collective_graph_key() { return collective_graph_key_; }
std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step,
int64 execution_count,
const RunOptions& ropts) {
return stats_publisher_->GetProfileHandler(step, execution_count, ropts);
}
int64 get_and_increment_execution_count() {
return execution_count_.fetch_add(1);
}
// Turn RPC logging on or off, both at the WorkerCache used by this
// master process, and at each remote worker in use for the current
// partitions.
void SetRPCLogging(bool active) {
worker_cache_->SetLogging(active);
// Logging is a best-effort activity, so we make async calls to turn
// it on/off and don't make use of the responses.
for (auto& p : partitions_) {
LoggingRequest* req = new LoggingRequest;
if (active) {
req->set_enable_rpc_logging(true);
} else {
req->set_disable_rpc_logging(true);
}
LoggingResponse* resp = new LoggingResponse;
Ref();
p.worker->LoggingAsync(req, resp, [this, req, resp](const Status& s) {
delete req;
delete resp;
// ReffedClientGraph owns p.worker so we need to hold a ref to
// ensure that the method doesn't attempt to access p.worker after
// ReffedClient graph has deleted it.
// TODO(suharshs): Simplify this ownership model.
Unref();
});
}
}
// Retrieve all RPC logs data accumulated for the current step, both
// from the local WorkerCache in use by this master process and from
// all the remote workers executing the remote partitions.
void RetrieveLogs(int64 step_id, StepStats* ss) {
// Get the local data first, because it sets *ss without merging.
worker_cache_->RetrieveLogs(step_id, ss);
// Then merge in data from all the remote workers.
LoggingRequest req;
req.add_fetch_step_id(step_id);
int waiting_for = partitions_.size();
if (waiting_for > 0) {
mutex scoped_mu;
BlockingCounter all_done(waiting_for);
for (auto& p : partitions_) {
LoggingResponse* resp = new LoggingResponse;
p.worker->LoggingAsync(
&req, resp,
[step_id, ss, resp, &scoped_mu, &all_done](const Status& s) {
{
mutex_lock l(scoped_mu);
if (s.ok()) {
for (auto& lss : resp->step()) {
if (step_id != lss.step_id()) {
LOG(ERROR) << "Wrong step_id in LoggingResponse";
continue;
}
ss->MergeFrom(lss.step_stats());
}
}
delete resp;
}
// Must not decrement all_done until out of critical section where
// *ss is updated.
all_done.DecrementCount();
});
}
all_done.Wait();
}
}
// Local execution methods.
// Partitions the graph into subgraphs and registers them on
// workers.
Status RegisterPartitions(PartitionOptions popts);
// Runs one step of all partitions.
Status RunPartitions(const MasterEnv* env, int64 step_id,
int64 execution_count, PerStepState* pss,
CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp,
CancellationManager* cm, const bool is_last_partial_run);
Status RunPartitions(const MasterEnv* env, int64 step_id,
int64 execution_count, PerStepState* pss,
CallOptions* call_opts, const RunCallableRequest& req,
RunCallableResponse* resp, CancellationManager* cm);
// Calls workers to cleanup states for the step "step_id". Calls
// `done` when all cleanup RPCs have completed.
void CleanupPartitionsAsync(int64 step_id, StatusCallback done);
// Post-processing of any runtime statistics gathered during execution.
void ProcessStats(int64 step_id, PerStepState* pss, ProfileHandler* ph,
const RunOptions& options, RunMetadata* resp);
void ProcessDeviceStats(ProfileHandler* ph, const DeviceStepStats& ds,
bool is_rpc);
// Checks that the requested fetches can be computed from the provided feeds.
Status CheckFetches(const RunStepRequestWrapper& req,
const RunState* run_state,
GraphExecutionState* execution_state);
private:
const string session_handle_;
const BuildGraphOptions bg_opts_;
// NOTE(mrry): This pointer will be null after `RegisterPartitions()` returns.
std::unique_ptr<ClientGraph> client_graph_before_register_ GUARDED_BY(mu_);
const SessionOptions session_opts_;
const bool is_partial_;
const CallableOptions callable_opts_;
WorkerCacheInterface* const worker_cache_; // Not owned.
struct NodeDetails {
explicit NodeDetails(string type_string, string detail_text)
: type_string(std::move(type_string)),
detail_text(std::move(detail_text)) {}
const string type_string;
const string detail_text;
};
std::unordered_map<string, NodeDetails> name_to_node_details_;
const bool should_deregister_;
const int64 collective_graph_key_;
std::atomic<int64> execution_count_ = {0};
// Graph partitioned into per-location subgraphs.
struct Part {
// Worker name.
string name;
// Maps feed names to rendezvous keys. Empty most of the time.
std::unordered_map<string, string> feed_key;
// Maps rendezvous keys to fetch names. Empty most of the time.
std::unordered_map<string, string> key_fetch;
// The interface to the worker. Owned.
WorkerInterface* worker = nullptr;
// After registeration with the worker, graph_handle identifies
// this partition on the worker.
string graph_handle;
Part() : feed_key(3), key_fetch(3) {}
};
// partitions_ is immutable after RegisterPartitions() call
// finishes. RunPartitions() can access partitions_ safely without
// acquiring locks.
std::vector<Part> partitions_;
mutable mutex mu_;
// Partition initialization and registration only needs to happen
// once. `!client_graph_before_register_ && !init_done_.HasBeenNotified()`
// indicates the initialization is ongoing.
Notification init_done_;
// init_result_ remembers the initialization error if any.
Status init_result_ GUARDED_BY(mu_);
std::unique_ptr<StatsPublisherInterface> stats_publisher_;
string DetailText(const NodeDetails& details, const NodeExecStats& stats) {
int64 tot = 0;
for (auto& no : stats.output()) {
tot += no.tensor_description().allocation_description().requested_bytes();
}
string bytes;
if (tot >= 0.1 * 1048576.0) {
bytes = strings::Printf("[%.1fMB] ", tot / 1048576.0);
}
return strings::StrCat(bytes, stats.node_name(), " = ", details.type_string,
details.detail_text);
}
// Send/Recv nodes that are the result of client-added
// feeds and fetches must be tracked so that the tensors
// can be added to the local rendezvous.
static void TrackFeedsAndFetches(Part* part, const GraphDef& graph_def,
const PartitionOptions& popts);
// The actual graph partitioning and registration implementation.
Status DoBuildPartitions(
PartitionOptions popts, ClientGraph* client_graph,
std::unordered_map<string, GraphDef>* out_partitions);
Status DoRegisterPartitions(
const PartitionOptions& popts,
std::unordered_map<string, GraphDef> graph_partitions);
// Prepares a number of calls to workers. One call per partition.
// This is a generic method that handles Run, PartialRun, and RunCallable.
template <class FetchListType, class ClientRequestType,
class ClientResponseType>
Status RunPartitionsHelper(
const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
const FetchListType& fetches, const MasterEnv* env, int64 step_id,
int64 execution_count, PerStepState* pss, CallOptions* call_opts,
const ClientRequestType& req, ClientResponseType* resp,
CancellationManager* cm, bool is_last_partial_run);
// Deregisters the partitions on the workers. Called in the
// destructor and does not wait for the rpc completion.
void DeregisterPartitions();
TF_DISALLOW_COPY_AND_ASSIGN(ReffedClientGraph);
};
Status MasterSession::ReffedClientGraph::RegisterPartitions(
PartitionOptions popts) {
{ // Ensure register once.
mu_.lock();
if (client_graph_before_register_) {
// The `ClientGraph` is no longer needed after partitions are registered.
// Since it can account for a large amount of memory, we consume it here,
// and it will be freed after concluding with registration.
std::unique_ptr<ClientGraph> client_graph;
std::swap(client_graph_before_register_, client_graph);
mu_.unlock();
std::unordered_map<string, GraphDef> graph_defs;
popts.flib_def = client_graph->flib_def.get();
Status s = DoBuildPartitions(popts, client_graph.get(), &graph_defs);
if (s.ok()) {
// NOTE(mrry): The pointers in `graph_defs_for_publishing` do not remain
// valid after the call to DoRegisterPartitions begins, so
// `stats_publisher_` must make a copy if it wants to retain the
// GraphDef objects.
std::vector<const GraphDef*> graph_defs_for_publishing;
graph_defs_for_publishing.reserve(partitions_.size());
for (const auto& name_def : graph_defs) {
graph_defs_for_publishing.push_back(&name_def.second);
}
stats_publisher_->PublishGraphProto(graph_defs_for_publishing);
s = DoRegisterPartitions(popts, std::move(graph_defs));
}
mu_.lock();
init_result_ = s;
init_done_.Notify();
} else {
mu_.unlock();
init_done_.WaitForNotification();
mu_.lock();
}
const Status result = init_result_;
mu_.unlock();
return result;
}
}
static string SplitByWorker(const Node* node) {
string task;
string device;
CHECK(DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task,
&device))
<< "node: " << node->name() << " dev: " << node->assigned_device_name();
return task;
}
void MasterSession::ReffedClientGraph::TrackFeedsAndFetches(
Part* part, const GraphDef& graph_def, const PartitionOptions& popts) {
for (int i = 0; i < graph_def.node_size(); ++i) {
const NodeDef& ndef = graph_def.node(i);
const bool is_recv = ndef.op() == "_Recv";
const bool is_send = ndef.op() == "_Send";
if (is_recv || is_send) {
// Only send/recv nodes that were added as feeds and fetches
// (client-terminated) should be tracked. Other send/recv nodes
// are for transferring data between partitions / memory spaces.
bool client_terminated;
TF_CHECK_OK(GetNodeAttr(ndef, "client_terminated", &client_terminated));
if (client_terminated) {
string name;
TF_CHECK_OK(GetNodeAttr(ndef, "tensor_name", &name));
string send_device;
TF_CHECK_OK(GetNodeAttr(ndef, "send_device", &send_device));
string recv_device;
TF_CHECK_OK(GetNodeAttr(ndef, "recv_device", &recv_device));
uint64 send_device_incarnation;
TF_CHECK_OK(
GetNodeAttr(ndef, "send_device_incarnation",
reinterpret_cast<int64*>(&send_device_incarnation)));
const string& key =
Rendezvous::CreateKey(send_device, send_device_incarnation,
recv_device, name, FrameAndIter(0, 0));
if (is_recv) {
part->feed_key.insert({name, key});
} else {
part->key_fetch.insert({key, name});
}
}
}
}
}
Status MasterSession::ReffedClientGraph::DoBuildPartitions(
PartitionOptions popts, ClientGraph* client_graph,
std::unordered_map<string, GraphDef>* out_partitions) {
if (popts.need_to_record_start_times) {
CostModel cost_model(true);
cost_model.InitFromGraph(client_graph->graph);
// TODO(yuanbyu): Use the real cost model.
// execution_state_->MergeFromGlobal(&cost_model);
SlackAnalysis sa(&client_graph->graph, &cost_model);
sa.ComputeAsap(&popts.start_times);
}
// Partition the graph.
return Partition(popts, &client_graph->graph, out_partitions);
}
Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
const PartitionOptions& popts,
std::unordered_map<string, GraphDef> graph_partitions) {
partitions_.reserve(graph_partitions.size());
Status s;
for (auto& name_def : graph_partitions) {
partitions_.emplace_back();
Part* part = &partitions_.back();
part->name = name_def.first;
TrackFeedsAndFetches(part, name_def.second, popts);
part->worker = worker_cache_->GetOrCreateWorker(part->name);
if (part->worker == nullptr) {
s = errors::NotFound("worker ", part->name);
break;
}
}
if (!s.ok()) {
for (Part& part : partitions_) {
worker_cache_->ReleaseWorker(part.name, part.worker);
part.worker = nullptr;
}
return s;
}
struct Call {
RegisterGraphRequest req;
RegisterGraphResponse resp;
Status status;
};
const int num = partitions_.size();
gtl::InlinedVector<Call, 4> calls(num);
BlockingCounter done(num);
for (int i = 0; i < num; ++i) {
const Part& part = partitions_[i];
Call* c = &calls[i];
c->req.set_session_handle(session_handle_);
c->req.set_create_worker_session_called(!should_deregister_);
c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
*c->req.mutable_config_proto() = session_opts_.config;
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
*c->req.mutable_debug_options() =
callable_opts_.run_options().debug_options();
c->req.set_collective_graph_key(collective_graph_key_);
VLOG(2) << "Register " << c->req.graph_def().DebugString();
auto cb = [c, &done](const Status& s) {
c->status = s;
done.DecrementCount();
};
part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
}
done.Wait();
for (int i = 0; i < num; ++i) {
Call* c = &calls[i];
s.Update(c->status);
partitions_[i].graph_handle = c->resp.graph_handle();
}
return s;
}
// Helper class to manage "num" parallel RunGraph calls.
class RunManyGraphs {
public:
explicit RunManyGraphs(int num) : calls_(num), pending_(num) {}
~RunManyGraphs() {}
// Returns the index-th call.
struct Call {
CallOptions opts;
std::unique_ptr<MutableRunGraphRequestWrapper> req;
std::unique_ptr<MutableRunGraphResponseWrapper> resp;
};
Call* get(int index) { return &calls_[index]; }
// When the index-th call is done, updates the overall status.
void WhenDone(int index, const std::string& worker_name, const Status& s) {
TRACEPRINTF("Partition %d %s", index, s.ToString().c_str());
auto resp = get(index)->resp.get();
if (resp->status_code() != error::Code::OK) {
// resp->status_code will only be non-OK if s.ok().
mutex_lock l(mu_);
ReportBadStatus(Status(resp->status_code(),
strings::StrCat("From ", worker_name, ":\n",
resp->status_error_message())));
} else if (!s.ok()) {
mutex_lock l(mu_);
ReportBadStatus(Status(
s.code(),
strings::StrCat("From ", worker_name, ":\n", s.error_message())));
}
pending_.DecrementCount();
}
void StartCancel() {
mutex_lock l(mu_);
ReportBadStatus(errors::Cancelled("RunManyGraphs"));
}
void Wait() { pending_.Wait(); }
Status status() const {
mutex_lock l(mu_);
// Concat status objects in this StatusGroup to get the aggregated status,
// as each status in status_group_ is already summarized status.
return status_group_.as_concatenated_status();
}
private:
gtl::InlinedVector<Call, 4> calls_;
BlockingCounter pending_;
mutable mutex mu_;
StatusGroup status_group_ GUARDED_BY(mu_);
bool cancel_issued_ GUARDED_BY(mu_) = false;
void ReportBadStatus(const Status& s) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
VLOG(1) << "Master received error status " << s;
if (!cancel_issued_ && !StatusGroup::IsDerived(s)) {
// Only start cancelling other workers upon receiveing a non-derived
// error
cancel_issued_ = true;
VLOG(1) << "Master received error report. Cancelling remaining workers.";
for (Call& call : calls_) {
call.opts.StartCancel();
}
}
status_group_.Update(s);
}
TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
};
namespace {
Status AddSendFromClientRequest(const RunStepRequestWrapper& client_req,
MutableRunGraphRequestWrapper* worker_req,
size_t index, const string& send_key) {
return worker_req->AddSendFromRunStepRequest(client_req, index, send_key);
}
Status AddSendFromClientRequest(const RunCallableRequest& client_req,
MutableRunGraphRequestWrapper* worker_req,
size_t index, const string& send_key) {
return worker_req->AddSendFromRunCallableRequest(client_req, index, send_key);
}
// TODO(mrry): Add a full-fledged wrapper that avoids TensorProto copies for
// in-process messages.
struct RunCallableResponseWrapper {
RunCallableResponse* resp; // Not owned.
std::unordered_map<string, TensorProto> fetch_key_to_protos;
RunMetadata* mutable_metadata() { return resp->mutable_metadata(); }
Status AddTensorFromRunGraphResponse(
const string& tensor_name, MutableRunGraphResponseWrapper* worker_resp,
size_t index) {
// TODO(b/74355905): Add a specialized implementation that avoids
// copying the tensor into the RunCallableResponse when at least
// two of the {client, master, worker} are in the same process.
return worker_resp->RecvValue(index, &fetch_key_to_protos[tensor_name]);
}
};
} // namespace
template <class FetchListType, class ClientRequestType,
class ClientResponseType>
Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
const FetchListType& fetches, const MasterEnv* env, int64 step_id,
int64 execution_count, PerStepState* pss, CallOptions* call_opts,
const ClientRequestType& req, ClientResponseType* resp,
CancellationManager* cm, bool is_last_partial_run) {
// Collect execution cost stats on a smoothly decreasing frequency.
ExecutorOpts exec_opts;
if (pss->report_tensor_allocations_upon_oom) {
exec_opts.set_report_tensor_allocations_upon_oom(true);
}
if (pss->collect_costs) {
exec_opts.set_record_costs(true);
}
if (pss->collect_timeline) {
exec_opts.set_record_timeline(true);
}
if (pss->collect_rpcs) {
SetRPCLogging(true);
}
if (pss->collect_partition_graphs) {
exec_opts.set_record_partition_graphs(true);
}
if (pss->collect_costs || pss->collect_timeline) {
pss->step_stats.resize(partitions_.size());
}
const int num = partitions_.size();
RunManyGraphs calls(num);
for (int i = 0; i < num; ++i) {
const Part& part = partitions_[i];
RunManyGraphs::Call* c = calls.get(i);
c->req.reset(part.worker->CreateRunGraphRequest());
c->resp.reset(part.worker->CreateRunGraphResponse());
if (is_partial_) {
c->req->set_is_partial(is_partial_);
c->req->set_is_last_partial_run(is_last_partial_run);
}
c->req->set_session_handle(session_handle_);
c->req->set_create_worker_session_called(!should_deregister_);
c->req->set_graph_handle(part.graph_handle);
c->req->set_step_id(step_id);
*c->req->mutable_exec_opts() = exec_opts;
c->req->set_store_errors_in_response_body(true);
c->req->set_request_id(GetUniqueRequestId());
// If any feeds are provided, send the feed values together
// in the RunGraph request.
// In the partial case, we only want to include feeds provided in the req.
// In the non-partial case, all feeds in the request are in the part.
// We keep these as separate paths for now, to ensure we aren't
// inadvertently slowing down the normal run path.
if (is_partial_) {
for (const auto& name_index : feeds) {
const auto iter = part.feed_key.find(string(name_index.first));
if (iter == part.feed_key.end()) {
// The provided feed must be for a different partition.
continue;
}
const string& key = iter->second;
TF_RETURN_IF_ERROR(AddSendFromClientRequest(req, c->req.get(),
name_index.second, key));
}
// TODO(suharshs): Make a map from feed to fetch_key to make this faster.
// For now, we just iterate through partitions to find the matching key.
for (const string& req_fetch : fetches) {
for (const auto& key_fetch : part.key_fetch) {
if (key_fetch.second == req_fetch) {
c->req->add_recv_key(key_fetch.first);
break;
}
}
}
} else {
for (const auto& feed_key : part.feed_key) {
const string& feed = feed_key.first;
const string& key = feed_key.second;
auto iter = feeds.find(feed);
if (iter == feeds.end()) {
return errors::Internal("No feed index found for feed: ", feed);
}
const int64 feed_index = iter->second;
TF_RETURN_IF_ERROR(
AddSendFromClientRequest(req, c->req.get(), feed_index, key));
}
for (const auto& key_fetch : part.key_fetch) {
const string& key = key_fetch.first;
c->req->add_recv_key(key);
}
}
}
// Issues RunGraph calls.
for (int i = 0; i < num; ++i) {
const Part& part = partitions_[i];
RunManyGraphs::Call* call = calls.get(i);
TRACEPRINTF("Partition %d %s", i, part.name.c_str());
part.worker->RunGraphAsync(&call->opts, call->req.get(), call->resp.get(),
std::bind(&RunManyGraphs::WhenDone, &calls, i,
part.name, std::placeholders::_1));
}
// Waits for the RunGraph calls.
call_opts->SetCancelCallback([&calls]() {
LOG(INFO) << "Client requested cancellation for RunStep, cancelling "
"worker operations.";
calls.StartCancel();
});
auto token = cm->get_cancellation_token();
const bool success =
cm->RegisterCallback(token, [&calls]() { calls.StartCancel(); });
if (!success) {
calls.StartCancel();
}
calls.Wait();
call_opts->ClearCancelCallback();
if (success) {
cm->DeregisterCallback(token);
} else {
return errors::Cancelled("Step was cancelled");
}
TF_RETURN_IF_ERROR(calls.status());
// Collects fetches and metadata.
Status status;
for (int i = 0; i < num; ++i) {
const Part& part = partitions_[i];
MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
if (iter == part.key_fetch.end()) {
status.Update(errors::Internal("Unexpected fetch key: ",
run_graph_resp->recv_key(j)));
break;
}
const string& fetch = iter->second;
status.Update(
resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
if (!status.ok()) {
break;
}
}
if (pss->collect_timeline) {
pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
}
if (pss->collect_costs) {
CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
for (int j = 0; j < cost_graph->node_size(); ++j) {
resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
cost_graph->mutable_node(j));
}
}
if (pss->collect_partition_graphs) {
protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
resp->mutable_metadata()->mutable_partition_graphs();
for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
partition_graph_defs->Add()->Swap(
run_graph_resp->mutable_partition_graph(i));
}
}
}
return status;
}
Status MasterSession::ReffedClientGraph::RunPartitions(
const MasterEnv* env, int64 step_id, int64 execution_count,
PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp, CancellationManager* cm,
const bool is_last_partial_run) {
VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
<< execution_count;
// Maps the names of fed tensors to their index in `req`.
std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
for (size_t i = 0; i < req.num_feeds(); ++i) {
if (!feeds.insert({req.feed_name(i), i}).second) {
return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i));
}
}
std::vector<string> fetches;
fetches.reserve(req.num_fetches());
for (size_t i = 0; i < req.num_fetches(); ++i) {
fetches.push_back(req.fetch_name(i));
}
return RunPartitionsHelper(feeds, fetches, env, step_id, execution_count, pss,
call_opts, req, resp, cm, is_last_partial_run);
}
Status MasterSession::ReffedClientGraph::RunPartitions(
const MasterEnv* env, int64 step_id, int64 execution_count,
PerStepState* pss, CallOptions* call_opts, const RunCallableRequest& req,
RunCallableResponse* resp, CancellationManager* cm) {
VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
<< execution_count;
// Maps the names of fed tensors to their index in `req`.
std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
for (size_t i = 0; i < callable_opts_.feed_size(); ++i) {
if (!feeds.insert({callable_opts_.feed(i), i}).second) {
// MakeCallable will fail if there are two feeds with the same name.
return errors::Internal("Duplicated feeds in callable: ",
callable_opts_.feed(i));
}
}
// Create a wrapped response object to collect the fetched values and
// rearrange them for the RunCallableResponse.
RunCallableResponseWrapper wrapped_resp;
wrapped_resp.resp = resp;
TF_RETURN_IF_ERROR(RunPartitionsHelper(
feeds, callable_opts_.fetch(), env, step_id, execution_count, pss,
call_opts, req, &wrapped_resp, cm, false /* is_last_partial_run */));
// Collects fetches.
// TODO(b/74355905): Add a specialized implementation that avoids
// copying the tensor into the RunCallableResponse when at least
// two of the {client, master, worker} are in the same process.
for (const string& fetch : callable_opts_.fetch()) {
TensorProto* fetch_proto = resp->mutable_fetch()->Add();
auto iter = wrapped_resp.fetch_key_to_protos.find(fetch);
if (iter == wrapped_resp.fetch_key_to_protos.end()) {
return errors::Internal("Worker did not return a value for fetch: ",
fetch);
}
fetch_proto->Swap(&iter->second);
}
return Status::OK();
}
namespace {
class CleanupBroadcastHelper {
public:
CleanupBroadcastHelper(int64 step_id, int num_calls, StatusCallback done)
: resps_(num_calls), num_pending_(num_calls), done_(std::move(done)) {
req_.set_step_id(step_id);
}
// Returns a non-owned pointer to a request buffer for all calls.
CleanupGraphRequest* request() { return &req_; }
// Returns a non-owned pointer to a response buffer for the ith call.
CleanupGraphResponse* response(int i) { return &resps_[i]; }
// Called when the ith response is received.
void call_done(int i, const Status& s) {
bool run_callback = false;
Status status_copy;
{
mutex_lock l(mu_);
status_.Update(s);
if (--num_pending_ == 0) {
run_callback = true;
status_copy = status_;
}
}
if (run_callback) {
done_(status_copy);
// This is the last call, so delete the helper object.
delete this;
}
}
private:
// A single request shared between all workers.
CleanupGraphRequest req_;
// One response buffer for each worker.
gtl::InlinedVector<CleanupGraphResponse, 4> resps_;
mutex mu_;
// Number of requests remaining to be collected.
int num_pending_ GUARDED_BY(mu_);
// Aggregate status of the operation.
Status status_ GUARDED_BY(mu_);
// Callback to be called when all operations complete.
StatusCallback done_;
TF_DISALLOW_COPY_AND_ASSIGN(CleanupBroadcastHelper);
};
} // namespace
void MasterSession::ReffedClientGraph::CleanupPartitionsAsync(
int64 step_id, StatusCallback done) {
const int num = partitions_.size();
// Helper object will be deleted when the final call completes.
CleanupBroadcastHelper* helper =
new CleanupBroadcastHelper(step_id, num, std::move(done));
for (int i = 0; i < num; ++i) {
const Part& part = partitions_[i];
part.worker->CleanupGraphAsync(
helper->request(), helper->response(i),
[helper, i](const Status& s) { helper->call_done(i, s); });
}
}
void MasterSession::ReffedClientGraph::ProcessStats(int64 step_id,
PerStepState* pss,
ProfileHandler* ph,
const RunOptions& options,
RunMetadata* resp) {
if (!pss->collect_costs && !pss->collect_timeline) return;
// Out-of-band logging data is collected now, during post-processing.
if (pss->collect_timeline) {
SetRPCLogging(false);
RetrieveLogs(step_id, &pss->rpc_stats);
}
for (size_t i = 0; i < partitions_.size(); ++i) {
const StepStats& ss = pss->step_stats[i];
if (ph) {
for (const auto& ds : ss.dev_stats()) {
ProcessDeviceStats(ph, ds, false /*is_rpc*/);
}
}
}
if (ph) {
for (const auto& ds : pss->rpc_stats.dev_stats()) {
ProcessDeviceStats(ph, ds, true /*is_rpc*/);
}
ph->StepDone(pss->start_micros, pss->end_micros,
Microseconds(0) /*cleanup_time*/, 0 /*total_runops*/,
Status::OK());
}
// Assemble all stats for this timeline into a merged StepStats.
if (pss->collect_timeline) {
StepStats step_stats_proto;
step_stats_proto.Swap(&pss->rpc_stats);
for (size_t i = 0; i < partitions_.size(); ++i) {
step_stats_proto.MergeFrom(pss->step_stats[i]);
pss->step_stats[i].Clear();
}
pss->step_stats.clear();
// Copy the stats back, but only for on-demand profiling to avoid slowing
// down calls that trigger the automatic profiling.
if (options.trace_level() == RunOptions::FULL_TRACE) {
resp->mutable_step_stats()->Swap(&step_stats_proto);
} else {
// If FULL_TRACE, it can be fetched from Session API, no need for
// duplicated publishing.
stats_publisher_->PublishStatsProto(step_stats_proto);
}
}
}
void MasterSession::ReffedClientGraph::ProcessDeviceStats(
ProfileHandler* ph, const DeviceStepStats& ds, bool is_rpc) {
const string& dev_name = ds.device();
VLOG(1) << "Device " << dev_name << " reports stats for "
<< ds.node_stats_size() << " nodes";
for (const auto& ns : ds.node_stats()) {
if (is_rpc) {
// We don't have access to a good Node pointer, so we rely on
// sufficient data being present in the NodeExecStats.
ph->RecordOneOp(dev_name, ns, true /*is_copy*/, "", ns.node_name(),
ns.timeline_label());
} else {
auto iter = name_to_node_details_.find(ns.node_name());
const bool found_node_in_graph = iter != name_to_node_details_.end();
if (!found_node_in_graph && ns.timeline_label().empty()) {
// The counter incrementing is not thread-safe. But we don't really
// care.
// TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N for
// more general usage.
static int log_counter = 0;
if (log_counter < 10) {
log_counter++;
LOG(WARNING) << "Failed to find node " << ns.node_name()
<< " for dev " << dev_name;
}
continue;
}
const string& optype =
found_node_in_graph ? iter->second.type_string : ns.node_name();
string details;
if (!ns.timeline_label().empty()) {
details = ns.timeline_label();
} else if (found_node_in_graph) {
details = DetailText(iter->second, ns);
} else {
// Leave details string empty
}
ph->RecordOneOp(dev_name, ns, false /*is_copy*/, ns.node_name(), optype,
details);
}
}
}
// TODO(suharshs): Merge with CheckFetches in DirectSession.
// TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends
// on once at setup time to prevent us from computing the dependencies
// everytime.
Status MasterSession::ReffedClientGraph::CheckFetches(
const RunStepRequestWrapper& req, const RunState* run_state,
GraphExecutionState* execution_state) {
// Build the set of pending feeds that we haven't seen.
std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
for (const auto& input : run_state->pending_inputs) {
// Skip if already fed.
if (input.second) continue;
TensorId id(ParseTensorName(input.first));
const Node* n = execution_state->get_node_by_name(string(id.first));
if (n == nullptr) {
return errors::NotFound("Feed ", input.first, ": not found");
}
pending_feeds.insert(id);
}
for (size_t i = 0; i < req.num_feeds(); ++i) {
const TensorId id(ParseTensorName(req.feed_name(i)));
pending_feeds.erase(id);
}
// Initialize the stack with the fetch nodes.
std::vector<const Node*> stack;
for (size_t i = 0; i < req.num_fetches(); ++i) {
const string& fetch = req.fetch_name(i);
const TensorId id(ParseTensorName(fetch));
const Node* n = execution_state->get_node_by_name(string(id.first));
if (n == nullptr) {
return errors::NotFound("Fetch ", fetch, ": not found");
}
stack.push_back(n);
}
// Any tensor needed for fetches can't be in pending_feeds.
// We need to use the original full graph from execution state.
const Graph* graph = execution_state->full_graph();
std::vector<bool> visited(graph->num_node_ids(), false);
while (!stack.empty()) {
const Node* n = stack.back();
stack.pop_back();
for (const Edge* in_edge : n->in_edges()) {
const Node* in_node = in_edge->src();
if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
return errors::InvalidArgument("Fetch ", in_node->name(), ":",
in_edge->src_output(),
" can't be computed from the feeds"
" that have been fed so far.");
}
if (!visited[in_node->id()]) {
visited[in_node->id()] = true;
stack.push_back(in_node);
}
}
}
return Status::OK();
}
// Asynchronously deregisters subgraphs on the workers, without waiting for the
// result.
void MasterSession::ReffedClientGraph::DeregisterPartitions() {
struct Call {
DeregisterGraphRequest req;
DeregisterGraphResponse resp;
};
for (Part& part : partitions_) {
// The graph handle may be empty if we failed during partition registration.
if (!part.graph_handle.empty()) {
Call* c = new Call;
c->req.set_session_handle(session_handle_);
c->req.set_create_worker_session_called(!should_deregister_);
c->req.set_graph_handle(part.graph_handle);
// NOTE(mrry): We must capture `worker_cache_` since `this`
// could be deleted before the callback is called.
WorkerCacheInterface* worker_cache = worker_cache_;
const string name = part.name;
WorkerInterface* w = part.worker;
CHECK_NOTNULL(w);
auto cb = [worker_cache, c, name, w](const Status& s) {
if (!s.ok()) {
// This error is potentially benign, so we don't log at the
// error level.
LOG(INFO) << "DeregisterGraph error: " << s;
}
delete c;
worker_cache->ReleaseWorker(name, w);
};
w->DeregisterGraphAsync(&c->req, &c->resp, cb);
}
}
}
namespace {
void CopyAndSortStrings(size_t size,
const std::function<string(size_t)>& input_accessor,
protobuf::RepeatedPtrField<string>* output) {
std::vector<string> temp;
temp.reserve(size);
for (size_t i = 0; i < size; ++i) {
output->Add(input_accessor(i));
}
std::sort(output->begin(), output->end());
}
} // namespace
void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
const ConfigProto& config,
BuildGraphOptions* opts) {
CallableOptions* callable_opts = &opts->callable_options;
CopyAndSortStrings(
req.num_feeds(), [&req](size_t i) { return req.feed_name(i); },
callable_opts->mutable_feed());
CopyAndSortStrings(
req.num_fetches(), [&req](size_t i) { return req.fetch_name(i); },
callable_opts->mutable_fetch());
CopyAndSortStrings(
req.num_targets(), [&req](size_t i) { return req.target_name(i); },
callable_opts->mutable_target());
if (!req.options().debug_options().debug_tensor_watch_opts().empty()) {
*callable_opts->mutable_run_options()->mutable_debug_options() =
req.options().debug_options();
}
opts->collective_graph_key =
req.options().experimental().collective_graph_key();
if (config.experimental().collective_deterministic_sequential_execution()) {
opts->collective_order = GraphCollectiveOrder::kEdges;
} else if (config.experimental().collective_nccl()) {
opts->collective_order = GraphCollectiveOrder::kAttrs;
}
}
void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
BuildGraphOptions* opts) {
CallableOptions* callable_opts = &opts->callable_options;
CopyAndSortStrings(
req.feed_size(), [&req](size_t i) { return req.feed(i); },
callable_opts->mutable_feed());
CopyAndSortStrings(
req.fetch_size(), [&req](size_t i) { return req.fetch(i); },
callable_opts->mutable_fetch());
CopyAndSortStrings(
req.target_size(), [&req](size_t i) { return req.target(i); },
callable_opts->mutable_target());
// TODO(cais): Add TFDBG support to partial runs.
}
uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
uint64 h = 0x2b992ddfa23249d6ull;
for (const string& name : opts.callable_options.feed()) {
h = Hash64(name.c_str(), name.size(), h);
}
for (const string& name : opts.callable_options.target()) {
h = Hash64(name.c_str(), name.size(), h);
}
for (const string& name : opts.callable_options.fetch()) {
h = Hash64(name.c_str(), name.size(), h);
}
const DebugOptions& debug_options =
opts.callable_options.run_options().debug_options();
if (!debug_options.debug_tensor_watch_opts().empty()) {
const string watch_summary =
SummarizeDebugTensorWatches(debug_options.debug_tensor_watch_opts());
h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
}
return h;
}
string BuildGraphOptionsString(const BuildGraphOptions& opts) {
string buf;
for (const string& name : opts.callable_options.feed()) {
strings::StrAppend(&buf, " FdE: ", name);
}
strings::StrAppend(&buf, "\n");
for (const string& name : opts.callable_options.target()) {
strings::StrAppend(&buf, " TN: ", name);
}
strings::StrAppend(&buf, "\n");
for (const string& name : opts.callable_options.fetch()) {
strings::StrAppend(&buf, " FeE: ", name);
}
if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
strings::StrAppend(&buf, "\nGK: ", opts.collective_graph_key);
}
strings::StrAppend(&buf, "\n");
return buf;
}
MasterSession::MasterSession(
const SessionOptions& opt, const MasterEnv* env,
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
std::unique_ptr<WorkerCacheInterface> worker_cache,
std::unique_ptr<DeviceSet> device_set,
std::vector<string> filtered_worker_list,
StatsPublisherFactory stats_publisher_factory)
: session_opts_(opt),
env_(env),
handle_(strings::FpToString(random::New64())),
remote_devs_(std::move(remote_devs)),
worker_cache_(std::move(worker_cache)),
devices_(std::move(device_set)),
filtered_worker_list_(std::move(filtered_worker_list)),
stats_publisher_factory_(std::move(stats_publisher_factory)),
graph_version_(0),
run_graphs_(5),
partial_run_graphs_(5) {
UpdateLastAccessTime();
CHECK(devices_) << "device_set was null!";
VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
<< " #remote " << remote_devs_->size();
VLOG(1) << "Start master session " << handle_
<< " with config: " << session_opts_.config.ShortDebugString();
}
MasterSession::~MasterSession() {
for (const auto& iter : run_graphs_) iter.second->Unref();
for (const auto& iter : partial_run_graphs_) iter.second->Unref();
}
void MasterSession::UpdateLastAccessTime() {
last_access_time_usec_.store(Env::Default()->NowMicros());
}
Status MasterSession::Create(GraphDef&& graph_def,
const WorkerCacheFactoryOptions& options) {
if (session_opts_.config.use_per_session_threads() ||
session_opts_.config.session_inter_op_thread_pool_size() > 0) {
return errors::InvalidArgument(
"Distributed session does not support session thread pool options.");
}
if (session_opts_.config.graph_options().place_pruned_graph()) {
// TODO(b/29900832): Fix this or remove the option.
LOG(WARNING) << "Distributed session does not support the "
"place_pruned_graph option.";
session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false);
}
GraphExecutionStateOptions execution_options;
execution_options.device_set = devices_.get();
execution_options.session_options = &session_opts_;
{
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
std::move(graph_def), execution_options, &execution_state_));
}
should_delete_worker_sessions_ = true;
return CreateWorkerSessions(options);
}
Status MasterSession::CreateWorkerSessions(
const WorkerCacheFactoryOptions& options) {
const std::vector<string> worker_names = filtered_worker_list_;
WorkerCacheInterface* worker_cache = get_worker_cache();
struct WorkerGroup {
// The worker name. (Not owned.)
const string* name;
// The worker referenced by name. (Not owned.)
WorkerInterface* worker = nullptr;
// Request and responses used for a given worker.
CreateWorkerSessionRequest request;
CreateWorkerSessionResponse response;
Status status = Status::OK();
};
BlockingCounter done(worker_names.size());
std::vector<WorkerGroup> workers(worker_names.size());
// Release the workers.
auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
for (auto&& worker_group : workers) {
if (worker_group.worker != nullptr) {
worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
}
}
});
Status status = Status::OK();
// Create all the workers & kick off the computations.
for (size_t i = 0; i < worker_names.size(); ++i) {
workers[i].name = &worker_names[i];
workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]);
workers[i].request.set_session_handle(handle_);
if (session_opts_.config.experimental()
.share_cluster_devices_in_session()) {
for (const auto& remote_dev : devices_->devices()) {
*workers[i].request.add_cluster_device_attributes() =
remote_dev->attributes();
}
}
DeviceNameUtils::ParsedName name;
if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
status = errors::Internal("Could not parse name ", worker_names[i]);
LOG(WARNING) << status;
return status;
}
if (!name.has_job || !name.has_task) {
status = errors::Internal("Incomplete worker name ", worker_names[i]);
LOG(WARNING) << status;
return status;
}
if (options.cluster_def) {
*workers[i].request.mutable_server_def()->mutable_cluster() =
*options.cluster_def;
workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
workers[i].request.mutable_server_def()->set_job_name(name.job);
workers[i].request.mutable_server_def()->set_task_index(name.task);
// Session state is always isolated when ClusterSpec propagation
// is in use.
workers[i].request.set_isolate_session_state(true);
} else {
// NOTE(mrry): Do not set any component of the ServerDef,
// because the worker will use its local configuration.
workers[i].request.set_isolate_session_state(
session_opts_.config.isolate_session_state());
}
if (session_opts_.config.experimental()
.share_session_state_in_clusterspec_propagation()) {
// In a dynamic cluster, the ClusterSpec info is usually propagated by
// master sessions. However, in data parallel training with multiple
// masters
// ("between-graph replication"), we need to disable isolation for
// different worker sessions to update the same variables in PS tasks.
workers[i].request.set_isolate_session_state(false);
}
}
for (size_t i = 0; i < worker_names.size(); ++i) {
auto cb = [i, &workers, &done](const Status& s) {
workers[i].status = s;
done.DecrementCount();
};
workers[i].worker->CreateWorkerSessionAsync(&workers[i].request,
&workers[i].response, cb);
}
done.Wait();
for (size_t i = 0; i < workers.size(); ++i) {
status.Update(workers[i].status);
}
return status;
}
Status MasterSession::DeleteWorkerSessions() {
WorkerCacheInterface* worker_cache = get_worker_cache();
const std::vector<string>& worker_names = filtered_worker_list_;
struct WorkerGroup {
// The worker name. (Not owned.)
const string* name;
// The worker referenced by name. (Not owned.)
WorkerInterface* worker = nullptr;
CallOptions call_opts;
// Request and responses used for a given worker.
DeleteWorkerSessionRequest request;
DeleteWorkerSessionResponse response;
Status status = Status::OK();
};
BlockingCounter done(worker_names.size());
std::vector<WorkerGroup> workers(worker_names.size());
// Release the workers.
auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
for (auto&& worker_group : workers) {
if (worker_group.worker != nullptr) {
worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
}
}
});
Status status = Status::OK();
// Create all the workers & kick off the computations.
for (size_t i = 0; i < worker_names.size(); ++i) {
workers[i].name = &worker_names[i];
workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]);
workers[i].request.set_session_handle(handle_);
// Since the worker may have gone away, set a timeout to avoid blocking the
// session-close operation.
workers[i].call_opts.SetTimeout(10000);
}
for (size_t i = 0; i < worker_names.size(); ++i) {
auto cb = [i, &workers, &done](const Status& s) {
workers[i].status = s;
done.DecrementCount();
};
workers[i].worker->DeleteWorkerSessionAsync(
&workers[i].call_opts, &workers[i].request, &workers[i].response, cb);
}
done.Wait();
for (size_t i = 0; i < workers.size(); ++i) {
status.Update(workers[i].status);
}
return status;
}
Status MasterSession::ListDevices(ListDevicesResponse* resp) const {
if (worker_cache_) {
// This is a ClusterSpec-propagated session, and thus env_->local_devices
// are invalid.
// Mark the "client_device" as the sole local device.
const Device* client_device = devices_->client_device();
for (const Device* dev : devices_->devices()) {
if (dev != client_device) {
*(resp->add_remote_device()) = dev->attributes();
}
}
*(resp->add_local_device()) = client_device->attributes();
} else {
for (Device* dev : env_->local_devices) {
*(resp->add_local_device()) = dev->attributes();
}
for (auto&& dev : *remote_devs_) {
*(resp->add_local_device()) = dev->attributes();
}
}
return Status::OK();
}
Status MasterSession::Extend(const ExtendSessionRequest* req,
ExtendSessionResponse* resp) {
UpdateLastAccessTime();
std::unique_ptr<GraphExecutionState> extended_execution_state;
{
mutex_lock l(mu_);
if (closed_) {
return errors::FailedPrecondition("Session is closed.");
}
if (graph_version_ != req->current_graph_version()) {
return errors::Aborted("Current version is ", graph_version_,
" but caller expected ",
req->current_graph_version(), ".");
}
CHECK(execution_state_);
TF_RETURN_IF_ERROR(
execution_state_->Extend(req->graph_def(), &extended_execution_state));
CHECK(extended_execution_state);
// The old execution state will be released outside the lock.
execution_state_.swap(extended_execution_state);
++graph_version_;
resp->set_new_graph_version(graph_version_);
}
return Status::OK();
}
WorkerCacheInterface* MasterSession::get_worker_cache() const {
if (worker_cache_) {
return worker_cache_.get();
}
return env_->worker_cache;
}
Status MasterSession::StartStep(const BuildGraphOptions& opts, bool is_partial,
ReffedClientGraph** out_rcg, int64* out_count) {
const uint64 hash = HashBuildGraphOptions(opts);
{
mutex_lock l(mu_);
// TODO(suharshs): We cache partial run graphs and run graphs separately
// because there is preprocessing that needs to only be run for partial
// run calls.
RCGMap* m = is_partial ? &partial_run_graphs_ : &run_graphs_;
auto iter = m->find(hash);
if (iter == m->end()) {
// We have not seen this subgraph before. Build the subgraph and
// cache it.
VLOG(1) << "Unseen hash " << hash << " for "
<< BuildGraphOptionsString(opts) << " is_partial = " << is_partial
<< "\n";
std::unique_ptr<ClientGraph> client_graph;
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
WorkerCacheInterface* worker_cache = get_worker_cache();
auto entry = new ReffedClientGraph(
handle_, opts, std::move(client_graph), session_opts_,
stats_publisher_factory_, is_partial, worker_cache,
!should_delete_worker_sessions_);
iter = m->insert({hash, entry}).first;
VLOG(1) << "Preparing to execute new graph";
}
*out_rcg = iter->second;
(*out_rcg)->Ref();
*out_count = (*out_rcg)->get_and_increment_execution_count();
}
return Status::OK();
}
void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
RCGMap* rcg_map) {
VLOG(1) << "Discarding all reffed graphs";
for (auto p : *rcg_map) {
ReffedClientGraph* rcg = p.second;
if (to_unref) {
to_unref->push_back(rcg);
} else {
rcg->Unref();
}
}
rcg_map->clear();
}
uint64 MasterSession::NewStepId(int64 graph_key) {
if (graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
// StepId must leave the most-significant 7 bits empty for future use.
return random::New64() & (((1uLL << 56) - 1) | (1uLL << 56));
} else {
uint64 step_id = env_->collective_executor_mgr->NextStepId(graph_key);
int32 retry_count = 0;
while (step_id == CollectiveExecutor::kInvalidId) {
Notification note;
Status status;
env_->collective_executor_mgr->RefreshStepIdSequenceAsync(
graph_key, [&status, &note](const Status& s) {
status = s;
note.Notify();
});
note.WaitForNotification();
if (!status.ok()) {
LOG(ERROR) << "Bad status from "
"collective_executor_mgr->RefreshStepIdSequence: "
<< status << ". Retrying.";
int64 delay_micros = std::min(60000000LL, 1000000LL * ++retry_count);
Env::Default()->SleepForMicroseconds(delay_micros);
} else {
step_id = env_->collective_executor_mgr->NextStepId(graph_key);
}
}
return step_id;
}
}
Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
PartialRunSetupResponse* resp) {
std::vector<string> inputs, outputs, targets;
for (const auto& feed : req->feed()) {
inputs.push_back(feed);
}
for (const auto& fetch : req->fetch()) {
outputs.push_back(fetch);
}
for (const auto& target : req->target()) {
targets.push_back(target);
}
string handle = std::to_string(partial_run_handle_counter_.fetch_add(1));
ReffedClientGraph* rcg = nullptr;
// Prepare.
BuildGraphOptions opts;
BuildBuildGraphOptions(*req, &opts);
int64 count = 0;
TF_RETURN_IF_ERROR(StartStep(opts, true, &rcg, &count));
rcg->Ref();
RunState* run_state =
new RunState(inputs, outputs, rcg,
NewStepId(BuildGraphOptions::kNoCollectiveGraphKey), count);
{
mutex_lock l(mu_);
partial_runs_.emplace(
std::make_pair(handle, std::unique_ptr<RunState>(run_state)));
}
TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
resp->set_partial_run_handle(handle);
return Status::OK();
}
Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp) {
UpdateLastAccessTime();
{
mutex_lock l(mu_);
if (closed_) {
return errors::FailedPrecondition("Session is closed.");
}
++num_running_;
// Note: all code paths must eventually call MarkRunCompletion()
// in order to appropriate decrement the num_running_ counter.
}
Status status;
if (!req.partial_run_handle().empty()) {
status = DoPartialRun(opts, req, resp);
} else {
status = DoRunWithLocalExecution(opts, req, resp);
}
return status;
}
// Decrements num_running_ and broadcasts if num_running_ is zero.
void MasterSession::MarkRunCompletion() {
mutex_lock l(mu_);
--num_running_;
if (num_running_ == 0) {
num_running_is_zero_.notify_all();
}
}
Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
// Registers subgraphs if haven't done so.
PartitionOptions popts;
popts.node_to_loc = SplitByWorker;
// The closures popts.{new_name,get_incarnation} are called synchronously in
// RegisterPartitions() below, so do not need a Ref()/Unref() pair to keep
// "this" alive during the closure.
popts.new_name = [this](const string& prefix) {
mutex_lock l(mu_);
return strings::StrCat(prefix, "_S", next_node_id_++);
};
popts.get_incarnation = [this](const string& name) -> int64 {
Device* d = devices_->FindDeviceByName(name);
if (d == nullptr) {
return PartitionOptions::kIllegalIncarnation;
} else {
return d->attributes().incarnation();
}
};
popts.control_flow_added = false;
const bool enable_bfloat16_sendrecv =
session_opts_.config.graph_options().enable_bfloat16_sendrecv();
popts.should_cast = [enable_bfloat16_sendrecv](const Edge* e) {
if (e->IsControlEdge()) {
return DT_FLOAT;
}
DataType dtype = BaseType(e->src()->output_type(e->src_output()));
if (enable_bfloat16_sendrecv && dtype == DT_FLOAT) {
return DT_BFLOAT16;
} else {
return dtype;
}
};
if (session_opts_.config.graph_options().enable_recv_scheduling()) {
popts.scheduling_for_recvs = true;
popts.need_to_record_start_times = true;
}
TF_RETURN_IF_ERROR(rcg->RegisterPartitions(std::move(popts)));
return Status::OK();
}
Status MasterSession::DoPartialRun(CallOptions* opts,
const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp) {
auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
const string& prun_handle = req.partial_run_handle();
RunState* run_state = nullptr;
{
mutex_lock l(mu_);
auto it = partial_runs_.find(prun_handle);
if (it == partial_runs_.end()) {
return errors::InvalidArgument(
"Must run PartialRunSetup before performing partial runs");
}
run_state = it->second.get();
}
// CollectiveOps are not supported in partial runs.
if (req.options().experimental().collective_graph_key() !=
BuildGraphOptions::kNoCollectiveGraphKey) {
return errors::InvalidArgument(
"PartialRun does not support Collective ops. collective_graph_key "
"must be kNoCollectiveGraphKey.");
}
// If this is the first partial run, initialize the PerStepState.
if (!run_state->step_started) {
run_state->step_started = true;
PerStepState pss;
const auto count = run_state->count;
pss.collect_timeline =
req.options().trace_level() == RunOptions::FULL_TRACE;
pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE;
pss.report_tensor_allocations_upon_oom =
req.options().report_tensor_allocations_upon_oom();
// Build the cost model every 'build_cost_model_every' steps after skipping
// an
// initial 'build_cost_model_after' steps.
const int64 build_cost_model_after =
session_opts_.config.graph_options().build_cost_model_after();
const int64 build_cost_model_every =
session_opts_.config.graph_options().build_cost_model();
pss.collect_costs =
build_cost_model_every > 0 &&
((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
pss.collect_partition_graphs = req.options().output_partition_graphs();
std::unique_ptr<ProfileHandler> ph = run_state->rcg->GetProfileHandler(
run_state->step_id, count, req.options());
if (ph) {
pss.collect_timeline = true;
pss.collect_rpcs = ph->should_collect_rpcs();
}
run_state->pss = std::move(pss);
run_state->ph = std::move(ph);
}
// Make sure that this is a new set of feeds that are still pending.
for (size_t i = 0; i < req.num_feeds(); ++i) {
const string& feed = req.feed_name(i);
auto it = run_state->pending_inputs.find(feed);
if (it == run_state->pending_inputs.end()) {
return errors::InvalidArgument(
"The feed ", feed, " was not specified in partial_run_setup.");
} else if (it->second) {
return errors::InvalidArgument("The feed ", feed,
" has already been fed.");
}
}
// Check that this is a new set of fetches that are still pending.
for (size_t i = 0; i < req.num_fetches(); ++i) {
const string& fetch = req.fetch_name(i);
auto it = run_state->pending_outputs.find(fetch);
if (it == run_state->pending_outputs.end()) {
return errors::InvalidArgument(
"The fetch ", fetch, " was not specified in partial_run_setup.");
} else if (it->second) {
return errors::InvalidArgument("The fetch ", fetch,
" has already been fetched.");
}
}
// Ensure that the requested fetches can be computed from the provided feeds.
{
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(
run_state->rcg->CheckFetches(req, run_state, execution_state_.get()));
}
// Determine if this partial run satisfies all the pending inputs and outputs.
for (size_t i = 0; i < req.num_feeds(); ++i) {
auto it = run_state->pending_inputs.find(req.feed_name(i));
it->second = true;
}
for (size_t i = 0; i < req.num_fetches(); ++i) {
auto it = run_state->pending_outputs.find(req.fetch_name(i));
it->second = true;
}
bool is_last_partial_run = run_state->PendingDone();
Status s = run_state->rcg->RunPartitions(
env_, run_state->step_id, run_state->count, &run_state->pss, opts, req,
resp, &cancellation_manager_, is_last_partial_run);
// Delete the run state if there is an error or all fetches are done.
if (!s.ok() || is_last_partial_run) {
ReffedClientGraph* rcg = run_state->rcg;
run_state->pss.end_micros = Env::Default()->NowMicros();
// Schedule post-processing and cleanup to be done asynchronously.
Ref();
rcg->Ref();
rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(),
req.options(), resp->mutable_metadata());
cleanup.release(); // MarkRunCompletion called in done closure.
rcg->CleanupPartitionsAsync(
run_state->step_id, [this, rcg, prun_handle](const Status& s) {
if (!s.ok()) {
LOG(ERROR) << "Cleanup partition error: " << s;
}
rcg->Unref();
MarkRunCompletion();
Unref();
});
mutex_lock l(mu_);
partial_runs_.erase(prun_handle);
}
return s;
}
Status MasterSession::CreateDebuggerState(
const DebugOptions& debug_options, const RunStepRequestWrapper& req,
int64 rcg_execution_count,
std::unique_ptr<DebuggerStateInterface>* debugger_state) {
TF_RETURN_IF_ERROR(
DebuggerStateRegistry::CreateState(debug_options, debugger_state));
std::vector<string> input_names;
for (size_t i = 0; i < req.num_feeds(); ++i) {
input_names.push_back(req.feed_name(i));
}
std::vector<string> output_names;
for (size_t i = 0; i < req.num_fetches(); ++i) {
output_names.push_back(req.fetch_name(i));
}
std::vector<string> target_names;
for (size_t i = 0; i < req.num_targets(); ++i) {
target_names.push_back(req.target_name(i));
}
// TODO(cais): We currently use -1 as a dummy value for session run count.
// While this counter value is straightforward to define and obtain for
// DirectSessions, it is less so for non-direct Sessions. Devise a better
// way to get its value when the need arises.
TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
debug_options.global_step(), rcg_execution_count, rcg_execution_count,
input_names, output_names, target_names));
return Status::OK();
}
void MasterSession::FillPerStepState(MasterSession::ReffedClientGraph* rcg,
const RunOptions& run_options,
uint64 step_id, int64 count,
PerStepState* out_pss,
std::unique_ptr<ProfileHandler>* out_ph) {
out_pss->collect_timeline =
run_options.trace_level() == RunOptions::FULL_TRACE;
out_pss->collect_rpcs = run_options.trace_level() == RunOptions::FULL_TRACE;
out_pss->report_tensor_allocations_upon_oom =
run_options.report_tensor_allocations_upon_oom();
// Build the cost model every 'build_cost_model_every' steps after skipping an
// initial 'build_cost_model_after' steps.
const int64 build_cost_model_after =
session_opts_.config.graph_options().build_cost_model_after();
const int64 build_cost_model_every =
session_opts_.config.graph_options().build_cost_model();
out_pss->collect_costs =
build_cost_model_every > 0 &&
((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
out_pss->collect_partition_graphs = run_options.output_partition_graphs();
*out_ph = rcg->GetProfileHandler(step_id, count, run_options);
if (*out_ph) {
out_pss->collect_timeline = true;
out_pss->collect_rpcs = (*out_ph)->should_collect_rpcs();
}
}
Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
uint64 step_id,
const RunOptions& run_options,
PerStepState* pss,
const std::unique_ptr<ProfileHandler>& ph,
const Status& run_status,
RunMetadata* out_run_metadata) {
Status s = run_status;
if (s.ok()) {
pss->end_micros = Env::Default()->NowMicros();
if (rcg->collective_graph_key() !=
BuildGraphOptions::kNoCollectiveGraphKey) {
env_->collective_executor_mgr->RetireStepId(rcg->collective_graph_key(),
step_id);
}
// Schedule post-processing and cleanup to be done asynchronously.
rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
} else if (errors::IsCancelled(s)) {
mutex_lock l(mu_);
if (closed_) {
if (garbage_collected_) {
s = errors::Cancelled(
"Step was cancelled because the session was garbage collected due "
"to inactivity.");
} else {
s = errors::Cancelled(
"Step was cancelled by an explicit call to `Session::Close()`.");
}
}
}
Ref();
rcg->Ref();
rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
if (!s.ok()) {
LOG(ERROR) << "Cleanup partition error: " << s;
}
rcg->Unref();
MarkRunCompletion();
Unref();
});
return s;
}
Status MasterSession::DoRunWithLocalExecution(
CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp) {
VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString();
PerStepState pss;
pss.start_micros = Env::Default()->NowMicros();
auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
// Prepare.
BuildGraphOptions bgopts;
BuildBuildGraphOptions(req, session_opts_.config, &bgopts);
ReffedClientGraph* rcg = nullptr;
int64 count;
TF_RETURN_IF_ERROR(StartStep(bgopts, false, &rcg, &count));
// Unref "rcg" when out of scope.
core::ScopedUnref unref(rcg);
std::unique_ptr<DebuggerStateInterface> debugger_state;
const DebugOptions& debug_options = req.options().debug_options();
if (!debug_options.debug_tensor_watch_opts().empty()) {
TF_RETURN_IF_ERROR(
CreateDebuggerState(debug_options, req, count, &debugger_state));
}
TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
// Keeps the highest 8 bits 0x01: we reserve some bits of the
// step_id for future use.
uint64 step_id = NewStepId(rcg->collective_graph_key());
TRACEPRINTF("stepid %llu", step_id);
std::unique_ptr<ProfileHandler> ph;
FillPerStepState(rcg, req.options(), step_id, count, &pss, &ph);
if (pss.collect_partition_graphs &&
session_opts_.config.experimental().disable_output_partition_graphs()) {
return errors::InvalidArgument(
"RunOptions.output_partition_graphs() is not supported when "
"disable_output_partition_graphs is true.");
}
Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
&cancellation_manager_, false);
cleanup.release(); // MarkRunCompletion called in PostRunCleanup().
return PostRunCleanup(rcg, step_id, req.options(), &pss, ph, s,
resp->mutable_metadata());
}
Status MasterSession::MakeCallable(const MakeCallableRequest& req,
MakeCallableResponse* resp) {
UpdateLastAccessTime();
BuildGraphOptions opts;
opts.callable_options = req.options();
opts.use_function_convention = false;
ReffedClientGraph* callable;
{
mutex_lock l(mu_);
if (closed_) {
return errors::FailedPrecondition("Session is closed.");
}
std::unique_ptr<ClientGraph> client_graph;
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
callable = new ReffedClientGraph(handle_, opts, std::move(client_graph),
session_opts_, stats_publisher_factory_,
false /* is_partial */, get_worker_cache(),
!should_delete_worker_sessions_);
}
Status s = BuildAndRegisterPartitions(callable);
if (!s.ok()) {
callable->Unref();
return s;
}
uint64 handle;
{
mutex_lock l(mu_);
handle = next_callable_handle_++;
callables_[handle] = callable;
}
resp->set_handle(handle);
return Status::OK();
}
Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
const RunCallableRequest& req,
RunCallableResponse* resp) {
VLOG(2) << "DoRunCallable req: " << req.DebugString();
PerStepState pss;
pss.start_micros = Env::Default()->NowMicros();
auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
// Prepare.
int64 count = rcg->get_and_increment_execution_count();
const uint64 step_id = NewStepId(rcg->collective_graph_key());
TRACEPRINTF("stepid %llu", step_id);
const RunOptions& run_options = rcg->callable_options().run_options();
if (run_options.timeout_in_ms() != 0) {
opts->SetTimeout(run_options.timeout_in_ms());
}
std::unique_ptr<ProfileHandler> ph;
FillPerStepState(rcg, run_options, step_id, count, &pss, &ph);
Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
&cancellation_manager_);
cleanup.release(); // MarkRunCompletion called in PostRunCleanup().
return PostRunCleanup(rcg, step_id, run_options, &pss, ph, s,
resp->mutable_metadata());
}
Status MasterSession::RunCallable(CallOptions* opts,
const RunCallableRequest& req,
RunCallableResponse* resp) {
UpdateLastAccessTime();
ReffedClientGraph* callable;
{
mutex_lock l(mu_);
if (closed_) {
return errors::FailedPrecondition("Session is closed.");
}
int64 handle = req.handle();
if (handle >= next_callable_handle_) {
return errors::InvalidArgument("No such callable handle: ", handle);
}
auto iter = callables_.find(req.handle());
if (iter == callables_.end()) {
return errors::InvalidArgument(
"Attempted to run callable after handle was released: ", handle);
}
callable = iter->second;
callable->Ref();
++num_running_;
}
core::ScopedUnref unref_callable(callable);
return DoRunCallable(opts, callable, req, resp);
}
Status MasterSession::ReleaseCallable(const ReleaseCallableRequest& req,
ReleaseCallableResponse* resp) {
UpdateLastAccessTime();
ReffedClientGraph* to_unref = nullptr;
{
mutex_lock l(mu_);
auto iter = callables_.find(req.handle());
if (iter != callables_.end()) {
to_unref = iter->second;
callables_.erase(iter);
}
}
if (to_unref != nullptr) {
to_unref->Unref();
}
return Status::OK();
}
Status MasterSession::Close() {
{
mutex_lock l(mu_);
closed_ = true; // All subsequent calls to Run() or Extend() will fail.
}
cancellation_manager_.StartCancel();
std::vector<ReffedClientGraph*> to_unref;
{
mutex_lock l(mu_);
while (num_running_ != 0) {
num_running_is_zero_.wait(l);
}
ClearRunsTable(&to_unref, &run_graphs_);
ClearRunsTable(&to_unref, &partial_run_graphs_);
ClearRunsTable(&to_unref, &callables_);
}
for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
if (should_delete_worker_sessions_) {
Status s = DeleteWorkerSessions();
if (!s.ok()) {
LOG(WARNING) << s;
}
}
return Status::OK();
}
void MasterSession::GarbageCollect() {
{
mutex_lock l(mu_);
closed_ = true;
garbage_collected_ = true;
}
cancellation_manager_.StartCancel();
Unref();
}
MasterSession::RunState::RunState(const std::vector<string>& input_names,
const std::vector<string>& output_names,
ReffedClientGraph* rcg, const uint64 step_id,
const int64 count)
: rcg(rcg), step_id(step_id), count(count) {
// Initially all the feeds and fetches are pending.
for (auto& name : input_names) {
pending_inputs[name] = false;
}
for (auto& name : output_names) {
pending_outputs[name] = false;
}
}
MasterSession::RunState::~RunState() {
if (rcg) rcg->Unref();
}
bool MasterSession::RunState::PendingDone() const {
for (const auto& it : pending_inputs) {
if (!it.second) return false;
}
for (const auto& it : pending_outputs) {
if (!it.second) return false;
}
return true;
}
} // end namespace tensorflow