blob: 5a3bb65562170914fadedd8bab53467632435a27 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/profiler/utils/group_events.h"
#include <stack>
#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
#include "tensorflow/core/profiler/utils/xplane_schema.h"
#include "tensorflow/core/profiler/utils/xplane_utils.h"
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
namespace tensorflow {
namespace profiler {
namespace {
static const int64 kFunctionalOpEventTypes[] = {
HostEventType::kCallOp,
HostEventType::kParallelForOp,
HostEventType::kForeverOp,
HostEventType::kNumericalGradientOpEvalRight,
HostEventType::kNumericalGradientOpEvalLeft,
HostEventType::kSymbolicGradientOp,
HostEventType::kRemoteCallOp,
HostEventType::kIfOp,
HostEventType::kCaseOp,
HostEventType::kWhileOpEvalCond,
HostEventType::kWhileOpStartBody,
HostEventType::kForOp,
HostEventType::kPartitionedCallOp,
};
// Creates stat metadata for the stats which may be added by grouping.
void CreateStatMetadata(XPlane* plane) {
XPlaneBuilder builder(plane);
builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId));
builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kStepName));
}
// Returns event type if it is a KernelLaunch or KernelExecute event.
absl::optional<int64> GetKernelEventType(const XPlaneVisitor& visitor,
const XEvent& event) {
for (const auto& stat : event.stats()) {
if (visitor.GetStatType(stat) == StatType::kCorrelationId) {
// TODO(b/149095099): avoid string comparison.
return visitor.Name() == kHostThreads ? HostEventType::kKernelLaunch
: HostEventType::kKernelExecute;
}
}
return absl::nullopt;
}
int64 GetEventType(const XPlaneVisitor& visitor, const XEvent& event) {
if (absl::optional<int64> event_type = visitor.GetEventType(event)) {
return *event_type;
} else if (absl::optional<int64> kernel_event_type =
GetKernelEventType(visitor, event)) {
// KernelLaunch and KernelExecute event types are not supported by
// XPlaneVisitor and should be checked separately.
// TODO(148346217): Make XPlaneVisitor support KernelLaunch and
// KernelExecute event types.
return *kernel_event_type;
} else {
return HostEventType::kUnknownHostEventType;
}
}
const XStat* GetStat(const XPlaneVisitor& visitor, const XEvent& event,
int64 stat_type) {
for (const auto& stat : event.stats()) {
if (visitor.GetStatType(stat) == stat_type) {
return &stat;
}
}
return nullptr;
}
void SetGroupId(const XPlaneVisitor& visitor, int64 group_id, XEvent* event) {
AddOrUpdateIntStat(*visitor.GetStatMetadataId(StatType::kGroupId), group_id,
event);
}
using VirtualEventNodeMap =
absl::flat_hash_map<int64 /*step_id*/,
absl::flat_hash_map<int64 /*iter_num*/, EventNode*>>;
std::unique_ptr<XEvent> CreateVirtualEvent(const XStat& step_id_stat,
const XStat& iter_num_stat) {
auto virtual_event = absl::make_unique<XEvent>();
*virtual_event->add_stats() = step_id_stat;
*virtual_event->add_stats() = iter_num_stat;
return virtual_event;
}
bool NeedsVirtualEventsForHostTrainingLoop(
const std::vector<int64 /*EventType*/>& root_event_types) {
return std::find(root_event_types.begin(), root_event_types.end(),
HostEventType::kHostTrainingLoopIteration) !=
root_event_types.end();
}
bool NeedsVirtualEventsForAsyncExecutor(
const std::vector<int64 /*EventType*/>& root_event_types) {
return std::find(root_event_types.begin(), root_event_types.end(),
HostEventType::kAsyncExecutorTraceContext) !=
root_event_types.end();
}
bool HasFunctionRun(EventNode* event_node) {
for (EventNode* child : event_node->GetChildren()) {
if (child->GetPlaneVisitor().GetEventType(child->GetEvent()) ==
HostEventType::kFunctionRun) {
return true;
}
}
return false;
}
} // namespace
const XStat* EventNode::GetContextStat(int64 stat_type) const {
if (const XStat* stat = GetStat(*visitor_, *event_, stat_type)) {
return stat;
} else if (parent_) {
return parent_->GetContextStat(stat_type);
}
return nullptr;
}
std::string EventNode::GetGroupName() const {
std::vector<std::string> name_parts;
if (const XStat* graph_type_stat = GetContextStat(StatType::kGraphType)) {
XStatVisitor stat(visitor_, graph_type_stat);
name_parts.push_back(stat.ToString());
}
int64 step_num = group_id_.value_or(0);
if (const XStat* step_num_stat = GetContextStat(StatType::kStepNum)) {
step_num = step_num_stat->int64_value();
}
if (const XStat* iter_num_stat = GetContextStat(StatType::kIterNum)) {
step_num = iter_num_stat->int64_value();
}
name_parts.push_back(absl::StrCat(step_num));
return absl::StrJoin(name_parts, " ");
}
void EventNode::PropagateGroupId(int64 group_id) {
group_id_ = group_id;
SetGroupId(*visitor_, group_id, event_);
for (const auto& child : children_) {
// Skip if it already belongs to a group. Some nodes may be added multiple
// times as child (e.g., sometimes async ops are executed synchronously and
// their nodes are added as child both in ConnectIntraThread and
// ConnectInterThread).
if (child->GetGroupId()) continue;
child->PropagateGroupId(*group_id_);
}
}
void EventNode::AddStepName(absl::string_view step_name) {
AddOrUpdateStrStat(*visitor_->GetStatMetadataId(StatType::kStepName),
step_name, event_);
}
bool EventNode::IsNestedIn(EventNode* parent) {
return parent && IsNested(GetEvent(), parent->GetEvent());
}
void EventForest::ConnectIntraThread(const XPlaneVisitor& visitor,
XPlane* plane) {
for (auto& line : *plane->mutable_lines()) {
std::vector<EventNode*> parent_nodes;
for (auto& event : *line.mutable_events()) {
auto cur_node = absl::make_unique<EventNode>(&visitor, &event);
while (!parent_nodes.empty()) {
EventNode* parent_node = parent_nodes.back();
if (cur_node->IsNestedIn(parent_node)) {
parent_node->AddChild(cur_node.get());
break;
} else {
parent_nodes.pop_back();
}
}
parent_nodes.push_back(cur_node.get());
// event_node_map_ keeps cur_node alive.
event_node_map_[GetEventType(visitor, event)].push_back(
std::move(cur_node));
}
}
}
void EventForest::ConnectInterThread(
const std::vector<InterThreadConnectInfo>& connect_info_list) {
for (const auto& connect_info : connect_info_list) {
absl::flat_hash_map<std::vector<int64>, EventNode*> connect_map;
const std::vector<int64>& parent_stat_types =
connect_info.parent_stat_types;
const std::vector<int64>* child_stat_types = &connect_info.child_stat_types;
if (child_stat_types->empty()) {
child_stat_types = &parent_stat_types;
}
if (auto parent_event_node_list =
gtl::FindOrNull(event_node_map_, connect_info.parent_event_type)) {
for (const auto& parent_event_node : *parent_event_node_list) {
std::vector<int64> stats;
for (auto stat_type : parent_stat_types) {
const XStat* stat = parent_event_node->GetContextStat(stat_type);
if (!stat) break;
stats.push_back(stat->value_case() == stat->kInt64Value
? stat->int64_value()
: stat->uint64_value());
}
if (stats.size() == parent_stat_types.size()) {
connect_map[stats] = parent_event_node.get();
}
}
}
if (auto child_event_node_list =
gtl::FindOrNull(event_node_map_, connect_info.child_event_type)) {
for (const auto& child_event_node : *child_event_node_list) {
std::vector<int64> stats;
for (auto stat_type : *child_stat_types) {
const XStat* stat = child_event_node->GetContextStat(stat_type);
if (!stat) break;
stats.push_back(stat->value_case() == stat->kInt64Value
? stat->int64_value()
: stat->uint64_value());
}
if (stats.size() == child_stat_types->size()) {
if (auto parent_event_node = gtl::FindPtrOrNull(connect_map, stats)) {
parent_event_node->AddChild(child_event_node.get());
}
}
}
}
}
}
void EventForest::CreateEventGroup(
const std::vector<int64 /*EventType*/>& root_event_types) {
int64 next_group_id = 0;
for (int64 root_event_type : root_event_types) {
if (auto root_event_node_list =
gtl::FindOrNull(event_node_map_, root_event_type)) {
for (const auto& root_event_node : *root_event_node_list) {
// Skip if it already belongs to a group.
if (root_event_node->GetGroupId()) continue;
int64 group_id = next_group_id++;
root_event_node->PropagateGroupId(group_id);
std::string group_name = root_event_node->GetGroupName();
// TODO(jihochoi): change event name instead.
root_event_node->AddStepName(group_name);
event_group_name_map_[group_id] = std::move(group_name);
}
// Only use the first root event type found.
if (!root_event_node_list->empty()) break;
}
}
}
void EventForest::CreateVirtualEventsForHostTrainingLoop() {
VirtualEventNodeMap virtual_event_node_map;
auto executor_event_node_list =
gtl::FindOrNull(event_node_map_, HostEventType::kExecutorStateProcess);
if (!executor_event_node_list) return;
for (auto& executor_event_node : *executor_event_node_list) {
const XStat* step_id_stat =
executor_event_node->GetContextStat(StatType::kStepId);
const XStat* iter_num_stat =
executor_event_node->GetContextStat(StatType::kIterNum);
if (!step_id_stat || !iter_num_stat) continue;
int64 step_id = step_id_stat->int64_value();
int64 iter_num = iter_num_stat->int64_value();
// Process the event with nonzero iter_num only to filter out the events
// related to tf.data.
// TODO(jihochoi): Filter out tf.data events more reliably.
if (!iter_num) continue;
EventNode*& virtual_event_node = virtual_event_node_map[step_id][iter_num];
if (!virtual_event_node) {
std::unique_ptr<XEvent> new_virtual_event =
CreateVirtualEvent(*step_id_stat, *iter_num_stat);
auto new_virtual_event_node = absl::make_unique<EventNode>(
&executor_event_node->GetPlaneVisitor(), new_virtual_event.get());
// virtual_event_container_ keeps new_virtual_event alive.
virtual_event_container_.push_back(std::move(new_virtual_event));
virtual_event_node = new_virtual_event_node.get();
// event_node_map_ keeps new_virtual_event_node alive.
event_node_map_[HostEventType::kHostTrainingLoopIteration].push_back(
std::move(new_virtual_event_node));
}
virtual_event_node->AddChild(executor_event_node.get());
}
}
void EventForest::CreateVirtualEventsForAsyncExecutor() {
auto eager_kernel_execute_event_node_list =
gtl::FindOrNull(event_node_map_, HostEventType::kEagerKernelExecute);
if (!eager_kernel_execute_event_node_list) return;
EventNode* virtual_event_node = nullptr;
for (auto& eager_kernel_execute_event_node :
*eager_kernel_execute_event_node_list) {
if (HasFunctionRun(eager_kernel_execute_event_node.get())) {
auto new_virtual_event = absl::make_unique<XEvent>();
auto new_virtual_event_node = absl::make_unique<EventNode>(
&eager_kernel_execute_event_node->GetPlaneVisitor(),
new_virtual_event.get());
// virtual_event_container_ keeps new_virtual_event alive.
virtual_event_container_.push_back(std::move(new_virtual_event));
virtual_event_node = new_virtual_event_node.get();
// event_node_map_ keeps new_virtual_event_node alive.
event_node_map_[HostEventType::kAsyncExecutorTraceContext].push_back(
std::move(new_virtual_event_node));
}
if (virtual_event_node) {
virtual_event_node->AddChild(eager_kernel_execute_event_node.get());
}
}
}
EventForest::EventForest(
const std::vector<InterThreadConnectInfo>& connect_info_list,
const std::vector<int64>& root_event_types,
const std::function<XPlaneVisitor(const XPlane*)> visitor_factory,
XSpace* space) {
visitors_.reserve(space->planes_size());
for (auto& plane : *space->mutable_planes()) {
CreateStatMetadata(&plane);
visitors_.push_back(visitor_factory(&plane));
ConnectIntraThread(visitors_.back(), &plane);
}
ConnectInterThread(connect_info_list);
if (NeedsVirtualEventsForHostTrainingLoop(root_event_types)) {
CreateVirtualEventsForHostTrainingLoop();
}
if (NeedsVirtualEventsForAsyncExecutor(root_event_types)) {
CreateVirtualEventsForAsyncExecutor();
}
CreateEventGroup(root_event_types);
}
std::vector<InterThreadConnectInfo> CreateInterThreadConnectInfoList() {
std::vector<InterThreadConnectInfo> connect_info_list = {
{HostEventType::kFunctionRun,
HostEventType::kExecutorStateProcess,
{StatType::kStepId}},
{HostEventType::kFunctionRun,
HostEventType::kExecutorDoneCallback,
{StatType::kStepId}},
{HostEventType::kSessionRun,
HostEventType::kExecutorStateProcess,
{StatType::kStepId}},
{HostEventType::kSessionRun,
HostEventType::kExecutorDoneCallback,
{StatType::kStepId}},
{HostEventType::kExecutorStateProcess,
HostEventType::kIteratorGetNextOp,
{StatType::kStepId, StatType::kIterNum}},
{HostEventType::kKernelLaunch,
HostEventType::kKernelExecute,
{StatType::kCorrelationId}},
{HostEventType::kLocalExecutableExecuteOnLocalDevice,
HostEventType::kLocalExecutableExecute,
{StatType::kRunId}}};
for (int64 event_type : kFunctionalOpEventTypes) {
connect_info_list.push_back({event_type,
HostEventType::kExecutorStateProcess,
{StatType::kFunctionStepId},
{StatType::kStepId}});
connect_info_list.push_back({event_type,
HostEventType::kExecutorDoneCallback,
{StatType::kFunctionStepId},
{StatType::kStepId}});
}
return connect_info_list;
}
void GroupTfEvents(XSpace* space, EventGroupNameMap* event_group_name_map) {
if (!space) return;
std::vector<InterThreadConnectInfo> connect_info_list =
CreateInterThreadConnectInfoList();
const std::vector<int64 /*EventType*/> root_event_types(
{HostEventType::kTraceContext, HostEventType::kFunctionRun,
HostEventType::kSessionRun, HostEventType::kHostTrainingLoopIteration});
EventForest event_forest(connect_info_list, root_event_types,
CreateTfXPlaneVisitor, space);
if (event_group_name_map) {
*event_group_name_map = event_forest.GetEventGroupNameMap();
}
}
} // namespace profiler
} // namespace tensorflow