| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/core/grappler/costs/virtual_scheduler.h" |
| |
| #include "absl/strings/str_format.h" |
| #include "absl/strings/str_replace.h" |
| #include "tensorflow/core/framework/allocation_description.pb.h" |
| #include "tensorflow/core/framework/attr_value.pb.h" |
| #include "tensorflow/core/framework/node_def.pb.h" |
| #include "tensorflow/core/framework/tensor.pb.h" |
| #include "tensorflow/core/framework/tensor_description.pb.h" |
| #include "tensorflow/core/framework/tensor_shape.pb.h" |
| #include "tensorflow/core/grappler/clusters/utils.h" |
| #include "tensorflow/core/grappler/costs/utils.h" |
| #include "tensorflow/core/grappler/op_types.h" |
| #include "tensorflow/core/grappler/utils.h" |
| #include "tensorflow/core/grappler/utils/transitive_fanin.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/strings/numbers.h" |
| #include "tensorflow/core/platform/logging.h" |
| #include "tensorflow/core/util/device_name_utils.h" |
| |
| namespace tensorflow { |
| namespace grappler { |
| |
| const char kAttrInputSrc[] = "input_source_"; |
| const char kAttrSrcDevice[] = "send_device"; |
| const char kAttrDstDevice[] = "recv_device"; |
| const char kAttrTensorName[] = "tensor_name"; |
| const char kChannelDevice[] = "Channel"; |
| const char kStreaming[] = "_streaming"; |
| |
| namespace { |
| |
| using ::tensorflow::strings::HumanReadableNumBytes; |
| |
| float Round2(const float x) { |
| // Not using std::round from <cmath> here because not all platforms seem to |
| // support that (specifically Android). |
| return ::round(100.0 * x) / 100.0; |
| } |
| |
| Costs& FindOrCreateZero(const string& op_name, |
| std::map<string, Costs>* op_cost) { |
| auto it = op_cost->find(op_name); |
| if (it == op_cost->end()) { |
| // Note that default constructor of Costs sets some memory related fields |
| // to unknown values so we should explicitly initialize it with ZeroCosts. |
| it = op_cost->emplace(op_name, Costs::ZeroCosts()).first; |
| } |
| return it->second; |
| } |
| |
| // Key to the cached _Recv ops map, and its hash and predicate structures. |
| struct RecvNodeDescriptor { |
| const NodeDef* node; |
| const int port_num; |
| const string device; |
| |
| RecvNodeDescriptor(const NodeDef* node_, const int port_num_, |
| const string& device_) |
| : node(node_), port_num(port_num_), device(device_) {} |
| }; |
| |
| struct RecvNodeDescriptorHash { |
| std::size_t operator()(const RecvNodeDescriptor& recv_node) const { |
| return std::hash<const NodeDef*>()(recv_node.node) ^ |
| std::hash<int>()(recv_node.port_num) ^ |
| std::hash<string>()(recv_node.device); |
| } |
| }; |
| |
| struct RecvNodeDescriptorEqual { |
| bool operator()(const RecvNodeDescriptor& a, |
| const RecvNodeDescriptor& b) const { |
| return a.node == b.node && a.port_num == b.port_num && a.device == b.device; |
| } |
| }; |
| |
| void UpdateDeviceAnnotationState(const NodeDef* node, |
| const NodeState& node_state, |
| DeviceState* device) { |
| if (node->attr().count(kOutputShapes) == 0) return; |
| |
| int64 execution_count = node->attr().count(kExecutionCount) == 0 |
| ? 1 |
| : node->attr().at(kExecutionCount).i(); |
| |
| auto& shape_annotation_stats = device->shape_annotation_stats; |
| shape_annotation_stats.num_ops_annotated += 1; |
| shape_annotation_stats.num_ops_executed += execution_count; |
| shape_annotation_stats.num_ops_executed_more_than_once += |
| execution_count > 1 ? 1 : 0; |
| shape_annotation_stats.num_ops_with_incompatible_shapes += |
| node_state.shape_incompatible ? 1 : 0; |
| shape_annotation_stats.num_ops_with_dynamic_shapes += |
| (execution_count > 1 && node->attr().count(kOutputSame) == 0) ? 1 : 0; |
| } |
| |
| bool IsStreamingPort(const NodeDef& node, const int port) { |
| if (!node.attr().contains(kStreaming)) return false; |
| |
| auto& attr_list = node.attr().at(kStreaming).list(); |
| bool is_streaming_port = false; |
| if (port >= 0 && port < attr_list.b().size()) { |
| is_streaming_port = attr_list.b(port); |
| } |
| return is_streaming_port; |
| } |
| |
| } // namespace |
| |
| void LIFOManager::AddNode(const NodeDef* node) { |
| // Merge nodes are scheduled with the lowest priority in LIFO manager; virtual |
| // scheduler may run multiple input nodes of Merge (when we don't have |
| // annotation, which is quite common); simply scheduling Merge after one of |
| // its input may break scheduling constraints; some inputs of Merge may be |
| // scheduled after the Merge. So, we place Merge at the beginning of the queue |
| // to guarantee all the inputs of Merge are scheduled before the Merge. |
| if (IsMerge(*node)) { |
| nodes_.push_front(node); |
| } else { |
| nodes_.push_back(node); |
| } |
| } |
| |
| const NodeDef* LIFOManager::GetCurrNode() { |
| CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node"; |
| if (curr_pos_ == nodes_.end()) { |
| curr_pos_ = --(nodes_.rbegin().base()); // Last one in the list. |
| } |
| // Once curr_pos_ is set to a valid entry in the list, we keep using the |
| // cached curr_pos_ until RemoveCurrNode() is called. AddNode() will not |
| // change the GetCurrNode() return value. |
| return *curr_pos_; |
| } |
| |
| void LIFOManager::RemoveCurrNode() { |
| // Make sure we have curr_pos_ ready to be removed. |
| GetCurrNode(); |
| // Note curr_pos_ may not be pointing the last element if some nodes are |
| // added. |
| nodes_.erase(curr_pos_); |
| |
| curr_pos_ = nodes_.end(); // Reset curr_pos_. |
| } |
| |
| HeapReadyManager::HeapReadyManager() : ReadyNodeManager() { |
| std::make_heap(nodes_.begin(), nodes_.end()); |
| } |
| |
| Status HeapReadyManager::Init( |
| const std::unordered_map<const NodeDef*, NodeState>* node_map) { |
| // Resets the node state since different instances of the scheduler can reuse |
| // the same node_manager. |
| node_map_ = node_map; |
| nodes_.clear(); |
| curr_node_ = nullptr; |
| |
| // Sets up the comparator for the heap. |
| greater_ = Greater(); |
| |
| return Status::OK(); |
| } |
| |
| void HeapReadyManager::AddNode(const NodeDef* node) { |
| // push_heap in AddNode and pop_heap in RemoveCurrNode() guarantees that the |
| // first element is the node with minimum time_ready. |
| nodes_.push_back(node); |
| std::push_heap(nodes_.begin(), nodes_.end(), greater_); |
| } |
| |
| const NodeDef* HeapReadyManager::GetCurrNode() { |
| if (curr_node_) return curr_node_; |
| if (nodes_.empty()) { |
| CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node"; |
| } |
| const std::string node_name = nodes_.front()->name(); |
| // Next time we call GetCurrNode(), it just returns the cached copy |
| // curr_node_, until we call the RemoveCurrNode(). |
| curr_node_ = nodes_.front(); |
| // Remove current node from the heap immediately. Because if we wait until |
| // later, the heap could have gotten re-organized if AddNode is called. The |
| // current node is anyways cached, incase GetCurrNode() is called again. |
| std::pop_heap(nodes_.begin(), nodes_.end(), greater_); |
| nodes_.pop_back(); |
| return curr_node_; |
| } |
| |
| void HeapReadyManager::RemoveCurrNode() { |
| if (curr_node_) { |
| // If cached copy exists, remove that. |
| // Reset curr_node_ so that GetCurrNode() finds another node. |
| curr_node_ = nullptr; |
| } else { |
| // If cached copy not present, then remove entry from the heap queue. |
| std::pop_heap(nodes_.begin(), nodes_.end(), greater_); |
| nodes_.pop_back(); |
| } |
| } |
| |
| bool HeapReadyManager::Empty() const { |
| return nodes_.empty() && curr_node_ == nullptr; |
| } |
| |
| bool FirstReadyCmp( |
| const std::unordered_map<const NodeDef*, NodeState>* node_map, |
| const NodeDef* a, const NodeDef* b) { |
| if (node_map->at(a).time_ready == node_map->at(b).time_ready) { |
| // Use Node name as tie-breaker for deterministic node scheduling. |
| return a->name().compare(b->name()) > 0; |
| } else { |
| // Note: we need a node with minimum time_ready, not maximum; hence, using |
| // a > b for comparison function. |
| return node_map->at(a).time_ready > node_map->at(b).time_ready; |
| } |
| } |
| |
| std::function<bool(const NodeDef*, const NodeDef*)> |
| FirstReadyManager::Greater() { |
| auto greater = [this](const NodeDef* a, const NodeDef* b) -> bool { |
| return FirstReadyCmp(node_map_, a, b); |
| }; |
| return greater; |
| } |
| |
| std::function<bool(const NodeDef*, const NodeDef*)> |
| PriorityReadyManager::Greater() { |
| auto greater = [this](const NodeDef* a, const NodeDef* b) -> bool { |
| auto pri_a = node_priority_.at(a->name()); |
| auto pri_b = node_priority_.at(b->name()); |
| if (pri_a == pri_b) { |
| // Fallback to default (FirstReady) behaviour. |
| return FirstReadyCmp(node_map_, a, b); |
| } |
| return pri_a > pri_b; |
| }; |
| return greater; |
| } |
| |
| void PriorityReadyManager::AddNode(const NodeDef* node) { |
| if (node_priority_.count(node->name()) == 0) { |
| VLOG(3) << "Priority of node " << node->name() << " not found."; |
| node_priority_[node->name()] = 0; |
| } |
| HeapReadyManager::AddNode(node); |
| } |
| |
| Status PriorityReadyManager::SetPriority( |
| const std::unordered_map<string, int>& node_priority) { |
| node_priority_ = node_priority; |
| return Status::OK(); |
| } |
| |
| CompositeNodeManager::CompositeNodeManager() |
| : ReadyNodeManager(), send_manager_(), recv_manager_() {} |
| |
| Status CompositeNodeManager::Init( |
| const std::unordered_map<const NodeDef*, NodeState>* node_map) { |
| node_map_ = node_map; |
| TF_RETURN_IF_ERROR(send_manager_.Init(node_map)); |
| TF_RETURN_IF_ERROR(recv_manager_.Init(node_map)); |
| curr_node_ = nullptr; |
| return Status::OK(); |
| } |
| |
| void CompositeNodeManager::AddNode(const NodeDef* node) { |
| if (IsSend(*node)) { |
| send_manager_.AddNode(node); |
| } else if (IsRecv(*node)) { |
| recv_manager_.AddNode(node); |
| } else { |
| const auto& device = node_map_->at(node).device_name; |
| ops_lifo_map_[device].AddNode(node); |
| } |
| } |
| |
| const NodeDef* CompositeNodeManager::GetCurrNode() { |
| if (curr_node_) return curr_node_; |
| |
| // Per-device LIFO for normal ops (not _Send / _Recv), |
| // FirstReady for _Send and _Recv (separately), |
| // Globally (among the LIFO-selected ops from each device and _Send and |
| // _Recv) FirstReady, |
| // Priority order: _Send, _Recv, and then the rest, if time_ready is equal. |
| std::vector<std::pair<const NodeDef*, Costs::Duration>> candidates; |
| for (auto& ops_lifo : ops_lifo_map_) { |
| if (!ops_lifo.second.Empty()) { |
| const auto* op = ops_lifo.second.GetCurrNode(); |
| candidates.emplace_back(op, node_map_->at(op).time_ready); |
| } |
| } |
| if (!send_manager_.Empty()) { |
| const auto* send = send_manager_.GetCurrNode(); |
| candidates.emplace_back(send, node_map_->at(send).time_ready); |
| } |
| if (!recv_manager_.Empty()) { |
| const auto* recv = recv_manager_.GetCurrNode(); |
| candidates.emplace_back(recv, node_map_->at(recv).time_ready); |
| } |
| CHECK(!candidates.empty()); |
| auto first_ready = std::min_element( |
| candidates.begin(), candidates.end(), |
| [](const std::pair<const NodeDef*, Costs::Duration>& a, |
| const std::pair<const NodeDef*, Costs::Duration>& b) { |
| if (a.second == b.second) { |
| // Note that there can be only 1 Send and only 1 Recv in candidates, |
| // at most; hence, score is 2 for Send, 1 for Recv, and 0 for a |
| // normap op, and a_score and b_score are equal only if both are |
| // normal ops. |
| int a_score = 2 * IsSend(*a.first) + IsRecv(*a.first); |
| int b_score = 2 * IsSend(*b.first) + IsRecv(*b.first); |
| if (a_score == b_score) { |
| // Both are normal ops; use node name as tie breaker. |
| return a.first->name().compare(b.first->name()) < 0; |
| } else { |
| // Prioritize by op type: _Send, _Recv, and normap ops. |
| return a_score > b_score; |
| } |
| } else { |
| return a.second < b.second; |
| } |
| }); |
| // Next time we call GetCurrNode(), it just returns the cached one, |
| // curr_node_ until we call RemovCurrNode(). |
| curr_node_ = first_ready->first; |
| |
| return curr_node_; |
| } |
| |
| void CompositeNodeManager::RemoveCurrNode() { |
| const auto* node = GetCurrNode(); |
| if (IsSend(*node)) { |
| send_manager_.RemoveCurrNode(); |
| } else if (IsRecv(*node)) { |
| recv_manager_.RemoveCurrNode(); |
| } else { |
| const auto device = node_map_->at(node).device_name; |
| ops_lifo_map_[device].RemoveCurrNode(); |
| } |
| // Reset curr_node_ so that GetCurrNode() finds another node. |
| curr_node_ = nullptr; |
| } |
| |
| bool CompositeNodeManager::Empty() const { |
| // Empty if all the ready managers are empty. |
| bool empty = true; |
| for (const auto& ops_lifo : ops_lifo_map_) { |
| empty &= ops_lifo.second.Empty(); |
| } |
| return empty && send_manager_.Empty() && recv_manager_.Empty(); |
| } |
| |
| std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory( |
| const string& ready_node_manager) { |
| if (ready_node_manager == "FIFO") { |
| return absl::make_unique<FIFOManager>(); |
| } else if (ready_node_manager == "LIFO") { |
| return absl::make_unique<LIFOManager>(); |
| } else if (ready_node_manager == "FirstReady") { |
| return absl::make_unique<FirstReadyManager>(); |
| } else if (ready_node_manager == "Composite") { |
| return absl::make_unique<CompositeNodeManager>(); |
| } |
| LOG(FATAL) << "Not a valid ready node manager: " << ready_node_manager; |
| return nullptr; |
| } |
| |
| SchedulerState::~SchedulerState() {} |
| |
| SchedulerState::SchedulerState(const bool use_static_shapes, |
| const bool use_aggressive_shape_inference, |
| Cluster* cluster, |
| std::unique_ptr<VirtualPlacer> placer) |
| : graph_costs_(Costs::ZeroCosts()), |
| cluster_(cluster), |
| use_static_shapes_(use_static_shapes), |
| use_aggressive_shape_inference_(use_aggressive_shape_inference), |
| placer_(std::move(placer)) { |
| DCHECK(placer_); // check if the pointer is valid. |
| graph_costs_.num_ops_total = 0; |
| initialized_ = false; |
| track_mem_usage_snapshot_ = VLOG_IS_ON(1); |
| } |
| |
| Status SchedulerState::Init(const GrapplerItem* item, |
| std::vector<const NodeDef*>* initial_nodes, |
| bool create_explicit_channel_device) { |
| initialized_ = false; |
| |
| // Clear all internal states so that the SchedulerState is reusable for |
| // different GrapplerItems |
| node_map_.clear(); |
| device_.clear(); |
| additional_nodes_.clear(); |
| |
| graph_costs_ = Costs::ZeroCosts(); |
| graph_costs_.num_ops_total = 0; |
| op_to_cost_.clear(); |
| |
| op_counts_.clear(); |
| op_costs_.clear(); |
| |
| initial_nodes->clear(); |
| |
| // Constructs graph properties and performs shape inference. |
| graph_properties_ = absl::make_unique<GraphProperties>(*item); |
| // TODO(safeen,dyoon): Will we ever use InferDynamically? If not we may want |
| // to get rid of use_static_shapes_ and cluster_. |
| if (use_static_shapes_) { |
| TF_RETURN_IF_ERROR(graph_properties_->InferStatically( |
| true, use_aggressive_shape_inference_, true)); |
| } else { |
| TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_)); |
| } |
| |
| grappler_item_ = item; |
| const auto& graph = grappler_item_->graph; |
| const auto& fetch_nodes = grappler_item_->fetch; |
| std::set<string> feed_nodes; |
| |
| for (const auto& f : grappler_item_->feed) { |
| auto iter_and_inserted_flag = feed_nodes.insert(f.first); |
| QCHECK(iter_and_inserted_flag.second) |
| << "Duplicate feed node found: " << f.first; |
| } |
| |
| // Get the nodes that would run to output fetch_nodes. |
| std::unordered_map<string, const NodeDef*> name_to_node; |
| std::vector<const NodeDef*> fetch_fanin_nodes; |
| TF_RETURN_IF_ERROR(ComputeTransitiveFanin(graph, fetch_nodes, &name_to_node, |
| &fetch_fanin_nodes)); |
| |
| // Once ComputeTransitiveFanin is complete, only the nodes that can be reached |
| // from the fetch nodes are scheduled. So the scheduled nodes should be |
| // exactly the same as those executed for real. One possible discrepancy could |
| // be the control flow nodes, where tf only executes one path. |
| |
| // Traverses the graph to record _Send nodes. |
| // TODO(dyoon): Instead of identifying _Send node here manually, add _Send |
| // to _Recv as control dependency when creating GrapplerItem. |
| std::unordered_map<string, const NodeDef*> name_to_send; |
| for (const auto& node : graph.node()) { |
| if (IsSend(node)) { |
| const auto& attr = node.attr(); |
| name_to_send[attr.at("tensor_name").s()] = &node; |
| } |
| } |
| |
| // To reuse _Recv ops. |
| std::unordered_map<RecvNodeDescriptor, const NodeDef*, RecvNodeDescriptorHash, |
| RecvNodeDescriptorEqual> |
| cached_recv_nodes; |
| |
| // Build node_map; for each node, create its NodeState and connect its inputs |
| // and outputs. |
| for (const auto* curr_node : fetch_fanin_nodes) { |
| auto& curr_node_state = GetNodeStateOrCreateIt(curr_node); |
| const string curr_node_device = DeviceName(curr_node); |
| std::vector<string> inputs; |
| if (IsRecv(*curr_node)) { |
| const auto& attr = curr_node->attr(); |
| if (attr.count("tensor_name")) { |
| const auto& send_node_name = attr.at("tensor_name").s(); |
| auto it = name_to_send.find(send_node_name); |
| // If there is a _Send associated with the curr_node (_Recv), add it as |
| // input. |
| if (it != name_to_send.end()) { |
| const NodeDef* send = it->second; |
| inputs = {send->name()}; |
| } |
| } |
| } else { |
| for (const string& input : curr_node->input()) { |
| inputs.push_back(input); |
| } |
| } |
| for (const string& input_node_name : inputs) { |
| // Note that input_node_name may be in <prefix><node_name>:<port_num> |
| // format, where <prefix> (e.g., "^" for control dependency) and |
| // ":<port_num>" may be omitted. NodeName() extracts only the node_name. |
| const NodeDef* input_node = name_to_node[NodeName(input_node_name)]; |
| |
| CHECK(input_node); |
| const string in_device = DeviceName(input_node); |
| const auto input_node_port_num = NodePosition(input_node_name); |
| |
| // Control dependencies should be treated as high priority. Current |
| // Channel device doesn't model a separate virual channel for control v/s |
| // data transfers. So in the interim, it may be okay to let control |
| // dependencies magically flow across devices bypassing the channel |
| // device. |
| if (curr_node_device == in_device || IsControlInput(input_node_name)) { |
| // Same device: connect input_node and curr_node directly. |
| curr_node_state.inputs.push_back( |
| std::make_pair(input_node, input_node_port_num)); |
| auto& input_node_state = GetNodeStateOrCreateIt(input_node); |
| input_node_state.outputs[input_node_port_num].push_back(curr_node); |
| } else { |
| RecvNodeDescriptor recv_node(input_node, input_node_port_num, |
| curr_node_device); |
| auto it = cached_recv_nodes.find(recv_node); |
| if (it != cached_recv_nodes.end()) { |
| // Different device, but found an already-cached copy (a _Recv op); |
| // connect the _Recv to curr_node. |
| const NodeDef* recv_op = it->second; |
| // recv_op's output port is hard-coded to zero. |
| curr_node_state.inputs.push_back(std::make_pair(recv_op, 0)); |
| auto& input_node_state = node_map_.at(recv_op); |
| input_node_state.outputs[0].push_back(curr_node); |
| } else { |
| // Different device, no cached copy; transfer input_node to the |
| // curr_node's device. |
| auto send_and_recv = |
| CreateSendRecv(input_node, curr_node, input_node, input_node_name, |
| create_explicit_channel_device); |
| // Note that CreateSendRecv() already connected input/output between |
| // _Send and _Recv ops. |
| const auto* send = send_and_recv.first; |
| const auto* recv = send_and_recv.second; |
| // recv_op's output port is hard-coded to zero. |
| curr_node_state.inputs.push_back(std::make_pair(recv, 0)); |
| auto& input_node_state = GetNodeStateOrCreateIt(input_node); |
| input_node_state.outputs[input_node_port_num].push_back(send); |
| |
| // Cache the _Recv op for future use. |
| cached_recv_nodes[recv_node] = recv; |
| } |
| } |
| } |
| |
| // Special case: given feed nodes are ready at time 0. |
| const bool given_as_feed = |
| feed_nodes.find(curr_node->name()) != feed_nodes.end(); |
| |
| // Default case: node without inputs are ready at time 0. |
| // Note that we check inputs vector which may be different to |
| // curr_node->input(); e.g., we add Send as input to Recv. |
| const bool has_no_inputs = inputs.empty(); |
| |
| if (given_as_feed || has_no_inputs) { |
| curr_node_state.time_ready = Costs::Duration(); |
| initial_nodes->push_back(curr_node); |
| VLOG(3) << "Added ready node: " << curr_node->name(); |
| } |
| feed_nodes.erase(curr_node->name()); |
| |
| if (IsPersistent(*curr_node)) { |
| auto& device_state = device_[curr_node_device]; |
| for (int port_num = 0, |
| port_num_end = curr_node_state.output_properties.size(); |
| port_num < port_num_end; ++port_num) { |
| device_state.persistent_nodes.insert( |
| std::make_pair(curr_node, port_num)); |
| } |
| } |
| } |
| |
| if (initial_nodes->empty()) { |
| return errors::InvalidArgument("No ready nodes in the graph."); |
| } |
| |
| if (!feed_nodes.empty()) { |
| // This isn't always a bug: when the caller hasn't specified the exact list |
| // of feed and fetch nodes, by default we consider all placeholders as feed |
| // nodes, but some of them may not be needed for the default fetch node. |
| VLOG(1) << "Some feed nodes were not consumed by the fetch fanin: " |
| << absl::StrJoin(feed_nodes, ","); |
| } |
| |
| initialized_ = true; |
| return Status::OK(); |
| } |
| |
| void SchedulerState::MaybeUpdateInputOutput(const NodeDef* node) { |
| CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init()."; |
| // This method is called when NodeState is created and adds input and output |
| // properties for a few exceptional cases that GraphProperties cannot provide |
| // input/output properties. |
| if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) { |
| // _Send and _Recv ops created from SchedulerState have kAttrInputSrc |
| // attr; normal _Send and _Recv ops (from the input graph) do not have that |
| // attr. |
| auto& node_state = node_map_[node]; |
| auto& inputs = node_state.input_properties; |
| auto& outputs = node_state.output_properties; |
| |
| // _Send and _Recv ops are created from SchedulerState, so |
| // there should be no inputs TensorProperties. |
| CHECK(inputs.empty()); |
| CHECK(outputs.empty()); |
| const auto& attr = node->attr(); |
| // This is the original input source to the _Send and _Recv, and this |
| // string includes "^" if it was control dependency, and output port |
| /// (e.g., ":2") if the input source had multiple outputs. |
| const auto& input_source_name = attr.at(kAttrInputSrc).s(); |
| if (IsControlInput(input_source_name)) { |
| // Control dependency; regardless of the input source tensor size, |
| // send 4B. |
| OpInfo::TensorProperties control_message; |
| control_message.set_dtype(DT_FLOAT); |
| control_message.mutable_shape()->add_dim()->set_size(1); |
| auto* value = control_message.mutable_value(); |
| value->add_float_val(1); |
| inputs.push_back(control_message); |
| outputs.push_back(control_message); |
| } else { |
| const auto& output_properties = |
| graph_properties_->GetOutputProperties(NodeName(input_source_name)); |
| // Like with HasInputProperties, if a node does not have output |
| // properties, it's likely it was pruned during the shape inference run. |
| if (!output_properties.empty()) { |
| const auto input_node_port_num = NodePosition(input_source_name); |
| // Use the input source's output property as _Send and _Recv's input |
| // property. |
| CHECK_GT(output_properties.size(), input_node_port_num); |
| inputs.push_back(output_properties[input_node_port_num]); |
| outputs.push_back(output_properties[input_node_port_num]); |
| } |
| } |
| } |
| } |
| |
| string SchedulerState::DeviceName(const NodeDef* node) const { |
| return placer_->get_canonical_device_name(*node); |
| } |
| |
| string SchedulerState::SanitizedDeviceName(const NodeDef* node) const { |
| // Replace the ":" characters that may be present in the device name with "_". |
| // This makes it possible to then use the resulting string in a node name. |
| return absl::StrReplaceAll(placer_->get_canonical_device_name(*node), |
| {{":", "_"}}); |
| } |
| |
| string SchedulerState::ChannelDeviceName(const NodeDef* from, |
| const NodeDef* to) const { |
| CHECK(!initialized_) << "ChannelDeviceName is called after Init()."; |
| return absl::StrCat(kChannelDevice, "_from_", SanitizedDeviceName(from), |
| "_to_", SanitizedDeviceName(to)); |
| } |
| |
| std::pair<const NodeDef*, const NodeDef*> SchedulerState::CreateSendRecv( |
| const NodeDef* from, const NodeDef* to, const NodeDef* input_node, |
| const string& input_name, bool create_channel_device) { |
| CHECK(!initialized_) << "CreateSendRecv is called after Init()."; |
| |
| // Connect "from" node to "to" node with _Send and _Recv such that |
| // from -> _Send -> _Recv -> to. |
| // _Send is placed on "Channel" device, and _Recv is on the same device |
| // as "to" node. |
| // input_node_name is the string from the "to" node to identify which output |
| // we get from the "from" node. |
| |
| // Note that we use NodeState for scheduling, so _Send and _Recv |
| // NodeDefs created here need not be correct: in terms of name, |
| // input names, attrs, etc. |
| |
| auto input_node_port_num = NodePosition(input_name); |
| string src_name; |
| bool control_input = false; |
| if (input_node_port_num >= 0) { |
| src_name = absl::StrCat(from->name(), "_", input_node_port_num); |
| } else { |
| src_name = absl::StrCat(from->name(), "_minus1"); |
| control_input = true; |
| } |
| |
| // _Send op. |
| auto* send = new NodeDef(); |
| send->set_name("Send_" + src_name + "_from_" + SanitizedDeviceName(from) + |
| "_to_" + SanitizedDeviceName(to)); |
| send->set_op("_Send"); |
| send->add_input(from->name()); |
| auto send_device = |
| create_channel_device ? ChannelDeviceName(from, to) : DeviceName(from); |
| send->set_device(send_device); |
| auto& send_attr = *(send->mutable_attr()); |
| send_attr[kAttrInputSrc].set_s(input_name); |
| send_attr[kAttrSrcDevice].set_s(DeviceName(from)); |
| send_attr[kAttrDstDevice].set_s(DeviceName(to)); |
| // GraphDef generated by AutoGrappler has tensor_name field when removing |
| // _Send/_Recv nodes. |
| if (input_node->attr().count(kAttrTensorName)) { |
| send_attr[kAttrTensorName].set_s( |
| input_node->attr().at(kAttrTensorName).s()); |
| } |
| |
| // _Recv op. |
| auto* recv = new NodeDef(); |
| recv->set_name("Recv_" + src_name + "_on_" + SanitizedDeviceName(to)); |
| recv->set_op("_Recv"); |
| recv->add_input(send->name()); |
| recv->set_device(DeviceName(to)); |
| auto& recv_attr = *(recv->mutable_attr()); |
| recv_attr[kAttrInputSrc].set_s(input_name); |
| if (input_node->attr().count(kAttrTensorName)) { |
| recv_attr[kAttrTensorName].set_s( |
| input_node->attr().at(kAttrTensorName).s()); |
| } |
| |
| // Propagate the streaming attribute to the send/recv nodes. |
| if (from->attr().contains(kStreaming) && !control_input) { |
| if (input_node_port_num >= from->attr().at(kStreaming).list().b_size()) { |
| LOG(ERROR) |
| << from->name() |
| << " port index larger than length of _streaming attribute list."; |
| } else if (from->attr().at(kStreaming).list().b(input_node_port_num)) { |
| send_attr[kStreaming].mutable_list()->add_b(true); |
| recv_attr[kStreaming].mutable_list()->add_b(true); |
| } |
| } |
| |
| // NodeState for _Send op. |
| auto& send_node_state = GetNodeStateOrCreateIt(send); |
| send_node_state.device_name = send->device(); // Set Channel device. |
| send_node_state.inputs.push_back(std::make_pair(from, input_node_port_num)); |
| send_node_state.outputs[0].push_back(recv); |
| |
| // NodeState for _Recv op. |
| auto& recv_node_state = GetNodeStateOrCreateIt(recv); |
| recv_node_state.inputs.push_back(std::make_pair(send, 0)); |
| recv_node_state.outputs[0].push_back(to); |
| |
| // Keep the created nodes. |
| additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(send)); |
| additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(recv)); |
| |
| // Return _Send and _Recv. |
| return std::make_pair(send, recv); |
| } |
| |
| OpContext SchedulerState::CreateOpContext(const NodeDef* node) const { |
| // Get the device from the placer. |
| DeviceProperties device; |
| device = placer_->get_device(*node); |
| |
| // Special case for _Send op. |
| if (IsSend(*node)) { |
| device.set_type(kChannelDevice); |
| } |
| |
| // Construct OpContext. |
| OpContext op_context; |
| const auto& node_state = node_map_.at(node); |
| op_context.name = node->name(); |
| op_context.device_name = node_state.device_name; |
| auto& op_info = op_context.op_info; |
| op_info.set_op(node->op()); |
| *op_info.mutable_attr() = node->attr(); |
| for (auto& input : node_state.input_properties) { |
| *op_info.add_inputs() = input; |
| } |
| for (auto& output : node_state.output_properties) { |
| *op_info.add_outputs() = output; |
| } |
| op_info.mutable_device()->Swap(&device); |
| |
| if (grappler_item_->graph.has_library()) { |
| op_context.function_library = &grappler_item_->graph.library(); |
| } |
| return op_context; |
| } |
| |
| NodeState& SchedulerState::GetNodeStateOrCreateIt(const NodeDef* node) { |
| CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init()."; |
| |
| auto it = node_map_.find(node); |
| if (it != node_map_.end()) { |
| return it->second; |
| } |
| |
| // Not found; create a NodeState for this node. |
| it = node_map_.emplace(node, NodeState()).first; |
| auto& node_state = it->second; |
| node_state.input_properties = |
| graph_properties_->GetInputProperties(node->name()); |
| node_state.output_properties = |
| graph_properties_->GetOutputProperties(node->name()); |
| node_state.shape_incompatible = |
| graph_properties_->CheckShapeIncompatible(node->name()); |
| |
| // Some ops may need further processing to the input / output properties: |
| // _Send and _Recv. |
| MaybeUpdateInputOutput(node); |
| |
| if (!IsSend(*node)) { |
| node_state.device_name = DeviceName(node); |
| // For _Send op, device_name will be set to Channel in CreateSendRecv(). |
| } |
| |
| // Initialize output port related data: |
| // Assume the size of OutputProperties represents the number of output ports |
| // of this node. |
| for (size_t i = 0; i < node_state.output_properties.size(); ++i) { |
| node_state.time_no_references[i] = Costs::Duration::max(); |
| node_state.num_outputs_executed[i] = 0; |
| // Populate an empty vector for each port. The caller will add nodes |
| // that use this port as input. |
| node_state.outputs[i] = {}; |
| } |
| // Port_num -1 is for control dependency. |
| node_state.time_no_references[-1] = Costs::Duration::max(); |
| node_state.num_outputs_executed[-1] = 0; |
| node_state.outputs[-1] = {}; |
| |
| // Initialize time_scheduled to infinity, so we know whether it has been |
| // assigned a non-default value later. |
| node_state.time_scheduled = Costs::Duration().infinity(); |
| |
| return it->second; |
| } |
| |
| void SchedulerState::GetOutputNodes(const NodeDef* node, |
| const Costs::Duration& curr_time, |
| std::vector<const NodeDef*>* output_nodes) { |
| // Checks whether the Switch's output slots change over iterations. |
| int slot = -1; |
| if (IsSwitch(*node) && node->attr().count(kOutputSlots) > 0 && |
| node->attr().at(kOutputSlots).list().i_size() > 0) { |
| slot = node->attr().at(kOutputSlots).list().i(0); |
| for (int i = 1; i < node->attr().at(kOutputSlots).list().i_size(); ++i) { |
| if (slot != node->attr().at(kOutputSlots).list().i(i)) { |
| slot = -1; |
| break; |
| } |
| } |
| } |
| // Increment num_inputs_ready of the output nodes and maybe add to ready |
| // nodes. |
| auto& node_state = node_map_[node]; |
| for (const auto& port_num_output_pair : node_state.outputs) { |
| // If Switch is annotated and its output slots are always the same, we only |
| // schedule the slot that was executed. Otherwise, scheduler both slots. |
| if (slot >= 0 && port_num_output_pair.first != slot) continue; |
| |
| for (auto* output_node : port_num_output_pair.second) { |
| auto& output_state = node_map_[output_node]; |
| output_state.num_inputs_ready++; |
| // Execute a node as soon as all its inputs are ready. Merge nodes are |
| // special since they run as soon as one of their inputs becomes |
| // available. |
| int output_state_inputs_size = output_state.inputs.size(); |
| if (output_state.num_inputs_ready == output_state_inputs_size || |
| IsMerge(*output_node)) { |
| // This output node is now ready. |
| output_state.time_ready = curr_time; |
| output_nodes->push_back(output_node); |
| VLOG(3) << " Add output: " << output_node->name(); |
| } |
| } |
| } |
| } |
| |
| std::vector<const NodeDef*> SchedulerState::MarkNodeExecuted( |
| const NodeDef* node, const Costs& node_costs, const OpContext& op_context) { |
| auto& node_state = node_map_[node]; |
| // TODO(dyoon, andiryxu): Consider to revisit node execution w.r.t. Switch and |
| // Merge -- it can create a loop which may include loop-carried dependency, |
| // diverge-merge, and other complex execution patterns. |
| bool previously_executed_merge = |
| IsMerge(*node) && (node_state.time_finished != Costs::Duration::max()); |
| |
| // If there is annotation in the graph about execution times, we use that |
| // number, otherwise, we assume the node is executed once. |
| node_state.execution_count = node->attr().count(kExecutionCount) == 0 |
| ? 1 |
| : node->attr().at(kExecutionCount).i(); |
| |
| node_state.node_costs = node_costs; |
| // TotalNodeCosts() Should be called after node_costs and execution_count. |
| Costs total_node_costs = node_state.TotalNodeCosts(); |
| |
| graph_costs_ = CombineCosts(graph_costs_, total_node_costs); |
| const string& op_name = node->op(); |
| |
| auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_); |
| op_cost = CombineCosts(op_cost, total_node_costs); |
| |
| if (VLOG_IS_ON(2)) { |
| // Also keep track of op counts and costs per op (with their shapes). |
| string node_description = GetOpDescription(op_context.op_info); |
| op_counts_[node_description] += 1; |
| op_costs_[node_description] = |
| std::make_pair(total_node_costs.execution_time.asMicroSeconds().count(), |
| !node_costs.inaccurate); |
| } |
| |
| // Update node and device states. |
| auto& device = device_[node_state.device_name]; |
| device.nodes_executed.push_back(node); |
| // Node is scheduled when the device is available AND all the inputs are |
| // ready; hence, time_scheduled is time_ready if time_ready > device curr |
| // time. |
| // NodeState times are assigned infinity at initialization. If they are |
| // still infinity here, we need to assign them. If not, it has been assigned |
| // already, so skip. This latter case may occur when a scheduler in-lines |
| // function calls, and thus schedules only function sub-nodes. |
| if (node_state.time_scheduled == Costs::Duration().infinity()) { |
| node_state.time_scheduled = |
| std::max(device.GetCurrTime(), node_state.time_ready); |
| // Override device curr time with the time_scheduled. |
| device.device_costs.execution_time = node_state.time_scheduled; |
| } |
| device.device_costs = CombineCosts(device.device_costs, total_node_costs); |
| auto curr_time = device.GetCurrTime(); |
| node_state.time_finished = curr_time; |
| |
| // Update shape annotation states. |
| UpdateDeviceAnnotationState(node, node_state, &device); |
| |
| // Update device memory usage. |
| if (!IsPersistent(*node)) { |
| for (const auto& port_num_output_pair : node_state.outputs) { |
| int port_num = port_num_output_pair.first; |
| |
| // There's a chance that a specific output is not used at all. |
| if (node_state.outputs[port_num].empty()) { |
| node_state.time_no_references[port_num] = curr_time; |
| } else { |
| // Streaming outputs do not allocate memory, they are directly consumed |
| // by the target node. |
| if (!IsStreamingPort(*node, port_num)) { |
| device.memory_usage += |
| CalculateOutputSize(node_state.output_properties, port_num) * |
| node_state.execution_count; |
| } |
| device.nodes_in_memory.insert(std::make_pair(node, port_num)); |
| } |
| } |
| } |
| |
| // Update device's per-op cost. |
| auto& device_op_cost = FindOrCreateZero(op_name, &device.op_to_cost); |
| device_op_cost = CombineCosts(device_op_cost, total_node_costs); |
| |
| VLOG(3) << "Op scheduled -- name: " << node->name() << ", op: " << node->op() |
| << ", device: " << node->device() |
| << ", execution_count: " << node_state.execution_count |
| << ", ready: " << node_state.time_ready.count() |
| << ", scheduled: " << node_state.time_scheduled.count() |
| << ", finished: " << node_state.time_finished.count(); |
| std::vector<const NodeDef*> new_nodes; |
| if (previously_executed_merge) { |
| // Skip AddOutputNodesToReadyQueue; this is due to Switch-Merge. |
| VLOG(1) << "node [ " << node->name() << ", " << node->op() << " ] " |
| << "is executed more than once. " |
| << "Skip scheduling its output nodes."; |
| } else { |
| // Checks outputs, and adds ready nodes to queue. |
| GetOutputNodes(node, curr_time, &new_nodes); |
| } |
| |
| // When op is scheduled, both input and output tensors must be allocated in |
| // memory. Now that output memory is added, check max memory usage. |
| if (!IsPersistent(*node)) { |
| if (device.memory_usage > device.max_memory_usage) { |
| device.max_memory_usage = device.memory_usage; |
| |
| if (track_mem_usage_snapshot_) { |
| device.mem_usage_snapshot_at_peak = device.nodes_in_memory; |
| } |
| } |
| } |
| |
| // Increment num_outputs_executed of the input nodes and maybe update memory. |
| for (const auto& input_port : node_state.inputs) { |
| auto* input = input_port.first; |
| auto port = input_port.second; |
| |
| auto& input_state = node_map_[input]; |
| input_state.num_outputs_executed[port]++; |
| int input_state_outputs_size_ = input_state.outputs[port].size(); |
| if (input_state.num_outputs_executed[port] == input_state_outputs_size_ && |
| !IsPersistent(*input)) { |
| // All the outputs are executed; no reference to this output port of |
| // input node. |
| input_state.time_no_references[port] = curr_time; |
| auto& input_device = device_[input_state.device_name]; |
| // If the node input is marked as streaming, then it wasn't allocated |
| // in memory. A streaming input is still reference counted, but it doesn't |
| // de-allocate memory. |
| if (!IsStreamingPort(*input, port)) { |
| input_device.memory_usage -= |
| CalculateOutputSize(input_state.output_properties, port) * |
| node_state.execution_count; |
| } |
| |
| input_device.nodes_in_memory.erase(std::make_pair(input, port)); |
| } |
| } |
| |
| return new_nodes; |
| } |
| |
| Costs SchedulerState::Summary() const { |
| // Overall statement about accuracy |
| VLOG(1) << graph_costs_.num_ops_total << " ops processed in total, with " |
| << graph_costs_.num_ops_with_unknown_shapes |
| << " having unknown shapes"; |
| |
| // Print out basic execution summary. |
| VLOG(1) << "Expected execution time: " << graph_costs_.execution_time.count(); |
| VLOG(1) << "Expected compute time: " << graph_costs_.compute_time.count(); |
| VLOG(1) << "Expected memory time: " << graph_costs_.memory_time.count(); |
| VLOG(1) << "Expected intermediate memory time: " |
| << graph_costs_.intermediate_memory_time.count(); |
| VLOG(1) << "Expected max memory: " << graph_costs_.max_memory; |
| VLOG(1) << "Expected max per-op buffers: " << graph_costs_.max_per_op_buffers; |
| VLOG(1) << "Expected max per-op streaming buffers: " |
| << graph_costs_.max_per_op_streaming; |
| |
| VLOG(1) << "Per-op execution time / compute time / memory time" |
| << " / intermediate memory time:"; |
| for (const auto& op_cost_pair : op_to_cost_) { |
| const auto& op = op_cost_pair.first; |
| const auto& cost = op_cost_pair.second.execution_time.count(); |
| const auto& compute_cost = op_cost_pair.second.compute_time.count(); |
| const auto& memory_cost = op_cost_pair.second.memory_time.count(); |
| const auto& intermediate_memory_cost = |
| op_cost_pair.second.intermediate_memory_time.count(); |
| const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate; |
| if (cost) { // Skip printing out zero-cost ops. |
| VLOG(1) << absl::StrFormat(" + %30s : %c %10d / %10d / %10d / %10d", op, |
| (is_op_cost_accurate ? ' ' : '~'), cost, |
| compute_cost, memory_cost, |
| intermediate_memory_cost); |
| } |
| } |
| |
| // Print per device summary |
| VLOG(1) << "Devices:"; |
| Costs critical_path_costs = Costs::ZeroCosts(); |
| std::vector<string> device_names; |
| device_names.reserve(device_.size()); |
| for (auto& it : device_) { |
| device_names.push_back(it.first); |
| } |
| std::sort(device_names.begin(), device_names.end()); |
| |
| for (const auto& name : device_names) { |
| const auto& state = device_.at(name); |
| |
| std::map<string, int64> op_to_memory; |
| // First profile only persistent memory usage. |
| int64 persistent_memory_usage = 0; |
| std::set<string> persistent_ops; |
| for (const auto& node_port : state.persistent_nodes) { |
| const auto* node = node_port.first; |
| const auto port = node_port.second; |
| auto output_size = 0; |
| // Check if the node is in the node_map. It may be that the node executed |
| // on this device was executed by a different Scheduler. |
| if (node_map_.find(node) != node_map_.end()) { |
| output_size = |
| CalculateOutputSize(node_map_.at(node).output_properties, port); |
| } |
| persistent_memory_usage += output_size; |
| op_to_memory[node->op()] += output_size; |
| persistent_ops.insert(node->op()); |
| } |
| int64 max_memory_usage = persistent_memory_usage + state.max_memory_usage; |
| critical_path_costs.estimated_max_memory_per_device[name] = |
| max_memory_usage; |
| |
| const Costs::NanoSeconds wall_time_ns = state.GetCurrTime(); |
| VLOG(1) << "Device = " << name |
| << ", num_nodes = " << state.nodes_executed.size() |
| << ", wall_time_ns = " << wall_time_ns.count() << ", memory usage: " |
| << "persistent = " << HumanReadableNumBytes(persistent_memory_usage) |
| << ", peak = " << HumanReadableNumBytes(state.max_memory_usage) |
| << ", total = " << HumanReadableNumBytes(max_memory_usage) |
| << ", at the end: " << HumanReadableNumBytes(state.memory_usage); |
| |
| // Overall statement about accuracy |
| VLOG(1) << state.device_costs.num_ops_total |
| << " ops processed in total, with " |
| << state.device_costs.num_ops_with_unknown_shapes |
| << " having unknown shapes"; |
| |
| // Device shape annotation statistics. |
| const auto& device_annotation_stats = state.shape_annotation_stats; |
| if (device_annotation_stats.num_ops_annotated > 0) { |
| VLOG(1) << device_annotation_stats.num_ops_annotated |
| << " ops with shape annotation, with " |
| << device_annotation_stats.num_ops_executed_more_than_once |
| << " executed more than once, " |
| << device_annotation_stats.num_ops_with_dynamic_shapes |
| << " with dynamic shapes, " |
| << device_annotation_stats.num_ops_with_incompatible_shapes |
| << " with incompatible shapes, " |
| << device_annotation_stats.num_ops_executed |
| << " ops executed in total."; |
| } |
| |
| VLOG(1) << "Per-op execution time / compute time / memory time " |
| << " / intermediate memory time" |
| << " (and memory usage at peak memory usage):"; |
| |
| // Profile non-persistent op memory usage. |
| for (const auto& node_port : state.mem_usage_snapshot_at_peak) { |
| const auto* node = node_port.first; |
| const auto port = node_port.second; |
| // Check if the node is in the node_map. It may be that the node executed |
| // on this device was executed by a different Scheduler. |
| if (node_map_.find(node) != node_map_.end()) { |
| op_to_memory[node->op()] += |
| CalculateOutputSize(node_map_.at(node).output_properties, port); |
| } |
| } |
| Costs::NanoSeconds total_compute_time_ns; |
| bool is_total_cost_accurate = true; |
| for (const auto& op_cost_pair : state.op_to_cost) { |
| const auto& op = op_cost_pair.first; |
| const auto& cost = op_cost_pair.second.execution_time.count(); |
| const auto& compute_cost = op_cost_pair.second.compute_time.count(); |
| const auto& memory_cost = op_cost_pair.second.memory_time.count(); |
| const auto& intermediate_memory_cost = |
| op_cost_pair.second.intermediate_memory_time.count(); |
| total_compute_time_ns += op_cost_pair.second.execution_time; |
| const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate; |
| if (!is_op_cost_accurate) { |
| is_total_cost_accurate = false; |
| } |
| |
| int64 op_mem_usage = 0; |
| auto it = op_to_memory.find(op); |
| if (it != op_to_memory.end()) { |
| op_mem_usage = it->second; |
| } |
| |
| const float mem_usage_percent = |
| max_memory_usage > 0 ? Round2(100.0 * op_mem_usage / max_memory_usage) |
| : 0.0; |
| if (cost || mem_usage_percent > 1.0) { |
| // Print out only non-zero cost ops or ops with > 1% memory usage. |
| VLOG(1) << absl::StrFormat( |
| " + %30s : %c %10d / %10d / %10d / %10d", op.c_str(), |
| (is_op_cost_accurate ? ' ' : '~'), cost, compute_cost, |
| memory_cost, intermediate_memory_cost) |
| << " (" << HumanReadableNumBytes(op_mem_usage) << " [" |
| << mem_usage_percent << "%] " |
| << (persistent_ops.count(op) > 0 ? ": persistent op)" : ")"); |
| } |
| } |
| |
| int utilization = 0; |
| if (wall_time_ns.count() > 0) { |
| utilization = total_compute_time_ns.count() * 100 / wall_time_ns.count(); |
| } |
| VLOG(1) << "Device = " << name << ", total_compute_time_ns = " |
| << (is_total_cost_accurate ? "" : "~") |
| << total_compute_time_ns.count() |
| << ", utilization = " << utilization << "%"; |
| |
| if (critical_path_costs.execution_time <= state.GetCurrTime()) { |
| critical_path_costs = state.device_costs; |
| } |
| } |
| |
| if (VLOG_IS_ON(2)) { |
| // Also log the op description and their corresponding counts. |
| VLOG(2) << "Node description, counts, cost:"; |
| for (const auto& item : op_counts_) { |
| int cost; |
| bool is_cost_accurate; |
| std::tie(cost, is_cost_accurate) = op_costs_.at(item.first); |
| VLOG(2) << "Node: " << item.first << ", Count: " << item.second |
| << ", Individual Cost: " << (is_cost_accurate ? "" : "~") << cost |
| << " us"; |
| } |
| } |
| |
| VLOG(1) << "Critical path execution time: " |
| << critical_path_costs.execution_time.count(); |
| return critical_path_costs; |
| } |
| |
| Costs SchedulerState::Summary(RunMetadata* metadata) { |
| if (metadata) GenerateRunMetadata(metadata); |
| return Summary(); |
| } |
| |
| void SchedulerState::GenerateRunMetadata(RunMetadata* metadata) { |
| // Fill RunMetadata's step_stats and partition_graphs fields. |
| StepStats* stepstats = metadata->mutable_step_stats(); |
| for (const auto& device : device_) { |
| GraphDef* device_partition_graph = metadata->add_partition_graphs(); |
| DeviceStepStats* device_stepstats = stepstats->add_dev_stats(); |
| device_stepstats->set_device(device.first); |
| for (const auto& node_def : device.second.nodes_executed) { |
| // Only proceed if the node is in the node_map. This is to cover the case |
| // where a device has executed a node that is not in the node_map of |
| // this scheduler. |
| if (node_map_.find(node_def) == node_map_.end()) { |
| continue; |
| } |
| const NodeState& nodestate = node_map_.at(node_def); |
| NodeExecStats* node_stats = device_stepstats->add_node_stats(); |
| uint64 total_output_size = 0; |
| for (int slot = 0, slot_end = nodestate.output_properties.size(); |
| slot < slot_end; slot++) { |
| const auto& properties = nodestate.output_properties[slot]; |
| NodeOutput* no = node_stats->add_output(); |
| no->set_slot(slot); |
| TensorDescription* tensor_descr = no->mutable_tensor_description(); |
| tensor_descr->set_dtype(properties.dtype()); |
| *tensor_descr->mutable_shape() = properties.shape(); |
| // Optional allocation description. |
| const auto tensor_size = |
| CalculateOutputSize(nodestate.output_properties, slot); |
| total_output_size += tensor_size; |
| tensor_descr->mutable_allocation_description()->set_requested_bytes( |
| tensor_size); |
| tensor_descr->mutable_allocation_description()->set_allocated_bytes( |
| tensor_size); |
| } |
| if (node_def->op() != "HloGenericOp") { |
| node_stats->set_timeline_label(node_def->op()); |
| } else { |
| // For HloGenericOp, display hlo_opcode as timeline label. |
| string timeline_label; |
| if (node_def->attr().count("hlo_opcode") > 0) { |
| absl::StrAppend(&timeline_label, |
| node_def->attr().at("hlo_opcode").s()); |
| } |
| if (node_def->attr().count("_hlo_metadata_op_type") > 0) { |
| absl::StrAppend(&timeline_label, "/", |
| node_def->attr().at("_hlo_metadata_op_type").s()); |
| } |
| node_stats->set_timeline_label(timeline_label); |
| } |
| node_stats->set_node_name(node_def->name()); |
| // Timestamps in microseconds (can be used by timeline_server). |
| node_stats->set_op_start_rel_micros(0); |
| node_stats->set_all_start_micros( |
| nodestate.time_scheduled.asMicroSeconds().count()); |
| node_stats->set_op_end_rel_micros( |
| nodestate.time_finished.asMicroSeconds().count() - |
| nodestate.time_scheduled.asMicroSeconds().count()); |
| node_stats->set_all_end_rel_micros( |
| nodestate.time_finished.asMicroSeconds().count() - |
| nodestate.time_scheduled.asMicroSeconds().count()); |
| // Timestamps in nanoseconds (can be used by xprof trace). |
| node_stats->set_op_start_rel_nanos(0); |
| node_stats->set_all_start_nanos(nodestate.time_scheduled.count()); |
| node_stats->set_op_end_rel_nanos(nodestate.time_finished.count() - |
| nodestate.time_scheduled.count()); |
| node_stats->set_all_end_rel_nanos(nodestate.time_finished.count() - |
| nodestate.time_scheduled.count()); |
| |
| auto* mem_stats = node_stats->mutable_memory_stats(); |
| // SchedulerState does not specify scratch pad memory usage. |
| mem_stats->set_temp_memory_size(0); |
| int64 persistent_memory_size = 0; |
| if (IsPersistent(*node_def)) { |
| persistent_memory_size = total_output_size; |
| } |
| mem_stats->set_persistent_memory_size(persistent_memory_size); |
| *device_partition_graph->add_node() = *node_def; |
| } |
| } |
| } |
| |
| const std::unordered_map<string, int64> SchedulerState::GetPeakMemoryUsage() |
| const { |
| std::unordered_map<string, int64> result; |
| for (const auto& device : device_) { |
| const string& name = device.first; |
| const DeviceState& state = device.second; |
| result[name] = state.max_memory_usage; |
| } |
| return result; |
| } |
| |
| const std::unordered_map<string, int64> |
| SchedulerState::GetPersistentMemoryUsage() const { |
| std::unordered_map<string, int64> result; |
| for (const auto& device : device_) { |
| const string& name = device.first; |
| const DeviceState& state = device.second; |
| int64 persistent_memory_usage = 0; |
| for (const auto& node_port : state.persistent_nodes) { |
| const auto* node = node_port.first; |
| const auto port = node_port.second; |
| const auto output_size = |
| CalculateOutputSize(node_map_.at(node).output_properties, port); |
| persistent_memory_usage += output_size; |
| } |
| result[name] = persistent_memory_usage; |
| } |
| return result; |
| } |
| |
| void SchedulerState::SetNodeStateTimeScheduled(const NodeDef* node) { |
| auto& node_state = node_map_.at(node); |
| auto& device = device_[node_state.device_name]; |
| node_state.time_scheduled = device.GetCurrTime(); |
| } |
| |
| VirtualScheduler::~VirtualScheduler() {} |
| |
| VirtualScheduler::VirtualScheduler(const bool use_static_shapes, |
| const bool use_aggressive_shape_inference, |
| Cluster* cluster, |
| ReadyNodeManager* ready_nodes, |
| std::unique_ptr<VirtualPlacer> placer) |
| : scheduler_state_(absl::make_unique<SchedulerState>( |
| use_static_shapes, use_aggressive_shape_inference, cluster, |
| std::move(placer))), |
| ready_nodes_(ready_nodes) {} |
| |
| VirtualScheduler::VirtualScheduler( |
| ReadyNodeManager* ready_nodes, |
| std::unique_ptr<SchedulerState> scheduler_state) |
| : scheduler_state_(std::move(scheduler_state)), ready_nodes_(ready_nodes) {} |
| |
| Status VirtualScheduler::Init(const GrapplerItem* item) { |
| // SchedulerState::Init() preprocesses the input grappler_item and |
| // graph_properties to extract necessary information for emulating tensorflow |
| // op scheduling and construct internal data structures (NodeState and |
| // DeviceState) for virtual scheduling. |
| TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates())); |
| std::vector<const NodeDef*> initial_nodes; |
| auto status = scheduler_state_->Init(item, &initial_nodes); |
| if (status.ok()) { |
| // Add the set of initial nodes to ready_nodes_ |
| for (auto node : initial_nodes) { |
| ready_nodes_->AddNode(node); |
| } |
| } |
| return status; |
| } |
| |
| OpContext VirtualScheduler::GetCurrNode() { |
| const NodeDef* node = ready_nodes_->GetCurrNode(); |
| return scheduler_state_->CreateOpContext(node); |
| } |
| |
| bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { |
| // Update graph_costs_ and per-op costs. |
| const NodeDef* node = ready_nodes_->GetCurrNode(); |
| auto new_nodes = scheduler_state_->MarkNodeExecuted( |
| node, node_costs, |
| scheduler_state_->CreateOpContext(ready_nodes_->GetCurrNode())); |
| // Add the set of new nodes obtained from MarkNodeExecuted() to ready_nodes_. |
| for (auto node : new_nodes) { |
| ready_nodes_->AddNode(node); |
| } |
| ready_nodes_->RemoveCurrNode(); |
| return !ready_nodes_->Empty(); |
| } |
| |
| } // end namespace grappler |
| } // end namespace tensorflow |