blob: c9ed9b075f27650afaa8aaf57c5c7e31fd6dad89 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/profiler/utils/group_events.h"
#include <algorithm>
#include <cstdint>
#include <functional>
#include <iterator>
#include <map>
#include <memory>
#include <optional>
#include <queue>
#include <string>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/connected_traceme.h"
#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
#include "tensorflow/core/profiler/utils/xplane_builder.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/profiler/utils/xplane_utils.h"
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
namespace profiler {
namespace {
// Creates stat metadata for the stats which may be added by grouping.
void CreateStatMetadata(XPlane* plane) {
XPlaneBuilder builder(plane);
builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId));
builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kStepName));
builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kIsEager));
}
// Returns event type if it is a KernelLaunch or KernelExecute event.
absl::optional<int64_t> GetKernelEventType(bool is_host_plane,
const EventNode& event) {
if (event.GetEventVisitor().GetStat(StatType::kCorrelationId).has_value()) {
return is_host_plane ? HostEventType::kKernelLaunch
: HostEventType::kKernelExecute;
}
return absl::nullopt;
}
int64_t GetEventType(bool is_host_plane, const EventNode& event) {
if (absl::optional<int64_t> event_type = event.GetEventVisitor().Type()) {
return *event_type;
} else if (absl::optional<int64_t> kernel_event_type =
GetKernelEventType(is_host_plane, event)) {
// KernelLaunch and KernelExecute event types are not supported by
// XPlaneVisitor and should be checked separately.
// TODO(b/148346217): Make XPlaneVisitor support KernelLaunch and
// KernelExecute event types.
return *kernel_event_type;
} else {
return HostEventType::kUnknownHostEventType;
}
}
bool IsLegacyProducerEvent(const XEventVisitor& event) {
static const auto* const kProducerEvents = new absl::flat_hash_set<int64_t>{
HostEventType::kTraceContext, HostEventType::kFunctionRun,
HostEventType::kSessionRun, HostEventType::kRunGraph};
return event.Type().has_value() && kProducerEvents->contains(*event.Type());
}
bool IsLegacyConsumerEvent(const XEventVisitor& event) {
static const auto* const kConsumerEvents = new absl::flat_hash_set<int64_t>{
HostEventType::kExecutorStateProcess,
HostEventType::kExecutorDoneCallback, HostEventType::kRunGraphDone};
return event.Type().has_value() && kConsumerEvents->contains(*event.Type());
}
bool IsLegacyRootEvent(const XEventVisitor& event) {
static const auto* const kRootEvents = new absl::flat_hash_set<int64_t>{
HostEventType::kTraceContext, HostEventType::kFunctionRun,
HostEventType::kSessionRun, HostEventType::kRunGraph};
return event.Type().has_value() && kRootEvents->contains(*event.Type());
}
// Stats used in ConnectIntraThread.
struct GroupingEventStats {
explicit GroupingEventStats(const XEventVisitor& event);
std::optional<int> producer_type;
std::optional<uint64_t> producer_id;
std::optional<int> consumer_type;
std::optional<uint64_t> consumer_id;
std::optional<int> root_level;
bool is_async = false;
};
GroupingEventStats::GroupingEventStats(const XEventVisitor& event) {
std::optional<int64_t> step_id;
event.ForEachStat([&](const XStatVisitor& stat) {
if (!stat.Type().has_value()) return;
switch (*stat.Type()) {
case StatType::kProducerType:
producer_type = stat.IntValue();
break;
case StatType::kProducerId:
producer_id = stat.IntOrUintValue();
break;
case StatType::kConsumerType:
consumer_type = stat.IntValue();
break;
case StatType::kConsumerId:
consumer_id = stat.IntOrUintValue();
break;
case StatType::kIsRoot:
root_level = stat.IntValue();
break;
case StatType::kIsAsync:
is_async = stat.BoolValue();
break;
case StatType::kStepId:
step_id = stat.IntValue();
break;
default:
break;
}
});
if (!producer_type.has_value() || !producer_id.has_value()) {
if (step_id.has_value() && IsLegacyProducerEvent(event)) {
producer_type = static_cast<int>(ContextType::kTfExecutor);
producer_id = *step_id;
}
}
if (!consumer_type.has_value() || !consumer_id.has_value()) {
if (step_id.has_value() && IsLegacyConsumerEvent(event)) {
consumer_type = static_cast<int>(ContextType::kTfExecutor);
consumer_id = *step_id;
}
}
if (!root_level.has_value() && IsLegacyRootEvent(event)) {
root_level = 1;
}
}
void SetContextGroup(const GroupingEventStats& stats, EventNode* event,
ContextGroupMap* context_groups) {
if (stats.producer_type.has_value() && stats.producer_id.has_value()) {
((*context_groups)[*stats.producer_type][*stats.producer_id])
.producers.push_back(event);
}
if (stats.consumer_type.has_value() && stats.consumer_id.has_value()) {
((*context_groups)[*stats.consumer_type][*stats.consumer_id])
.consumers.push_back(event);
}
}
void ConnectContextGroups(const ContextGroupMap& context_groups) {
for (auto& type_id_group : context_groups) {
for (auto& id_group : type_id_group.second) {
const ContextGroup& group = id_group.second;
for (EventNode* parent : group.producers) {
for (EventNode* child : group.consumers) {
parent->AddChild(child);
}
}
}
}
}
bool HasFunctionRun(EventNode* event_node) {
for (EventNode* child : event_node->GetChildren()) {
if (child->GetEventVisitor().Type() == HostEventType::kFunctionRun) {
return true;
}
}
return false;
}
bool IsImplicitRootEvent(const XEventVisitor& event) {
static const auto* const kImplicitRootEvents =
new absl::flat_hash_set<int64_t>{
HostEventType::kFunctionRun, HostEventType::kSessionRun,
HostEventType::kRunGraph, HostEventType::kExecutorStateProcess};
return event.Type().has_value() &&
kImplicitRootEvents->contains(*event.Type());
}
void ProcessRootEvent(int64_t group_id, EventNode* root_event,
GroupMetadataMap* group_metadata_map) {
root_event->PropagateGroupId(group_id, group_metadata_map);
std::string group_name = root_event->GetGroupName();
// TODO(b/160255693): Change the event name instead.
if (!IsImplicitRootEvent(root_event->GetEventVisitor())) {
// Add the `step_name` stat for the user-defined root events only. When an
// XEvent is converted to a trace event, the trace event name is set to the
// `step_name` stat's value if present.
root_event->AddStepName(group_name);
}
(*group_metadata_map)[group_id].name = std::move(group_name);
}
using Comparator = std::function<bool(const EventNode*)>;
const EventNode* FindParentWithComparator(const Comparator& comparator,
const EventNode* node,
bool include_self) {
std::queue<const EventNode*> nodes;
absl::flat_hash_set<const EventNode*> seen = {node};
if (include_self) {
nodes.push(node);
} else {
for (const EventNode* parent : node->GetParents()) {
nodes.push(parent);
seen.insert(parent);
}
}
while (!nodes.empty()) {
const EventNode* node = nodes.front();
nodes.pop();
if (comparator(node)) return node;
for (const EventNode* parent : node->GetParents()) {
if (seen.contains(parent)) continue;
nodes.push(parent);
seen.insert(parent);
}
}
return nullptr;
}
// Returns true if it has JAX-related events.
bool HasJaxEvent(const EventNodeMap& event_node_map) {
return event_node_map.contains(HostEventType::kExecuteOnLocalDevices);
}
bool IsIteratorEventType(absl::optional<int64_t> event_type) {
return event_type == HostEventType::kIterator ||
event_type == HostEventType::kDeviceInputPipelineSecondIterator;
}
} // namespace
// Returns true if TF's loop ops exist in the given XSpace's metadata.
bool CheckLoopOp(const XSpace& space) {
for (const XPlane& plane : space.planes()) {
for (const auto& event_metadata : plane.event_metadata()) {
absl::optional<int64_t> event_type =
FindHostEventType(event_metadata.second.name());
if (!event_type.has_value()) continue;
switch (*event_type) {
case HostEventType::kWhileOpEvalCond:
case HostEventType::kWhileOpStartBody:
case HostEventType::kForOp:
case HostEventType::kParallelForOp:
case HostEventType::kForeverOp:
return true;
default:
break;
}
}
}
return false;
}
absl::optional<XStatVisitor> EventNode::GetContextStat(
int64_t stat_type) const {
std::queue<const EventNode*> nodes;
absl::flat_hash_set<const EventNode*> seen = {this};
nodes.push(this);
while (!nodes.empty()) {
const EventNode* node = nodes.front();
nodes.pop();
if (absl::optional<XStatVisitor> stat = node->visitor_.GetStat(stat_type)) {
return stat;
}
for (const EventNode* parent : node->GetParents()) {
if (seen.contains(parent)) continue;
nodes.push(parent);
seen.insert(parent);
}
}
return absl::nullopt;
}
std::string EventNode::GetGroupName() const {
std::string name;
if (absl::optional<XStatVisitor> stat =
GetContextStat(StatType::kGraphType)) {
absl::StrAppend(&name, stat->StrOrRefValue(), " ");
} else if (!(IsImplicitRootEvent(visitor_))) {
absl::StrAppend(&name, GetEventVisitor().Name(), " ");
}
int64_t step_num = group_id_.value_or(0);
if (absl::optional<XStatVisitor> stat = GetContextStat(StatType::kIterNum)) {
step_num = stat->IntValue();
} else if (absl::optional<XStatVisitor> stat =
GetContextStat(StatType::kStepNum)) {
step_num = stat->IntValue();
}
absl::StrAppend(&name, step_num);
return name;
}
XStat* EventNode::FindOrAddStatByType(int64_t stat_type) {
const XPlaneVisitor& plane = visitor_.Plane();
const XStatMetadata* stat_metadata = plane.GetStatMetadataByType(stat_type);
DCHECK(stat_metadata != nullptr);
auto* raw_event = const_cast<XEvent*>(&visitor_.RawEvent()); // NOLINT
return FindOrAddMutableStat(*stat_metadata, raw_event);
}
void EventNode::SetGroupId(int64_t group_id) {
group_id_ = group_id;
FindOrAddStatByType(StatType::kGroupId)->set_int64_value(group_id);
}
void EventNode::PropagateGroupId(int64_t group_id,
GroupMetadataMap* group_metadata_map) {
std::queue<EventNode*> nodes;
absl::flat_hash_set<EventNode*> seen = {this};
nodes.push(this);
while (!nodes.empty()) {
EventNode* node = nodes.front();
nodes.pop();
absl::optional<int64_t> node_group_id = node->GetGroupId();
if (node_group_id.has_value()) {
if (*node_group_id != group_id) {
(*group_metadata_map)[group_id].children.insert(*node_group_id);
(*group_metadata_map)[*node_group_id].parents.insert(group_id);
}
} else {
node->SetGroupId(group_id);
for (EventNode* child : node->GetChildren()) {
if (seen.contains(child)) continue;
nodes.push(child);
seen.insert(child);
}
}
}
}
void EventNode::AddStepName(absl::string_view step_name) {
FindOrAddStatByType(StatType::kStepName)
->set_str_value(step_name.data(), step_name.size());
}
void EventNode::SetIsEager(bool is_eager) {
FindOrAddStatByType(StatType::kIsEager)->set_int64_value(is_eager ? 1 : 0);
}
bool EventNode::IsCompiledFunc() const {
auto is_func = visitor_.GetStat(StatType::kIsFunc);
return !is_func || is_func->IntValue();
}
bool EventNode::IsEager() {
/* Both eager mode (op-by-op) and non-eager mode (eager functions) of eager
* executions are unified and forward to TF1 executor now. Therefore we will
* check following conditions:
*/
auto* node = FindParent(HostEventType::kEagerKernelExecute);
if (node == nullptr) {
// if current op is NOT scheduled under "EagerExecute", likely this is
// from TF1, therefore not eager.
return false;
}
// Otherwise, it is eager mode execution of an operation if and only if it is
// not a eager mode execution of a compiled function.
return !node->IsCompiledFunc();
}
const EventNode* EventNode::FindParent(int64_t event_type) const {
return FindParentWithComparator(
[event_type](const EventNode* node) {
return node->GetEventVisitor().Type() == event_type;
},
this, /*include_self=*/true);
}
void EventForest::ConnectIntraThread(XPlane* plane, XPlaneVisitor* visitor,
ContextGroupMap* context_groups) {
// TODO(b/149095099): avoid string comparison.
bool is_host_plane = (visitor->Name() == kHostThreadsPlaneName);
for (auto& line : *plane->mutable_lines()) {
std::vector<EventNode*> parent_nodes;
for (auto& event : *line.mutable_events()) {
XEventVisitor event_visitor(visitor, &line, &event);
GroupingEventStats stats(event_visitor);
auto cur_node = std::make_unique<EventNode>(std::move(event_visitor));
if (stats.root_level.has_value()) {
cur_node->SetRootLevel(*stats.root_level);
}
// Update `context_groups` for `ConnectInterThread`.
SetContextGroup(stats, cur_node.get(), context_groups);
// Async events are ignored when processing the nesting relationship.
if (!stats.is_async) {
while (!parent_nodes.empty()) {
EventNode* parent_node = parent_nodes.back();
if (parent_node->GetEventVisitor().GetTimespan().Includes(
cur_node->GetEventVisitor().GetTimespan())) {
parent_node->AddChild(cur_node.get());
break;
} else {
parent_nodes.pop_back();
}
}
parent_nodes.push_back(cur_node.get());
}
// event_node_map_ keeps cur_node alive.
event_node_map_[GetEventType(is_host_plane, *cur_node)].push_back(
std::move(cur_node));
}
}
}
void EventForest::ConnectInterThread(
const std::vector<InterThreadConnectInfo>& connect_info_list) {
for (const auto& connect_info : connect_info_list) {
absl::flat_hash_map<std::vector<uint64>, EventNode*> connect_map;
const std::vector<int64_t>& parent_stat_types =
connect_info.parent_stat_types;
const std::vector<int64_t>* child_stat_types =
&connect_info.child_stat_types;
if (child_stat_types->empty()) {
child_stat_types = &parent_stat_types;
}
if (auto parent_event_node_list =
gtl::FindOrNull(event_node_map_, connect_info.parent_event_type)) {
for (const auto& parent_event_node : *parent_event_node_list) {
std::vector<uint64> stats;
for (auto stat_type : parent_stat_types) {
absl::optional<XStatVisitor> stat =
parent_event_node->GetContextStat(stat_type);
if (!stat) break;
stats.push_back(stat->IntOrUintValue());
}
if (stats.size() == parent_stat_types.size()) {
connect_map[stats] = parent_event_node.get();
}
}
}
if (auto child_event_node_list =
gtl::FindOrNull(event_node_map_, connect_info.child_event_type)) {
for (const auto& child_event_node : *child_event_node_list) {
std::vector<uint64> stats;
for (auto stat_type : *child_stat_types) {
absl::optional<XStatVisitor> stat =
child_event_node->GetContextStat(stat_type);
if (!stat) break;
stats.push_back(stat->IntOrUintValue());
}
if (stats.size() == child_stat_types->size()) {
if (auto parent_event_node = gtl::FindPtrOrNull(connect_map, stats)) {
parent_event_node->AddChild(child_event_node.get());
}
}
}
}
}
}
// Returns whether a root event needs grouping.
bool RootNeedsGrouping(const EventNode* root) {
// No grouping is needed if it is already grouped.
if (root->GetGroupId().has_value()) return false;
// If there is a parent node with the same root level, skip grouping at <root>
// and later apply grouping at the parent node.
// If there is a parent node with a different root level, apply grouping at
// <root>, and later apply grouping at the parent node. Root events with
// different levels are grouped separately.
const EventNode* root_parent = FindParentWithComparator(
[root](const EventNode* parent) {
return parent->RootLevel() == root->RootLevel();
},
root,
/*include_self=*/false);
return root_parent == nullptr;
}
// Sorts root events based on root level and timestamp.
void SortRootEventList(EventList* event_list) {
absl::c_sort(*event_list, [](const EventNode* e1, const EventNode* e2) {
// If two root events have the same root level, the root event with an
// earlier timestamp will be processed first. Otherwise, the event with a
// larger root level will be processed first.
return e1->RootLevel() == e2->RootLevel()
? *e1 < *e2
: e1->RootLevel() > e2->RootLevel();
});
}
void EventForest::CreateEventGroups() {
// Create a group for each TF loop iteration in non-JAX profiles.
int64_t group_id = 0;
if (!HasJaxEvent(event_node_map_) && !tf_loop_root_events_.empty()) {
for (EventNode* root_event : tf_loop_root_events_) {
ProcessRootEvent(group_id++, root_event, &group_metadata_map_);
}
return;
}
// Iterate over all events and collect all root events.
EventList root_events;
for (const auto& typed_events : event_node_map_) {
for (const auto& event : typed_events.second) {
if (!event->RootLevel()) continue;
absl::optional<XStatVisitor> step_id_stat =
event->GetEventVisitor().GetStat(StatType::kStepId);
// If this is a root event that associated with tf.data, skip.
if (step_id_stat && tf_data_step_ids_.contains(step_id_stat->IntValue()))
continue;
root_events.push_back(event.get());
}
}
SortRootEventList(&root_events);
for (EventNode* root_event : root_events) {
if (RootNeedsGrouping(root_event) &&
// Ignores legacy TF root events for JAX profiles.
(!HasJaxEvent(event_node_map_) ||
!IsLegacyRootEvent(root_event->GetEventVisitor()))) {
ProcessRootEvent(group_id++, root_event, &group_metadata_map_);
}
}
}
void EventForest::MarkEagerlyExecutedGpuKernels() {
auto kernel_execute_event_node_list =
gtl::FindOrNull(event_node_map_, HostEventType::kKernelExecute);
if (!kernel_execute_event_node_list) return;
for (auto& kernel_execute_event_node : *kernel_execute_event_node_list) {
kernel_execute_event_node->SetIsEager(kernel_execute_event_node->IsEager());
}
}
void EventForest::MarkEagerlyExecutedCpuTfOps() {
auto tf_op_run_event_node_list =
gtl::FindOrNull(event_node_map_, HostEventType::kTfOpRun);
if (!tf_op_run_event_node_list) return;
for (auto& tf_op_run_event_node : *tf_op_run_event_node_list) {
tf_op_run_event_node->SetIsEager(tf_op_run_event_node->IsEager());
}
}
void EventForest::ProcessTfDataSteps() {
const int64_t tf_data_event_types[] = {
HostEventType::kTfDataCapturedFunctionRun,
HostEventType::kTfDataCapturedFunctionRunAsync,
HostEventType::kTfDataCapturedFunctionRunInstantiated,
HostEventType::kTfDataCapturedFunctionRunWithBorrowedArgs};
for (const int64_t tf_data_event_type : tf_data_event_types) {
auto tf_data_events = gtl::FindOrNull(event_node_map_, tf_data_event_type);
if (!tf_data_events) continue;
for (const auto& tf_data_event : *tf_data_events) {
absl::optional<XStatVisitor> step_id_stat =
tf_data_event->GetEventVisitor().GetStat(StatType::kStepId);
if (!step_id_stat) continue;
tf_data_step_ids_.insert(step_id_stat->IntValue());
}
}
}
void EventForest::ProcessTensorFlowLoop() {
struct TensorFlowLoopIteration {
EventNode* first_event = nullptr;
std::vector<EventNode*> events;
};
using TensorFlowLoop =
absl::flat_hash_map<int64_t /*iter_num*/, TensorFlowLoopIteration>;
absl::flat_hash_map<int64_t /*step_id*/, TensorFlowLoop> tf_loops;
// Sort the TF executor events by TF function/session (step_id) and iter_num.
auto executor_event_list =
gtl::FindOrNull(event_node_map_, HostEventType::kExecutorStateProcess);
if (!executor_event_list) return;
for (auto& executor_event : *executor_event_list) {
absl::optional<XStatVisitor> step_id_stat =
executor_event->GetEventVisitor().GetStat(StatType::kStepId);
absl::optional<XStatVisitor> iter_num_stat =
executor_event->GetEventVisitor().GetStat(StatType::kIterNum);
if (!step_id_stat || !iter_num_stat) continue;
int64_t step_id = step_id_stat->IntValue();
// Skip tf.data events.
if (tf_data_step_ids_.contains(step_id)) continue;
TensorFlowLoop& tf_loop = tf_loops[step_id];
TensorFlowLoopIteration& iteration = tf_loop[iter_num_stat->IntValue()];
if (!iteration.first_event || *executor_event < *iteration.first_event) {
iteration.first_event = executor_event.get();
}
iteration.events.push_back(executor_event.get());
}
std::vector<const TensorFlowLoopIteration*> iters;
for (const auto& step_id_and_tf_loop : tf_loops) {
const TensorFlowLoop& tf_loop = step_id_and_tf_loop.second;
// Filter out TF function/session without loops.
if (tf_loop.size() == 1 && tf_loop.contains(0)) continue;
for (const auto& iter_num_and_iter : tf_loop) {
iters.push_back(&iter_num_and_iter.second);
}
}
// Sort iterations based on timestamp of the first event in the iteration.
absl::c_sort(iters, [](const auto& iter1, const auto& iter2) {
return *iter1->first_event < *iter2->first_event;
});
// Register the first event of each iteration as a root event. Also, add the
// other events of the iteration as child to the root event.
for (const TensorFlowLoopIteration* iter : iters) {
EventNode* root_event = iter->first_event;
tf_loop_root_events_.push_back(root_event);
for (EventNode* event : iter->events) {
if (event == root_event) continue;
root_event->AddChild(event);
}
}
}
void EventForest::ProcessWorker() {
auto eager_kernel_execute_event_list =
gtl::FindOrNull(event_node_map_, HostEventType::kEagerKernelExecute);
if (!eager_kernel_execute_event_list) return;
// The last EagerKernelExecute with a FunctionRun child.
EventNode* root_event = nullptr;
for (auto& eager_kernel_execute_event : *eager_kernel_execute_event_list) {
if (HasFunctionRun(eager_kernel_execute_event.get())) {
// A function op becomes a new root.
root_event = eager_kernel_execute_event.get();
root_event->SetRootLevel(1);
} else if (root_event) {
// Add non-function eager ops as child.
root_event->AddChild(eager_kernel_execute_event.get());
}
}
}
void EventForest::ProcessModelIds() {
const int64_t model_id_event_type_list[] = {HostEventType::kSessionRun,
HostEventType::kTfrtModelRun};
for (const int64_t event_type : model_id_event_type_list) {
auto event_list = gtl::FindOrNull(event_node_map_, event_type);
if (!event_list) continue;
for (const auto& event : *event_list) {
auto group_id = event->GetGroupId();
if (!group_id.has_value()) continue;
absl::optional<XStatVisitor> model_id =
event->GetEventVisitor().GetStat(StatType::kModelId);
if (!model_id.has_value()) continue;
group_metadata_map_[*group_id].model_id = model_id->ToString();
}
}
}
void EventForest::AddPlane(
const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
XPlane* plane) {
CreateStatMetadata(plane);
planes_.push_back({plane, visitor_factory(plane)});
}
void EventForest::AddSpace(
const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
XSpace* space) {
for (XPlane& plane : *space->mutable_planes()) {
AddPlane(visitor_factory, &plane);
}
}
void EventForest::AddPlanes(
const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
const std::vector<XPlane*>& planes) {
for (XPlane* plane : planes) {
AddPlane(visitor_factory, plane);
}
}
void EventForest::ConnectEvents(
const std::vector<InterThreadConnectInfo>& connect_info_list) {
ContextGroupMap context_groups;
for (auto& plane_visitor : planes_) {
ConnectIntraThread(plane_visitor.first, &plane_visitor.second,
&context_groups);
}
ConnectInterThread(connect_info_list);
ConnectContextGroups(context_groups);
}
void EventForest::ConnectTfDataEvents() {
absl::flat_hash_map<
std::pair<int64_t /*iterator_id*/, int64_t /*element_id*/>,
std::vector<EventNode*>>
produce_iterator_map;
uint64 num_producers = 0;
for (HostEventType event_type :
{HostEventType::kPrefetchProduce,
HostEventType::kParallelInterleaveProduce,
HostEventType::kParallelMapProduce, HostEventType::kMapAndBatchProduce,
HostEventType::kParseExampleProduce,
HostEventType::kParallelBatchProduce}) {
auto produce_event_list = gtl::FindOrNull(event_node_map_, event_type);
if (!produce_event_list) continue;
VLOG(1) << produce_event_list->size() << " "
<< GetHostEventTypeStr(event_type) << " events found.";
for (auto& produce_event : *produce_event_list) {
absl::optional<XStatVisitor> element_id =
produce_event->GetEventVisitor().GetStat(StatType::kElementId);
if (!element_id.has_value()) continue;
for (EventNode* produce_iterator : produce_event->GetChildren()) {
if (IsIteratorEventType(produce_iterator->GetEventVisitor().Type())) {
absl::optional<XStatVisitor> iterator_id =
produce_iterator->GetEventVisitor().GetStat(StatType::kParentId);
if (!iterator_id.has_value()) break;
produce_iterator_map[{iterator_id->IntValue(),
element_id->IntValue()}]
.push_back(produce_iterator);
++num_producers;
break;
}
}
}
}
VLOG(1) << num_producers << " producer iterators found.";
uint64 num_matched = 0;
for (HostEventType event_type :
{HostEventType::kPrefetchConsume,
HostEventType::kParallelInterleaveConsume,
HostEventType::kParallelMapConsume, HostEventType::kMapAndBatchConsume,
HostEventType::kParseExampleConsume,
HostEventType::kParallelBatchConsume}) {
auto consume_event_list = gtl::FindOrNull(event_node_map_, event_type);
if (!consume_event_list) continue;
VLOG(1) << consume_event_list->size() << " "
<< GetHostEventTypeStr(event_type) << " events found.";
for (auto& consume_event : *consume_event_list) {
absl::optional<XStatVisitor> element_id =
consume_event->GetEventVisitor().GetStat(StatType::kElementId);
if (!element_id.has_value()) continue;
if (consume_event->GetParents().empty()) continue;
// consume_event is nested by consumer_iterator and does not have other
// parents.
EventNode* consume_iterator = consume_event->GetParents().at(0);
if (!consume_iterator ||
!IsIteratorEventType(consume_iterator->GetEventVisitor().Type())) {
continue;
}
absl::optional<XStatVisitor> iterator_id =
consume_iterator->GetEventVisitor().GetStat(StatType::kStepId);
if (!iterator_id.has_value()) continue;
if (auto produce_iterators = gtl::FindOrNull(
produce_iterator_map, std::make_pair(iterator_id->IntValue(),
element_id->IntValue()))) {
for (EventNode* produce_iterator : *produce_iterators) {
consume_iterator->AddChild(produce_iterator);
++num_matched;
}
}
}
}
VLOG(1) << num_matched << " consumer iterators matched.";
}
void EventForest::GroupEvents() {
ProcessTfDataSteps();
ProcessTensorFlowLoop();
ProcessWorker();
CreateEventGroups();
MarkEagerlyExecutedGpuKernels();
MarkEagerlyExecutedCpuTfOps();
ProcessModelIds();
}
std::vector<InterThreadConnectInfo> CreateInterThreadConnectInfoList() {
std::vector<InterThreadConnectInfo> connect_info_list = {
{HostEventType::kExecutorStateProcess,
HostEventType::kIteratorGetNextOp,
{StatType::kStepId, StatType::kIterNum}},
{HostEventType::kExecutorStateProcess,
HostEventType::kIteratorGetNextAsOptionalOp,
{StatType::kStepId, StatType::kIterNum}},
{HostEventType::kKernelLaunch,
HostEventType::kKernelExecute,
{StatType::kCorrelationId}}};
return connect_info_list;
}
void GroupTfEvents(XSpace* space, EventForest* event_forest) {
if (CheckLoopOp(*space)) {
// TODO(b/154510598): Support TF's loop ops.
return;
}
std::vector<InterThreadConnectInfo> connect_info_list =
CreateInterThreadConnectInfoList();
event_forest->AddSpace(CreateTfXPlaneVisitor, space);
event_forest->ConnectEvents(connect_info_list);
event_forest->GroupEvents();
}
void GroupTfEvents(XSpace* space) {
EventForest event_forest;
GroupTfEvents(space, &event_forest);
}
} // namespace profiler
} // namespace tensorflow