Add unit test for XNNPACK Delegate Plugin C API to verify that it sets the thread count correctly.
PiperOrigin-RevId: 408359847
Change-Id: Iae7bde320784dd81b552e95b73f48903ab6e23a0
diff --git a/tensorflow/lite/experimental/acceleration/configuration/c/BUILD b/tensorflow/lite/experimental/acceleration/configuration/c/BUILD
index 4fe7d2b..dc8d144 100644
--- a/tensorflow/lite/experimental/acceleration/configuration/c/BUILD
+++ b/tensorflow/lite/experimental/acceleration/configuration/c/BUILD
@@ -90,7 +90,9 @@
deps = [
":xnnpack_plugin",
"//tensorflow/lite/c:common",
+ "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
"//tensorflow/lite/experimental/acceleration/configuration:configuration_fbs",
"@com_google_googletest//:gtest_main",
+ "@pthreadpool",
],
)
diff --git a/tensorflow/lite/experimental/acceleration/configuration/c/xnnpack_plugin_test.cc b/tensorflow/lite/experimental/acceleration/configuration/c/xnnpack_plugin_test.cc
index 24de431..2a04ad8 100644
--- a/tensorflow/lite/experimental/acceleration/configuration/c/xnnpack_plugin_test.cc
+++ b/tensorflow/lite/experimental/acceleration/configuration/c/xnnpack_plugin_test.cc
@@ -20,18 +20,21 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
+#include "pthreadpool.h" // from @pthreadpool
#include "tensorflow/lite/c/common.h"
+#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h"
namespace tflite {
class XnnpackTest : public testing::Test {
public:
+ static constexpr int kNumThreadsForTest = 7;
void SetUp() override {
// Construct a FlatBuffer that contains
- // TFLiteSettings { XNNPackSettings { num_threads: 7 } }.
+ // TFLiteSettings { XNNPackSettings { num_threads: kNumThreadsForTest } }.
XNNPackSettingsBuilder xnnpack_settings_builder(flatbuffer_builder_);
- xnnpack_settings_builder.add_num_threads(7);
+ xnnpack_settings_builder.add_num_threads(kNumThreadsForTest);
flatbuffers::Offset<XNNPackSettings> xnnpack_settings =
xnnpack_settings_builder.Finish();
TFLiteSettingsBuilder tflite_settings_builder(flatbuffer_builder_);
@@ -50,6 +53,8 @@
const TFLiteSettings *settings_;
};
+constexpr int XnnpackTest::kNumThreadsForTest;
+
TEST_F(XnnpackTest, CanCreateAndDestroyDelegate) {
TfLiteDelegate *delegate =
TfLiteXnnpackDelegatePluginCApi()->create(settings_);
@@ -66,4 +71,13 @@
TfLiteXnnpackDelegatePluginCApi()->destroy(delegate);
}
+TEST_F(XnnpackTest, SetsCorrectThreadCount) {
+ TfLiteDelegate *delegate =
+ TfLiteXnnpackDelegatePluginCApi()->create(settings_);
+ pthreadpool_t threadpool =
+ static_cast<pthreadpool_t>(TfLiteXNNPackDelegateGetThreadPool(delegate));
+ int thread_count = pthreadpool_get_threads_count(threadpool);
+ EXPECT_EQ(thread_count, kNumThreadsForTest);
+ TfLiteXnnpackDelegatePluginCApi()->destroy(delegate);
+}
} // namespace tflite