blob: 4ad717bff04e9bb1fc6f2587671a1e44a44d950a [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_PROFILER_UTILS_TF_OP_UTILS_H_
#define TENSORFLOW_CORE_PROFILER_UTILS_TF_OP_UTILS_H_
#include <string>
#include <vector>
#include "absl/strings/match.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
namespace profiler {
// Special op types.
TF_CONST_INIT extern const absl::string_view kUnknownOp;
TF_CONST_INIT extern const absl::string_view kDatasetOp;
TF_CONST_INIT extern const absl::string_view kMemcpyHToDOp;
TF_CONST_INIT extern const absl::string_view kMemcpyDToHOp;
enum class Category {
kTensorFlow,
kJax,
kTfData,
kMemcpyHToD,
kMemcpyDToH,
kUnknown,
};
// Breaks a TensorFlow op fullname into name and type.
struct TfOp {
Category category;
absl::string_view name;
absl::string_view type;
};
TfOp ParseTfOpFullname(absl::string_view tf_op_fullname);
// Returns a vector of TF name scopes extracted from a TF op name.
std::vector<absl::string_view> ParseTfNameScopes(absl::string_view tf_op_name);
std::vector<absl::string_view> ParseTfNameScopes(const TfOp& tf_op);
// Trace event name for TF ops is the op type so they have the same color in
// trace viewer.
std::string TfOpEventName(const TfOp& tf_op);
std::string TfOpEventName(absl::string_view tf_op_fullname);
// Trace event name for dataset ops.
std::string DatasetOpEventName(absl::string_view full_name);
// Returns the iterator name without prefix and parent iterator names.
std::string IteratorName(absl::string_view full_name);
// Returns true if the given name is a TensorFlow Dataset Op.
inline bool IsDatasetOp(absl::string_view tf_op_type) {
return tf_op_type == kDatasetOp;
}
inline bool IsDatasetOp(const TfOp& tf_op) {
return tf_op.category == Category::kTfData;
}
// Returns true if the given name is a TensorFlow Infeed Enqueue Op.
// See: tensorflow/core/tpu/kernels/infeed_ops.h
inline bool IsInfeedEnqueueOp(absl::string_view tf_op_type) {
return absl::StartsWith(tf_op_type, "InfeedEnqueue");
}
// Returns true if the given op is for outside compilation.
inline bool IsOutsideCompilationOp(absl::string_view tf_op_fullname,
absl::string_view hlo_expression) {
if (absl::EndsWith(tf_op_fullname, ":XlaSendToHost")) return true;
if (absl::StrContains(hlo_expression, "send-done") &&
absl::StrContains(hlo_expression, "is_host_transfer=true"))
return true;
return false;
}
// Returns true if the given name is a TensorFlow embedding op.
inline bool IsEmbeddingOp(absl::string_view tf_op_fullname) {
return absl::StrContains(tf_op_fullname, "Embedding");
}
// Returns true if the given op is for copying data from host to device.
inline bool IsMemcpyHToDOp(absl::string_view tf_op_type) {
return tf_op_type == kMemcpyHToDOp;
}
// Returns true if the given op is for copying data from device to host.
inline bool IsMemcpyDToHOp(absl::string_view tf_op_type) {
return tf_op_type == kMemcpyDToHOp;
}
// Splits a string of tensor shapes in "(shape1;shape2;...)" format, i.e.,
// delimited by '(' and ')' and separated by ';', into the individual shapes.
std::vector<absl::string_view> ParseTensorShapes(
absl::string_view tensor_shapes);
// Returns true if the given string matches OpDef.name pattern.
bool IsTfOpName(absl::string_view op_name);
// Returns true if the given string matches NodeDef.name pattern.
bool IsTfOpType(absl::string_view op_type);
// Returns true if the given string matches JAX pattern.
bool IsJaxOpType(absl::string_view op_type);
// Returns true if the given strings match JAX pattern.
bool IsJaxOpNameAndType(absl::string_view op_name, absl::string_view op_type);
} // namespace profiler
} // namespace tensorflow
#endif // TENSORFLOW_CORE_PROFILER_UTILS_TF_OP_UTILS_H_