Add PlatformProfiler to support op tracing using platform tracing tools.
PiperOrigin-RevId: 295872277
Change-Id: I8c02ec3974cd246bab70b47426778e9dda5938ee
diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD
index 4c21278..e9539d4 100644
--- a/tensorflow/lite/BUILD
+++ b/tensorflow/lite/BUILD
@@ -17,6 +17,13 @@
]))
config_setting(
+ name = "enable_default_profiler",
+ values = {
+ "copt": "-DTFLITE_ENABLE_DEFAULT_PROFILER",
+ },
+)
+
+config_setting(
name = "gemmlowp_profiling",
values = {
"copt": "-DGEMMLOWP_PROFILING",
@@ -239,7 +246,12 @@
"//tensorflow/lite/experimental/resource",
"//tensorflow/lite/nnapi:nnapi_implementation",
"//tensorflow/lite/schema:schema_fbs",
- ],
+ ] + select({
+ ":enable_default_profiler": [
+ "//tensorflow/lite/profiling:platform_profiler",
+ ],
+ "//conditions:default": [],
+ }),
alwayslink = 1,
)
diff --git a/tensorflow/lite/interpreter.cc b/tensorflow/lite/interpreter.cc
index b839ffd..d333fa7 100644
--- a/tensorflow/lite/interpreter.cc
+++ b/tensorflow/lite/interpreter.cc
@@ -349,6 +349,18 @@
}
void Interpreter::SetProfiler(Profiler* profiler) {
+ // Release resources occupied by owned_profiler_ which is replaced by
+ // caller-owned profiler.
+ owned_profiler_.reset(nullptr);
+ SetSubgraphProfiler(profiler);
+}
+
+void Interpreter::SetProfiler(std::unique_ptr<Profiler> profiler) {
+ owned_profiler_ = std::move(profiler);
+ SetSubgraphProfiler(owned_profiler_.get());
+}
+
+void Interpreter::SetSubgraphProfiler(Profiler* profiler) {
for (int subgraph_index = 0; subgraph_index < subgraphs_.size();
++subgraph_index) {
subgraphs_[subgraph_index]->SetProfiler(profiler, subgraph_index);
diff --git a/tensorflow/lite/interpreter.h b/tensorflow/lite/interpreter.h
index 4b4945c..093390a 100644
--- a/tensorflow/lite/interpreter.h
+++ b/tensorflow/lite/interpreter.h
@@ -410,6 +410,11 @@
/// WARNING: This is an experimental API and subject to change.
void SetProfiler(Profiler* profiler);
+ /// Same as SetProfiler except this interpreter takes ownership
+ /// of the provided profiler.
+ /// WARNING: This is an experimental API and subject to change.
+ void SetProfiler(std::unique_ptr<Profiler> profiler);
+
/// Gets the profiler used for op tracing.
/// WARNING: This is an experimental API and subject to change.
Profiler* GetProfiler();
@@ -496,6 +501,9 @@
TfLiteExternalContextType type,
TfLiteExternalContext* ctx);
+ // Sets the profiler to all subgraphs.
+ void SetSubgraphProfiler(Profiler* profiler);
+
// A pure C data structure used to communicate with the pure C plugin
// interface. To avoid copying tensor metadata, this is also the definitive
// structure to store tensors.
@@ -511,6 +519,10 @@
// TODO(b/116667551): Use TfLiteExternalContext for storing state.
std::vector<TfLiteDelegatePtr> owned_delegates_;
+ // Profiler that has been installed and is owned by this interpreter instance.
+ // Useful if client profiler ownership is burdensome.
+ std::unique_ptr<Profiler> owned_profiler_;
+
bool allow_buffer_handle_output_ = false;
// List of active external contexts.
diff --git a/tensorflow/lite/model.cc b/tensorflow/lite/model.cc
index 46fee7f..22a4cf2 100644
--- a/tensorflow/lite/model.cc
+++ b/tensorflow/lite/model.cc
@@ -29,6 +29,10 @@
#include "tensorflow/lite/util.h"
#include "tensorflow/lite/version.h"
+#if defined(TFLITE_ENABLE_DEFAULT_PROFILER)
+#include "tensorflow/lite/profiling/platform_profiler.h"
+#endif
+
namespace tflite {
namespace {
@@ -687,6 +691,10 @@
(*interpreter)->AddSubgraphs(subgraphs->Length() - 1);
}
+#if defined(TFLITE_ENABLE_DEFAULT_PROFILER)
+ (*interpreter)->SetProfiler(tflite::profiling::CreatePlatformProfiler());
+#endif
+
for (int subgraph_index = 0; subgraph_index < subgraphs->Length();
++subgraph_index) {
const tflite::SubGraph* subgraph = (*subgraphs)[subgraph_index];
diff --git a/tensorflow/lite/profiling/BUILD b/tensorflow/lite/profiling/BUILD
index 03dd505..94c6a3c 100644
--- a/tensorflow/lite/profiling/BUILD
+++ b/tensorflow/lite/profiling/BUILD
@@ -23,6 +23,31 @@
],
)
+cc_library(
+ name = "atrace_profiler",
+ srcs = ["atrace_profiler.cc"],
+ hdrs = ["atrace_profiler.h"],
+ copts = common_copts,
+ visibility = ["//visibility:private"],
+ deps = [
+ "//tensorflow/lite/core/api",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "platform_profiler",
+ srcs = ["platform_profiler.cc"],
+ hdrs = ["platform_profiler.h"],
+ copts = common_copts,
+ deps = [
+ "//tensorflow/lite/core/api",
+ ] + select({
+ "//tensorflow:android": [":atrace_profiler"],
+ "//conditions:default": [],
+ }),
+)
+
cc_test(
name = "profiler_test",
srcs = ["profiler_test.cc"],
diff --git a/tensorflow/lite/profiling/atrace_profiler.cc b/tensorflow/lite/profiling/atrace_profiler.cc
new file mode 100644
index 0000000..8fe3641
--- /dev/null
+++ b/tensorflow/lite/profiling/atrace_profiler.cc
@@ -0,0 +1,72 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/profiling/atrace_profiler.h"
+
+#include <dlfcn.h>
+
+#include "absl/strings/str_cat.h"
+
+namespace tflite {
+namespace profiling {
+
+ATraceProfiler::ATraceProfiler() {
+ handle_ = dlopen("libandroid.so", RTLD_NOW | RTLD_LOCAL);
+ if (handle_) {
+ // Use dlsym() to prevent crashes on devices running Android 5.1
+ // (API level 22) or lower.
+ atrace_is_enabled_ =
+ reinterpret_cast<FpIsEnabled>(dlsym(handle_, "ATrace_isEnabled"));
+ atrace_begin_section_ =
+ reinterpret_cast<FpBeginSection>(dlsym(handle_, "ATrace_beginSection"));
+ atrace_end_section_ =
+ reinterpret_cast<FpEndSection>(dlsym(handle_, "ATrace_endSection"));
+
+ if (!atrace_is_enabled_ || !atrace_begin_section_ || !atrace_end_section_) {
+ dlclose(handle_);
+ handle_ = nullptr;
+ }
+ }
+}
+
+ATraceProfiler::~ATraceProfiler() {
+ if (handle_) {
+ dlclose(handle_);
+ }
+}
+
+uint32_t ATraceProfiler::BeginEvent(const char* tag, EventType event_type,
+ uint32_t event_metadata,
+ uint32_t event_subgraph_index) {
+ if (handle_ && atrace_is_enabled_()) {
+ // Note: When recording an OPERATOR_INVOKE_EVENT, we have recorded the op
+ // name as tag and node index as event_metadata. See the macro
+ // TFLITE_SCOPED_TAGGED_OPERATOR_PROFILE defined in
+ // tensorflow/lite/core/api/profiler.h for details.
+ // op_name@node_index/subgraph_index
+ std::string trace_event_tag =
+ absl::StrCat(tag, "@", event_metadata, "/", event_subgraph_index);
+ atrace_begin_section_(trace_event_tag.c_str());
+ }
+ return 0;
+}
+
+void ATraceProfiler::EndEvent(uint32_t event_handle) {
+ if (handle_) {
+ atrace_end_section_();
+ }
+}
+
+} // namespace profiling
+} // namespace tflite
diff --git a/tensorflow/lite/profiling/atrace_profiler.h b/tensorflow/lite/profiling/atrace_profiler.h
new file mode 100644
index 0000000..fcfb9f8
--- /dev/null
+++ b/tensorflow/lite/profiling/atrace_profiler.h
@@ -0,0 +1,53 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_PROFILING_ATRACE_PROFILER_H_
+#define TENSORFLOW_LITE_PROFILING_ATRACE_PROFILER_H_
+
+#include <type_traits>
+
+#include "tensorflow/lite/core/api/profiler.h"
+
+namespace tflite {
+namespace profiling {
+
+// Profiler reporting to ATrace.
+class ATraceProfiler : public tflite::Profiler {
+ public:
+ ATraceProfiler();
+
+ ~ATraceProfiler() override;
+
+ uint32_t BeginEvent(const char* tag, EventType event_type,
+ uint32_t event_metadata,
+ uint32_t event_subgraph_index) override;
+
+ void EndEvent(uint32_t event_handle) override;
+
+ private:
+ using FpIsEnabled = std::add_pointer<bool()>::type;
+ using FpBeginSection = std::add_pointer<void(const char*)>::type;
+ using FpEndSection = std::add_pointer<void()>::type;
+
+ // Handle to libandroid.so library. Null if not supported.
+ void* handle_;
+ FpIsEnabled atrace_is_enabled_;
+ FpBeginSection atrace_begin_section_;
+ FpEndSection atrace_end_section_;
+};
+
+} // namespace profiling
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_PROFILING_ATRACE_PROFILER_H_
diff --git a/tensorflow/lite/profiling/platform_profiler.cc b/tensorflow/lite/profiling/platform_profiler.cc
new file mode 100644
index 0000000..bbf5e17
--- /dev/null
+++ b/tensorflow/lite/profiling/platform_profiler.cc
@@ -0,0 +1,37 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/profiling/platform_profiler.h"
+
+#include <memory>
+
+#include "tensorflow/lite/core/api/profiler.h"
+
+#if defined(__ANDROID__)
+#include "tensorflow/lite/profiling/atrace_profiler.h"
+#endif
+
+namespace tflite {
+namespace profiling {
+
+std::unique_ptr<tflite::Profiler> CreatePlatformProfiler() {
+#if defined(__ANDROID__)
+ return std::unique_ptr<tflite::Profiler>(new ATraceProfiler());
+#else
+ return std::unique_ptr<tflite::Profiler>(nullptr);
+#endif
+}
+
+} // namespace profiling
+} // namespace tflite
diff --git a/tensorflow/lite/profiling/platform_profiler.h b/tensorflow/lite/profiling/platform_profiler.h
new file mode 100644
index 0000000..87361b3
--- /dev/null
+++ b/tensorflow/lite/profiling/platform_profiler.h
@@ -0,0 +1,30 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_PROFILING_PLATFORM_PROFILER_H_
+#define TENSORFLOW_LITE_PROFILING_PLATFORM_PROFILER_H_
+
+#include <memory>
+
+#include "tensorflow/lite/core/api/profiler.h"
+
+namespace tflite {
+namespace profiling {
+
+std::unique_ptr<tflite::Profiler> CreatePlatformProfiler();
+
+} // namespace profiling
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_PROFILING_PLATFORM_PROFILER_H_