Export the tfprof functions from C++ to Python with pybind11 instead of swig. This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. It will also make exporting C++ ops to Python significantly easier. XLA is using the pybind11 macros already. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information.
PiperOrigin-RevId: 267530586
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index faa9f0f..338a77f 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -380,6 +380,21 @@
)
tf_python_pybind_extension(
+ name = "_pywrap_tfprof",
+ srcs = ["util/tfprof_wrapper.cc"],
+ module_name = "_pywrap_tfprof",
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/profiler/internal:print_model_analysis_hdr",
+ "//third_party/eigen3",
+ "//third_party/python_runtime:headers",
+ "@com_google_absl//absl/strings",
+ "@pybind11",
+ ],
+)
+
+tf_python_pybind_extension(
name = "_pywrap_utils",
srcs = ["util/util_wrapper.cc"],
hdrs = ["util/util.h"],
@@ -5023,7 +5038,6 @@
"util/kernel_registry.i",
"util/py_checkpoint_reader.i",
"util/scoped_annotation.i",
- "util/tfprof.i",
"util/traceme.i",
"util/transform_graph.i",
"//tensorflow/lite/toco/python:toco.i",
@@ -5118,6 +5132,7 @@
srcs = [
":cpp_python_util", # util
"//tensorflow/stream_executor:stream_executor_pimpl", # stat_summarizer
+ "//tensorflow/core/profiler/internal:print_model_analysis", # tfprof
],
outs = ["pybind_symbol_target_libs_file.txt"],
cmd = select({
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 06216f4..d733dc0 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -48,6 +48,7 @@
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import _pywrap_utils
+from tensorflow.python import _pywrap_tfprof
# Protocol buffers
from tensorflow.core.framework.graph_pb2 import *
diff --git a/tensorflow/python/profiler/BUILD b/tensorflow/python/profiler/BUILD
index eec7cd2..ef113ae 100644
--- a/tensorflow/python/profiler/BUILD
+++ b/tensorflow/python/profiler/BUILD
@@ -37,8 +37,8 @@
":option_builder",
":tfprof_logger",
"//tensorflow/core/profiler:protos_all_py",
+ "//tensorflow/python:_pywrap_tfprof",
"//tensorflow/python:errors",
- "//tensorflow/python:pywrap_tensorflow",
"@six_archive//:six",
],
)
diff --git a/tensorflow/python/profiler/internal/model_analyzer_testlib.py b/tensorflow/python/profiler/internal/model_analyzer_testlib.py
index 8956469..edce43b 100644
--- a/tensorflow/python/profiler/internal/model_analyzer_testlib.py
+++ b/tensorflow/python/profiler/internal/model_analyzer_testlib.py
@@ -19,7 +19,7 @@
import contextlib
-from tensorflow.python import pywrap_tensorflow as print_mdl
+from tensorflow.python import _pywrap_tfprof as print_mdl
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
diff --git a/tensorflow/python/profiler/model_analyzer.py b/tensorflow/python/profiler/model_analyzer.py
index f3b521a..aa876e6 100644
--- a/tensorflow/python/profiler/model_analyzer.py
+++ b/tensorflow/python/profiler/model_analyzer.py
@@ -27,7 +27,7 @@
from google.protobuf import message
from tensorflow.core.profiler import tfprof_options_pb2
from tensorflow.core.profiler import tfprof_output_pb2
-from tensorflow.python import pywrap_tensorflow as print_mdl
+from tensorflow.python import _pywrap_tfprof as print_mdl
from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
diff --git a/tensorflow/python/profiler/profile_context.py b/tensorflow/python/profiler/profile_context.py
index fa4260a..c5c8d66 100644
--- a/tensorflow/python/profiler/profile_context.py
+++ b/tensorflow/python/profiler/profile_context.py
@@ -25,7 +25,7 @@
import threading
from tensorflow.core.protobuf import config_pb2
-from tensorflow.python import pywrap_tensorflow as print_mdl
+from tensorflow.python import _pywrap_tfprof as print_mdl
from tensorflow.python.client import session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index 11eba26..9ab8169 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -17,11 +17,10 @@
* The includes are intentionally not alphabetically sorted, as the order of
* includes follows dependency order */
-%include "tensorflow/python/pywrap_tfe.i"
-
-%include "tensorflow/python/util/tfprof.i"
%include "tensorflow/python/util/py_checkpoint_reader.i"
+%include "tensorflow/python/pywrap_tfe.i"
+
%include "tensorflow/python/lib/core/py_func.i"
%include "tensorflow/python/lib/core/py_exception_registry.i"
diff --git a/tensorflow/python/util/tfprof.i b/tensorflow/python/util/tfprof.i
deleted file mode 100644
index 06f1263..0000000
--- a/tensorflow/python/util/tfprof.i
+++ /dev/null
@@ -1,56 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-%include "tensorflow/python/lib/core/strings.i"
-%include "tensorflow/python/platform/base.i"
-
-%{
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/profiler/internal/print_model_analysis.h"
-
-using tensorflow::int64;
-%}
-
-%typemap(typecheck) const string & = char *;
-%typemap(in) const string& (string temp) {
- if (!_PyObjAs<string>($input, &temp)) return NULL;
- $1 = &temp;
-}
-%typemap(out) const string& {
-%#if PY_MAJOR_VERSION >= 3
- $result = PyUnicode_FromStringAndSize($1->data(), $1->size());
-%#else
- $result = PyString_FromStringAndSize($1->data(), $1->size());
-%#endif
-}
-%apply const string & {string &};
-%apply const string & {string *};
-
-%ignoreall
-
-%unignore tensorflow;
-%unignore tensorflow::tfprof;
-%unignore tensorflow::tfprof::PrintModelAnalysis;
-%unignore tensorflow::tfprof::NewProfiler;
-%unignore tensorflow::tfprof::ProfilerFromFile;
-%unignore tensorflow::tfprof::DeleteProfiler;
-%unignore tensorflow::tfprof::AddStep;
-%unignore tensorflow::tfprof::SerializeToString;
-%unignore tensorflow::tfprof::WriteProfile;
-%unignore tensorflow::tfprof::Profile;
-
-%include "tensorflow/core/profiler/internal/print_model_analysis.h"
-
-%unignoreall
\ No newline at end of file
diff --git a/tensorflow/python/util/tfprof_wrapper.cc b/tensorflow/python/util/tfprof_wrapper.cc
new file mode 100644
index 0000000..0d7b518
--- /dev/null
+++ b/tensorflow/python/util/tfprof_wrapper.cc
@@ -0,0 +1,46 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "include/pybind11/pybind11.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/profiler/internal/print_model_analysis.h"
+
+namespace py = pybind11;
+
+PYBIND11_MODULE(_pywrap_tfprof, m) {
+ m.def("PrintModelAnalysis",
+ [](const std::string* graph, const std::string* run_meta,
+ const std::string* op_log, const std::string* command,
+ const std::string* options) {
+ std::string temp = tensorflow::tfprof::PrintModelAnalysis(
+ graph, run_meta, op_log, command, options);
+ return py::bytes(temp);
+ });
+ m.def("NewProfiler", &tensorflow::tfprof::NewProfiler);
+ m.def("ProfilerFromFile", &tensorflow::tfprof::ProfilerFromFile);
+ m.def("DeleteProfiler", &tensorflow::tfprof::DeleteProfiler);
+ m.def("AddStep", &tensorflow::tfprof::AddStep);
+ m.def("SerializeToString", []() {
+ std::string temp = tensorflow::tfprof::SerializeToString();
+ return py::bytes(temp);
+ });
+ m.def("WriteProfile", &tensorflow::tfprof::WriteProfile);
+ m.def("Profile", [](const std::string* command, const std::string* options) {
+ std::string temp = tensorflow::tfprof::Profile(command, options);
+ return py::bytes(temp);
+ });
+}
diff --git a/tensorflow/tools/def_file_filter/symbols_pybind.txt b/tensorflow/tools/def_file_filter/symbols_pybind.txt
index 7bcb43b..0bfe998 100644
--- a/tensorflow/tools/def_file_filter/symbols_pybind.txt
+++ b/tensorflow/tools/def_file_filter/symbols_pybind.txt
@@ -1,4 +1,4 @@
-[cpp_python_util]
+[cpp_python_util] # util
tensorflow::swig::IsSequence
tensorflow::swig::IsSequenceOrComposite
tensorflow::swig::IsCompositeTensor
@@ -18,6 +18,17 @@
tensorflow::swig::AssertSameStructureForData
tensorflow::swig::RegisterType
-[stream_executor_pimpl]
+[stream_executor_pimpl] # stat_summarizer
stream_executor::StreamExecutor::EnablePeerAccessTo
stream_executor::StreamExecutor::CanEnablePeerAccessTo
+
+[print_model_analysis] # tfprof
+tensorflow::tfprof::NewProfiler
+tensorflow::tfprof::DeleteProfiler
+tensorflow::tfprof::AddStep
+tensorflow::tfprof::WriteProfile
+tensorflow::tfprof::ProfilerFromFile
+tensorflow::tfprof::Profile
+tensorflow::tfprof::PrintModelAnalysis
+tensorflow::tfprof::SerializeToString
+