paves the datapath for xprof tool options, from `xspace_to_tools_data` to `ConvertMultiXSpacesToToolData`.
the tool options parameter is defaulted (to the empty dictionary), so python callers do not need to change. this CL does *not* add a 'graph_viewer' path to `raw_to_tool_data.py`.
PiperOrigin-RevId: 466793694
diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD
index 0b51cde..9b4d488 100644
--- a/tensorflow/core/profiler/convert/BUILD
+++ b/tensorflow/core/profiler/convert/BUILD
@@ -728,8 +728,10 @@
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core/profiler/utils:xplane_schema",
"//tensorflow/core/profiler/utils:xplane_utils",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:variant",
],
)
diff --git a/tensorflow/core/profiler/convert/xplane_to_tools_data.cc b/tensorflow/core/profiler/convert/xplane_to_tools_data.cc
index 34778b1..6187352 100644
--- a/tensorflow/core/profiler/convert/xplane_to_tools_data.cc
+++ b/tensorflow/core/profiler/convert/xplane_to_tools_data.cc
@@ -16,6 +16,7 @@
#include "tensorflow/core/profiler/convert/xplane_to_tools_data.h"
#include <utility>
+#include <variant>
#include <vector>
#include "absl/strings/str_format.h"
@@ -211,7 +212,9 @@
std::pair<std::string, bool> ConvertMultiXSpacesToToolData(
const std::vector<XSpace>& xspaces,
const std::vector<std::string>& filenames,
- const absl::string_view tool_name) {
+ const absl::string_view tool_name,
+ const absl::flat_hash_map<std::string, std::variant<int, std::string>>&
+ options) {
if (tool_name == "trace_viewer") {
return ConvertXSpaceToTraceEvents(xspaces);
} else if (tool_name == "overview_page") {
diff --git a/tensorflow/core/profiler/convert/xplane_to_tools_data.h b/tensorflow/core/profiler/convert/xplane_to_tools_data.h
index c1c4918..41c961b 100644
--- a/tensorflow/core/profiler/convert/xplane_to_tools_data.h
+++ b/tensorflow/core/profiler/convert/xplane_to_tools_data.h
@@ -18,9 +18,12 @@
#include <string>
#include <utility>
+#include <variant>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
+#include "absl/types/variant.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
namespace tensorflow {
@@ -33,7 +36,9 @@
std::pair<std::string, bool> ConvertMultiXSpacesToToolData(
const std::vector<XSpace>& xspaces,
const std::vector<std::string>& filenames,
- const absl::string_view tool_name);
+ const absl::string_view tool_name,
+ const absl::flat_hash_map<std::string, std::variant<int, std::string>>&
+ options);
} // namespace profiler
} // namespace tensorflow
diff --git a/tensorflow/python/profiler/BUILD b/tensorflow/python/profiler/BUILD
index 4e4a361..3eb7d4a 100644
--- a/tensorflow/python/profiler/BUILD
+++ b/tensorflow/python/profiler/BUILD
@@ -70,6 +70,19 @@
],
)
+py_test(
+ name = "profiler_wrapper_test",
+ srcs = ["profiler_wrapper_test.py"],
+ python_version = "PY3",
+ tags = [
+ "no_pip",
+ ],
+ deps = [
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python/profiler/internal:_pywrap_profiler",
+ ],
+)
+
py_library(
name = "option_builder",
srcs = ["option_builder.py"],
diff --git a/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc b/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc
index f6a9348..a63e0e5 100644
--- a/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc
+++ b/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc
@@ -15,6 +15,9 @@
#include "tensorflow/python/profiler/internal/profiler_pywrap_impl.h"
+#include <string>
+#include <variant>
+
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/strings/match.h"
@@ -123,7 +126,8 @@
// RemoteProfilerSessionManagerOptions.
RemoteProfilerSessionManagerOptions GetOptionsLocked(
absl::string_view logdir,
- const absl::flat_hash_map<std::string, absl::variant<int>>& opts) {
+ const absl::flat_hash_map<std::string, std::variant<int, std::string>>&
+ opts) {
RemoteProfilerSessionManagerOptions options;
*options.mutable_profiler_options() =
tensorflow::ProfilerSession::DefaultOptions();
@@ -143,19 +147,19 @@
for (const auto& kw : opts) {
absl::string_view key = kw.first;
if (key == "host_tracer_level") {
- int value = absl::get<int>(kw.second);
+ int value = std::get<int>(kw.second);
options.mutable_profiler_options()->set_host_tracer_level(value);
VLOG(1) << "host_tracer_level set to " << value;
} else if (key == "device_tracer_level") {
- int value = absl::get<int>(kw.second);
+ int value = std::get<int>(kw.second);
options.mutable_profiler_options()->set_device_tracer_level(value);
VLOG(1) << "device_tracer_level set to " << value;
} else if (key == "python_tracer_level") {
- int value = absl::get<int>(kw.second);
+ int value = std::get<int>(kw.second);
options.mutable_profiler_options()->set_python_tracer_level(value);
VLOG(1) << "python_tracer_level set to " << value;
} else if (key == "delay_ms") {
- int value = absl::get<int>(kw.second);
+ int value = std::get<int>(kw.second);
options.set_delay_ms(value);
VLOG(1) << "delay_ms was set to " << value;
} else {
@@ -170,7 +174,8 @@
absl::string_view service_addresses, absl::string_view logdir,
absl::string_view worker_list, bool include_dataset_ops,
int32_t duration_ms,
- const absl::flat_hash_map<std::string, absl::variant<int>>& opts,
+ const absl::flat_hash_map<std::string, std::variant<int, std::string>>&
+ opts,
bool* is_cloud_tpu_session) {
auto options = GetOptionsLocked(logdir, opts);
@@ -213,7 +218,8 @@
tensorflow::Status Trace(
const char* service_addr, const char* logdir, const char* worker_list,
bool include_dataset_ops, int duration_ms, int num_tracing_attempts,
- const absl::flat_hash_map<std::string, absl::variant<int>>& options) {
+ const absl::flat_hash_map<std::string, std::variant<int, std::string>>&
+ options) {
// TPU capture is true if the user sets worker_list.
bool is_cloud_tpu_session = false;
RemoteProfilerSessionManagerOptions opts =
@@ -242,7 +248,8 @@
tensorflow::Status ProfilerSessionWrapper::Start(
const char* logdir,
- const absl::flat_hash_map<std::string, absl::variant<int>>& options) {
+ const absl::flat_hash_map<std::string, std::variant<int, std::string>>&
+ options) {
auto opts = GetOptionsLocked(logdir, options);
session_ = tensorflow::ProfilerSession::Create(opts.profiler_options());
logdir_ = logdir;
diff --git a/tensorflow/python/profiler/internal/profiler_pywrap_impl.h b/tensorflow/python/profiler/internal/profiler_pywrap_impl.h
index fcdc676..3f75c81 100644
--- a/tensorflow/python/profiler/internal/profiler_pywrap_impl.h
+++ b/tensorflow/python/profiler/internal/profiler_pywrap_impl.h
@@ -16,6 +16,7 @@
#define TENSORFLOW_PYTHON_PROFILER_INTERNAL_PROFILER_PYWRAP_IMPL_H_
#include <string>
+#include <variant>
#include "absl/container/flat_hash_map.h"
#include "absl/types/variant.h"
@@ -29,7 +30,8 @@
tensorflow::Status Trace(
const char* service_addr, const char* logdir, const char* worker_list,
bool include_dataset_ops, int duration_ms, int num_tracing_attempts,
- const absl::flat_hash_map<std::string, absl::variant<int>>& options);
+ const absl::flat_hash_map<std::string, std::variant<int, std::string>>&
+ options);
tensorflow::Status Monitor(const char* service_addr, int duration_ms,
int monitoring_level, bool display_timestamp,
@@ -39,7 +41,8 @@
public:
tensorflow::Status Start(
const char* logdir,
- const absl::flat_hash_map<std::string, absl::variant<int>>& options);
+ const absl::flat_hash_map<std::string, std::variant<int, std::string>>&
+ options);
tensorflow::Status Stop(tensorflow::string* result);
tensorflow::Status ExportToTensorBoard();
diff --git a/tensorflow/python/profiler/internal/profiler_wrapper.cc b/tensorflow/python/profiler/internal/profiler_wrapper.cc
index bf2f87f..71677ad 100644
--- a/tensorflow/python/profiler/internal/profiler_wrapper.cc
+++ b/tensorflow/python/profiler/internal/profiler_wrapper.cc
@@ -15,6 +15,7 @@
#include <memory>
#include <string>
+#include <variant>
#include <vector>
#include "absl/container/flat_hash_map.h"
@@ -32,17 +33,25 @@
namespace {
-// This must be called under GIL because it reads Python objects. Reading Python
-// objects require GIL because the objects can be mutated by other Python
+// These must be called under GIL because it reads Python objects. Reading
+// Python objects require GIL because the objects can be mutated by other Python
// threads. In addition, Python objects are reference counted; reading py::dict
// will increase its reference count.
-absl::flat_hash_map<std::string, absl::variant<int>> ConvertDictToMap(
- const py::dict& dict) {
- absl::flat_hash_map<std::string, absl::variant<int>> map;
- for (const auto& kw : dict) {
- if (!kw.second.is_none()) {
- map.emplace(kw.first.cast<std::string>(), kw.second.cast<int>());
+absl::flat_hash_map<std::string, std::variant<int, std::string>>
+ConvertDictToMap(const py::dict& dictionary) {
+ absl::flat_hash_map<std::string, std::variant<int, std::string>> map;
+ for (const auto& item : dictionary) {
+ std::variant<int, std::string> value;
+ try {
+ value = item.second.cast<int>();
+ } catch (...) {
+ try {
+ value = item.second.cast<std::string>();
+ } catch (...) {
+ continue;
+ }
}
+ map.emplace(item.first.cast<std::string>(), value);
}
return map;
}
@@ -57,11 +66,11 @@
[](ProfilerSessionWrapper& wrapper, const char* logdir,
const py::dict& options) {
tensorflow::Status status;
- absl::flat_hash_map<std::string, absl::variant<int>> opts =
- ConvertDictToMap(options);
+ absl::flat_hash_map<std::string, std::variant<int, std::string>>
+ cxx_options = ConvertDictToMap(options);
{
py::gil_scoped_release release;
- status = wrapper.Start(logdir, opts);
+ status = wrapper.Start(logdir, cxx_options);
}
// Py_INCREF and Py_DECREF must be called holding the GIL.
tensorflow::MaybeRaiseRegisteredFromStatus(status);
@@ -91,7 +100,7 @@
m.def("start_server", [](int port) {
auto profiler_server =
- absl::make_unique<tensorflow::profiler::ProfilerServer>();
+ std::make_unique<tensorflow::profiler::ProfilerServer>();
profiler_server->StartProfilerServer(port);
// Intentionally release profiler server. Should transfer ownership to
// caller instead.
@@ -103,13 +112,13 @@
const char* worker_list, bool include_dataset_ops, int duration_ms,
int num_tracing_attempts, py::dict options) {
tensorflow::Status status;
- absl::flat_hash_map<std::string, absl::variant<int>> opts =
- ConvertDictToMap(options);
+ absl::flat_hash_map<std::string, std::variant<int, std::string>>
+ cxx_options = ConvertDictToMap(options);
{
py::gil_scoped_release release;
status = tensorflow::profiler::pywrap::Trace(
service_addr, logdir, worker_list, include_dataset_ops,
- duration_ms, num_tracing_attempts, opts);
+ duration_ms, num_tracing_attempts, cxx_options);
}
// Py_INCREF and Py_DECREF must be called holding the GIL.
tensorflow::MaybeRaiseRegisteredFromStatus(status);
@@ -130,35 +139,42 @@
return content;
});
- m.def("xspace_to_tools_data",
- [](const py::list& xspace_path_list, const py::str& py_tool_name) {
- std::vector<tensorflow::profiler::XSpace> xspaces;
- xspaces.reserve(xspace_path_list.size());
- std::vector<std::string> filenames;
- filenames.reserve(xspace_path_list.size());
- for (py::handle obj : xspace_path_list) {
- std::string filename = std::string(py::cast<py::str>(obj));
+ m.def(
+ "xspace_to_tools_data",
+ [](const py::list& xspace_path_list, const py::str& py_tool_name,
+ const py::dict options = py::dict()) {
+ std::vector<tensorflow::profiler::XSpace> xspaces;
+ xspaces.reserve(xspace_path_list.size());
+ std::vector<std::string> filenames;
+ filenames.reserve(xspace_path_list.size());
+ for (py::handle obj : xspace_path_list) {
+ std::string filename = std::string(py::cast<py::str>(obj));
- tensorflow::profiler::XSpace xspace;
- tensorflow::Status status;
+ tensorflow::profiler::XSpace xspace;
+ tensorflow::Status status;
- status = tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
- filename, &xspace);
+ status = tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
+ filename, &xspace);
- if (!status.ok()) {
- return py::make_tuple(py::bytes(""), py::bool_(false));
- }
-
- xspaces.push_back(xspace);
- filenames.push_back(filename);
+ if (!status.ok()) {
+ return py::make_tuple(py::bytes(""), py::bool_(false));
}
- std::string tool_name = std::string(py_tool_name);
- auto tool_data_and_success =
- tensorflow::profiler::ConvertMultiXSpacesToToolData(
- xspaces, filenames, tool_name);
- return py::make_tuple(py::bytes(tool_data_and_success.first),
- py::bool_(tool_data_and_success.second));
- });
+
+ xspaces.push_back(xspace);
+ filenames.push_back(filename);
+ }
+ std::string tool_name = std::string(py_tool_name);
+ absl::flat_hash_map<std::string, std::variant<int, std::string>>
+ cxx_options = ConvertDictToMap(options);
+ auto tool_data_and_success =
+ tensorflow::profiler::ConvertMultiXSpacesToToolData(
+ xspaces, filenames, tool_name, cxx_options);
+ return py::make_tuple(py::bytes(tool_data_and_success.first),
+ py::bool_(tool_data_and_success.second));
+ },
+ // TODO: consider defaulting `xspace_path_list` to empty list, since
+ // this parameter is only used for two of the tools...
+ py::arg(), py::arg(), py::arg() = py::dict());
m.def("xspace_to_tools_data_from_byte_string",
[](const py::list& xspace_string_list, const py::list& filenames_list,
@@ -185,11 +201,10 @@
for (py::handle obj : filenames_list) {
filenames.push_back(std::string(py::cast<py::str>(obj)));
}
-
std::string tool_name = std::string(py_tool_name);
auto tool_data_and_success =
tensorflow::profiler::ConvertMultiXSpacesToToolData(
- xspaces, filenames, tool_name);
+ xspaces, filenames, tool_name, {});
return py::make_tuple(py::bytes(tool_data_and_success.first),
py::bool_(tool_data_and_success.second));
});
diff --git a/tensorflow/python/profiler/profiler_wrapper_test.py b/tensorflow/python/profiler/profiler_wrapper_test.py
new file mode 100644
index 0000000..fb07ccc
--- /dev/null
+++ b/tensorflow/python/profiler/profiler_wrapper_test.py
@@ -0,0 +1,42 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for profiler_wrapper.cc pybind methods."""
+
+from tensorflow.python.eager import test
+from tensorflow.python.framework import test_util
+from tensorflow.python.profiler.internal import _pywrap_profiler as profiler_wrapper
+
+
+class ProfilerSessionTest(test_util.TensorFlowTestCase):
+
+ def test_xspace_to_tools_data_default_options(self):
+ # filenames only used for `tf_data_bottleneck_analysis` and
+ # `hlo_proto` tools.
+ profiler_wrapper.xspace_to_tools_data([], 'trace_viewer')
+
+ def _test_xspace_to_tools_data_options(self, options):
+ profiler_wrapper.xspace_to_tools_data([], 'trace_viewer', options)
+
+ def test_xspace_to_tools_data_empty_options(self):
+ self._test_xspace_to_tools_data_options({})
+
+ def test_xspace_to_tools_data_int_options(self):
+ self._test_xspace_to_tools_data_options({'example_option': 0})
+
+ def test_xspace_to_tools_data_str_options(self):
+ self._test_xspace_to_tools_data_options({'example_option': 'example'})
+
+if __name__ == '__main__':
+ test.main()