Revert: Refactor helper functions to populate stream executor structs from stream_executor_test to a common file, so test_pluggable_device can use it too.
This change is a preparation for PR #45784. In the PR, the test `LibraryPluggableDeviceLoadFunctions` in //tensorflow/c:c_api_experimental_test eventually verifies that all structs in `test_pluggable_device.so` are populated properly.
PR #45784: https://github.com/tensorflow/tensorflow/pull/45784
PiperOrigin-RevId: 363205807
Change-Id: I903882bf82391ebc3493ecfa3c111dfc0916b9cb
diff --git a/tensorflow/c/experimental/stream_executor/BUILD b/tensorflow/c/experimental/stream_executor/BUILD
index 4f6e03a..e017cc4 100644
--- a/tensorflow/c/experimental/stream_executor/BUILD
+++ b/tensorflow/c/experimental/stream_executor/BUILD
@@ -76,7 +76,6 @@
deps = [
":stream_executor",
":stream_executor_internal",
- ":stream_executor_test_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/protobuf:error_codes_proto_impl_cc",
@@ -85,11 +84,3 @@
"//tensorflow/stream_executor:stream_executor_pimpl",
],
)
-
-cc_library(
- name = "stream_executor_test_util",
- srcs = ["stream_executor_test_util.cc"],
- hdrs = ["stream_executor_test_util.h"],
- visibility = ["//tensorflow:internal"],
- deps = [":stream_executor_hdrs"],
-)
diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc
index 2d41820..8c7b6cd 100644
--- a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc
+++ b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc
@@ -15,7 +15,6 @@
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
-#include "tensorflow/c/experimental/stream_executor/stream_executor_test_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
@@ -25,15 +24,200 @@
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
#include "tensorflow/stream_executor/timer.h"
+struct SP_Stream_st {
+ explicit SP_Stream_st(int id) : stream_id(id) {}
+ int stream_id;
+};
+
+struct SP_Event_st {
+ explicit SP_Event_st(int id) : event_id(id) {}
+ int event_id;
+};
+
+struct SP_Timer_st {
+ explicit SP_Timer_st(int id) : timer_id(id) {}
+ int timer_id;
+};
+
namespace stream_executor {
namespace {
+constexpr int kDeviceCount = 2;
+constexpr char kDeviceName[] = "MY_DEVICE";
+constexpr char kDeviceType[] = "GPU";
+
+/*** Create SP_StreamExecutor (with empty functions) ***/
+void allocate(const SP_Device* const device, uint64_t size,
+ int64_t memory_space, SP_DeviceMemoryBase* const mem) {}
+void deallocate(const SP_Device* const device, SP_DeviceMemoryBase* const mem) {
+}
+void* host_memory_allocate(const SP_Device* const device, uint64_t size) {
+ return nullptr;
+}
+void host_memory_deallocate(const SP_Device* const device, void* mem) {}
+TF_Bool get_allocator_stats(const SP_Device* const device,
+ SP_AllocatorStats* const stats) {
+ return true;
+}
+TF_Bool device_memory_usage(const SP_Device* const device, int64_t* const free,
+ int64_t* const total) {
+ return true;
+}
+void create_stream(const SP_Device* const device, SP_Stream* stream,
+ TF_Status* const status) {
+ stream = nullptr;
+}
+void destroy_stream(const SP_Device* const device, SP_Stream stream) {}
+void create_stream_dependency(const SP_Device* const device,
+ SP_Stream dependent, SP_Stream other,
+ TF_Status* const status) {}
+void get_stream_status(const SP_Device* const device, SP_Stream stream,
+ TF_Status* const status) {}
+void create_event(const SP_Device* const device, SP_Event* event,
+ TF_Status* const status) {
+ event = nullptr;
+}
+void destroy_event(const SP_Device* const device, SP_Event event) {}
+SE_EventStatus get_event_status(const SP_Device* const device, SP_Event event) {
+ return SE_EVENT_UNKNOWN;
+}
+void record_event(const SP_Device* const device, SP_Stream stream,
+ SP_Event event, TF_Status* const status) {}
+void wait_for_event(const SP_Device* const device, SP_Stream stream,
+ SP_Event event, TF_Status* const status) {}
+void create_timer(const SP_Device* const device, SP_Timer* timer,
+ TF_Status* const status) {}
+void destroy_timer(const SP_Device* const device, SP_Timer timer) {}
+void start_timer(const SP_Device* const device, SP_Stream stream,
+ SP_Timer timer, TF_Status* const status) {}
+void stop_timer(const SP_Device* const device, SP_Stream stream, SP_Timer timer,
+ TF_Status* const status) {}
+void memcpy_dtoh(const SP_Device* const device, SP_Stream stream,
+ void* host_dst, const SP_DeviceMemoryBase* const device_src,
+ uint64_t size, TF_Status* const status) {}
+void memcpy_htod(const SP_Device* const device, SP_Stream stream,
+ SP_DeviceMemoryBase* const device_dst, const void* host_src,
+ uint64_t size, TF_Status* const status) {}
+void sync_memcpy_dtoh(const SP_Device* const device, void* host_dst,
+ const SP_DeviceMemoryBase* const device_src,
+ uint64_t size, TF_Status* const status) {}
+void sync_memcpy_htod(const SP_Device* const device,
+ SP_DeviceMemoryBase* const device_dst,
+ const void* host_src, uint64_t size,
+ TF_Status* const status) {}
+void block_host_for_event(const SP_Device* const device, SP_Event event,
+ TF_Status* const status) {}
+void synchronize_all_activity(const SP_Device* const device,
+ TF_Status* const status) {}
+TF_Bool host_callback(const SP_Device* const device, SP_Stream stream,
+ SE_StatusCallbackFn const callback_fn,
+ void* const callback_arg) {
+ return true;
+}
+
+void PopulateDefaultStreamExecutor(SP_StreamExecutor* se) {
+ *se = {SP_STREAMEXECUTOR_STRUCT_SIZE};
+ se->allocate = allocate;
+ se->deallocate = deallocate;
+ se->host_memory_allocate = host_memory_allocate;
+ se->host_memory_deallocate = host_memory_deallocate;
+ se->get_allocator_stats = get_allocator_stats;
+ se->device_memory_usage = device_memory_usage;
+ se->create_stream = create_stream;
+ se->destroy_stream = destroy_stream;
+ se->create_stream_dependency = create_stream_dependency;
+ se->get_stream_status = get_stream_status;
+ se->create_event = create_event;
+ se->destroy_event = destroy_event;
+ se->get_event_status = get_event_status;
+ se->record_event = record_event;
+ se->wait_for_event = wait_for_event;
+ se->create_timer = create_timer;
+ se->destroy_timer = destroy_timer;
+ se->start_timer = start_timer;
+ se->stop_timer = stop_timer;
+ se->memcpy_dtoh = memcpy_dtoh;
+ se->memcpy_htod = memcpy_htod;
+ se->sync_memcpy_dtoh = sync_memcpy_dtoh;
+ se->sync_memcpy_htod = sync_memcpy_htod;
+ se->block_host_for_event = block_host_for_event;
+ se->synchronize_all_activity = synchronize_all_activity;
+ se->host_callback = host_callback;
+}
+
+void PopulateDefaultDeviceFns(SP_DeviceFns* device_fns) {
+ *device_fns = {SP_DEVICE_FNS_STRUCT_SIZE};
+}
+
+/*** Create SP_TimerFns ***/
+uint64_t nanoseconds(SP_Timer timer) { return timer->timer_id; }
+
+void PopulateDefaultTimerFns(SP_TimerFns* timer_fns) {
+ timer_fns->nanoseconds = nanoseconds;
+}
+
+/*** Create SP_Platform ***/
+void create_timer_fns(const SP_Platform* platform, SP_TimerFns* timer_fns,
+ TF_Status* status) {
+ TF_SetStatus(status, TF_OK, "");
+ PopulateDefaultTimerFns(timer_fns);
+}
+void destroy_timer_fns(const SP_Platform* platform, SP_TimerFns* timer_fns) {}
+
+void create_stream_executor(const SP_Platform* platform,
+ SE_CreateStreamExecutorParams* params,
+ TF_Status* status) {
+ TF_SetStatus(status, TF_OK, "");
+ PopulateDefaultStreamExecutor(params->stream_executor);
+}
+void destroy_stream_executor(const SP_Platform* platform,
+ SP_StreamExecutor* se) {}
+void get_device_count(const SP_Platform* platform, int* device_count,
+ TF_Status* status) {
+ TF_SetStatus(status, TF_OK, "");
+ *device_count = kDeviceCount;
+}
+void create_device(const SP_Platform* platform, SE_CreateDeviceParams* params,
+ TF_Status* status) {
+ TF_SetStatus(status, TF_OK, "");
+ params->device->struct_size = {SP_DEVICE_STRUCT_SIZE};
+}
+void destroy_device(const SP_Platform* platform, SP_Device* device) {}
+
+void create_device_fns(const SP_Platform* platform,
+ SE_CreateDeviceFnsParams* params, TF_Status* status) {
+ TF_SetStatus(status, TF_OK, "");
+ params->device_fns->struct_size = {SP_DEVICE_FNS_STRUCT_SIZE};
+}
+void destroy_device_fns(const SP_Platform* platform, SP_DeviceFns* device_fns) {
+}
+
+void PopulateDefaultPlatform(SP_Platform* platform,
+ SP_PlatformFns* platform_fns) {
+ *platform = {SP_PLATFORM_STRUCT_SIZE};
+ platform->name = kDeviceName;
+ platform->type = kDeviceType;
+ platform_fns->get_device_count = get_device_count;
+ platform_fns->create_device = create_device;
+ platform_fns->destroy_device = destroy_device;
+ platform_fns->create_device_fns = create_device_fns;
+ platform_fns->destroy_device_fns = destroy_device_fns;
+ platform_fns->create_stream_executor = create_stream_executor;
+ platform_fns->destroy_stream_executor = destroy_stream_executor;
+ platform_fns->create_timer_fns = create_timer_fns;
+ platform_fns->destroy_timer_fns = destroy_timer_fns;
+}
+
+void destroy_platform(SP_Platform* const platform) {}
+void destroy_platform_fns(SP_PlatformFns* const platform_fns) {}
/*** Registration tests ***/
TEST(StreamExecutor, SuccessfulRegistration) {
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
- test_util::PopulateDefaultPlatformRegistrationParams(params);
+ PopulateDefaultPlatform(params->platform, params->platform_fns);
+ params->destroy_platform = destroy_platform;
+ params->destroy_platform_fns = destroy_platform_fns;
};
port::Status status = InitStreamExecutorPlugin(plugin_init);
TF_ASSERT_OK(status);
@@ -41,8 +225,8 @@
MultiPlatformManager::PlatformWithName("MY_DEVICE");
TF_ASSERT_OK(maybe_platform.status());
Platform* platform = maybe_platform.ConsumeValueOrDie();
- ASSERT_EQ(platform->Name(), test_util::kDeviceName);
- ASSERT_EQ(platform->VisibleDeviceCount(), test_util::kDeviceCount);
+ ASSERT_EQ(platform->Name(), kDeviceName);
+ ASSERT_EQ(platform->VisibleDeviceCount(), kDeviceCount);
port::StatusOr<StreamExecutor*> maybe_executor =
platform->ExecutorForDevice(0);
@@ -53,8 +237,10 @@
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
- test_util::PopulateDefaultPlatformRegistrationParams(params);
+ PopulateDefaultPlatform(params->platform, params->platform_fns);
params->platform->name = nullptr;
+ params->destroy_platform = destroy_platform;
+ params->destroy_platform_fns = destroy_platform_fns;
};
port::Status status = InitStreamExecutorPlugin(plugin_init);
@@ -66,8 +252,10 @@
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
- test_util::PopulateDefaultPlatformRegistrationParams(params);
+ PopulateDefaultPlatform(params->platform, params->platform_fns);
params->platform->name = "INVALID:NAME";
+ params->destroy_platform = destroy_platform;
+ params->destroy_platform_fns = destroy_platform_fns;
};
port::Status status = InitStreamExecutorPlugin(plugin_init);
@@ -81,8 +269,10 @@
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
- test_util::PopulateDefaultPlatformRegistrationParams(params);
+ PopulateDefaultPlatform(params->platform, params->platform_fns);
params->platform->name = "INVALID/";
+ params->destroy_platform = destroy_platform;
+ params->destroy_platform_fns = destroy_platform_fns;
};
port::Status status = InitStreamExecutorPlugin(plugin_init);
@@ -95,8 +285,10 @@
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
- test_util::PopulateDefaultPlatformRegistrationParams(params);
+ PopulateDefaultPlatform(params->platform, params->platform_fns);
params->platform_fns->create_device = nullptr;
+ params->destroy_platform = destroy_platform;
+ params->destroy_platform_fns = destroy_platform_fns;
};
port::Status status = InitStreamExecutorPlugin(plugin_init);
@@ -109,8 +301,10 @@
auto plugin_init = [](SE_PlatformRegistrationParams* const params,
TF_Status* const status) -> void {
TF_SetStatus(status, TF_OK, "");
- test_util::PopulateDefaultPlatformRegistrationParams(params);
+ PopulateDefaultPlatform(params->platform, params->platform_fns);
params->platform->supports_unified_memory = true;
+ params->destroy_platform = destroy_platform;
+ params->destroy_platform_fns = destroy_platform_fns;
};
port::Status status = InitStreamExecutorPlugin(plugin_init);
@@ -125,18 +319,18 @@
protected:
StreamExecutorTest() {}
void SetUp() override {
- test_util::PopulateDefaultPlatform(&platform_, &platform_fns_);
- test_util::PopulateDefaultDeviceFns(&device_fns_);
- test_util::PopulateDefaultStreamExecutor(&se_);
- test_util::PopulateDefaultTimerFns(&timer_fns_);
+ PopulateDefaultPlatform(&platform_, &platform_fns_);
+ PopulateDefaultDeviceFns(&device_fns_);
+ PopulateDefaultStreamExecutor(&se_);
+ PopulateDefaultTimerFns(&timer_fns_);
}
void TearDown() override {}
StreamExecutor* GetExecutor(int ordinal) {
if (!cplatform_) {
cplatform_ = absl::make_unique<CPlatform>(
- platform_, test_util::DestroyPlatform, platform_fns_,
- test_util::DestroyPlatformFns, device_fns_, se_, timer_fns_);
+ platform_, destroy_platform, platform_fns_, destroy_platform_fns,
+ device_fns_, se_, timer_fns_);
}
port::StatusOr<StreamExecutor*> maybe_executor =
cplatform_->ExecutorForDevice(ordinal);
diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test_util.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test_util.cc
deleted file mode 100644
index a352bd7..0000000
--- a/tensorflow/c/experimental/stream_executor/stream_executor_test_util.cc
+++ /dev/null
@@ -1,195 +0,0 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/c/experimental/stream_executor/stream_executor_test_util.h"
-
-#include <cstring>
-
-#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
-
-namespace stream_executor {
-namespace test_util {
-
-/*** Functions for creating SP_StreamExecutor ***/
-void Allocate(const SP_Device* const device, uint64_t size,
- int64_t memory_space, SP_DeviceMemoryBase* const mem) {}
-void Deallocate(const SP_Device* const device, SP_DeviceMemoryBase* const mem) {
-}
-void* HostMemoryAllocate(const SP_Device* const device, uint64_t size) {
- return nullptr;
-}
-void HostMemoryDeallocate(const SP_Device* const device, void* mem) {}
-TF_Bool GetAllocatorStats(const SP_Device* const device,
- SP_AllocatorStats* const stats) {
- return true;
-}
-TF_Bool DeviceMemoryUsage(const SP_Device* const device, int64_t* const free,
- int64_t* const total) {
- return true;
-}
-void CreateStream(const SP_Device* const device, SP_Stream* stream,
- TF_Status* const status) {
- stream = nullptr;
-}
-void DestroyStream(const SP_Device* const device, SP_Stream stream) {}
-void CreateStreamDependency(const SP_Device* const device, SP_Stream dependent,
- SP_Stream other, TF_Status* const status) {}
-void GetStreamStatus(const SP_Device* const device, SP_Stream stream,
- TF_Status* const status) {}
-void CreateEvent(const SP_Device* const device, SP_Event* event,
- TF_Status* const status) {
- event = nullptr;
-}
-void DestroyEvent(const SP_Device* const device, SP_Event event) {}
-SE_EventStatus GetEventStatus(const SP_Device* const device, SP_Event event) {
- return SE_EVENT_UNKNOWN;
-}
-void RecordEvent(const SP_Device* const device, SP_Stream stream,
- SP_Event event, TF_Status* const status) {}
-void WaitForEvent(const SP_Device* const device, SP_Stream stream,
- SP_Event event, TF_Status* const status) {}
-void CreateTimer(const SP_Device* const device, SP_Timer* timer,
- TF_Status* const status) {}
-void DestroyTimer(const SP_Device* const device, SP_Timer timer) {}
-void StartTimer(const SP_Device* const device, SP_Stream stream, SP_Timer timer,
- TF_Status* const status) {}
-void StopTimer(const SP_Device* const device, SP_Stream stream, SP_Timer timer,
- TF_Status* const status) {}
-void MemcpyDToH(const SP_Device* const device, SP_Stream stream, void* host_dst,
- const SP_DeviceMemoryBase* const device_src, uint64_t size,
- TF_Status* const status) {}
-void MemcpyHToD(const SP_Device* const device, SP_Stream stream,
- SP_DeviceMemoryBase* const device_dst, const void* host_src,
- uint64_t size, TF_Status* const status) {}
-void SyncMemcpyDToH(const SP_Device* const device, void* host_dst,
- const SP_DeviceMemoryBase* const device_src, uint64_t size,
- TF_Status* const status) {}
-void SyncMemcpyHToD(const SP_Device* const device,
- SP_DeviceMemoryBase* const device_dst, const void* host_src,
- uint64_t size, TF_Status* const status) {}
-void BlockHostForEvent(const SP_Device* const device, SP_Event event,
- TF_Status* const status) {}
-void SynchronizeAllActivity(const SP_Device* const device,
- TF_Status* const status) {}
-TF_Bool HostCallback(const SP_Device* const device, SP_Stream stream,
- SE_StatusCallbackFn const callback_fn,
- void* const callback_arg) {
- return true;
-}
-
-void PopulateDefaultStreamExecutor(SP_StreamExecutor* se) {
- *se = {SP_STREAMEXECUTOR_STRUCT_SIZE};
- se->allocate = Allocate;
- se->deallocate = Deallocate;
- se->host_memory_allocate = HostMemoryAllocate;
- se->host_memory_deallocate = HostMemoryDeallocate;
- se->get_allocator_stats = GetAllocatorStats;
- se->device_memory_usage = DeviceMemoryUsage;
- se->create_stream = CreateStream;
- se->destroy_stream = DestroyStream;
- se->create_stream_dependency = CreateStreamDependency;
- se->get_stream_status = GetStreamStatus;
- se->create_event = CreateEvent;
- se->destroy_event = DestroyEvent;
- se->get_event_status = GetEventStatus;
- se->record_event = RecordEvent;
- se->wait_for_event = WaitForEvent;
- se->create_timer = CreateTimer;
- se->destroy_timer = DestroyTimer;
- se->start_timer = StartTimer;
- se->stop_timer = StopTimer;
- se->memcpy_dtoh = MemcpyDToH;
- se->memcpy_htod = MemcpyHToD;
- se->sync_memcpy_dtoh = SyncMemcpyDToH;
- se->sync_memcpy_htod = SyncMemcpyHToD;
- se->block_host_for_event = BlockHostForEvent;
- se->synchronize_all_activity = SynchronizeAllActivity;
- se->host_callback = HostCallback;
-}
-
-void PopulateDefaultDeviceFns(SP_DeviceFns* device_fns) {
- *device_fns = {SP_DEVICE_FNS_STRUCT_SIZE};
-}
-
-/*** Functions for creating SP_TimerFns ***/
-uint64_t Nanoseconds(SP_Timer timer) { return timer->timer_id; }
-
-void PopulateDefaultTimerFns(SP_TimerFns* timer_fns) {
- timer_fns->nanoseconds = Nanoseconds;
-}
-
-/*** Functions for creating SP_Platform ***/
-void CreateTimerFns(const SP_Platform* platform, SP_TimerFns* timer_fns,
- TF_Status* status) {
- TF_SetStatus(status, TF_OK, "");
- PopulateDefaultTimerFns(timer_fns);
-}
-void DestroyTimerFns(const SP_Platform* platform, SP_TimerFns* timer_fns) {}
-
-void CreateStreamExecutor(const SP_Platform* platform,
- SE_CreateStreamExecutorParams* params,
- TF_Status* status) {
- TF_SetStatus(status, TF_OK, "");
- PopulateDefaultStreamExecutor(params->stream_executor);
-}
-void DestroyStreamExecutor(const SP_Platform* platform, SP_StreamExecutor* se) {
-}
-void GetDeviceCount(const SP_Platform* platform, int* device_count,
- TF_Status* status) {
- TF_SetStatus(status, TF_OK, "");
- *device_count = kDeviceCount;
-}
-void CreateDevice(const SP_Platform* platform, SE_CreateDeviceParams* params,
- TF_Status* status) {
- TF_SetStatus(status, TF_OK, "");
- params->device->struct_size = {SP_DEVICE_STRUCT_SIZE};
-}
-void DestroyDevice(const SP_Platform* platform, SP_Device* device) {}
-
-void CreateDeviceFns(const SP_Platform* platform,
- SE_CreateDeviceFnsParams* params, TF_Status* status) {
- TF_SetStatus(status, TF_OK, "");
- params->device_fns->struct_size = {SP_DEVICE_FNS_STRUCT_SIZE};
-}
-void DestroyDeviceFns(const SP_Platform* platform, SP_DeviceFns* device_fns) {}
-
-void PopulateDefaultPlatform(SP_Platform* platform,
- SP_PlatformFns* platform_fns) {
- *platform = {SP_PLATFORM_STRUCT_SIZE};
- platform->name = kDeviceName;
- platform->type = kDeviceType;
- platform_fns->get_device_count = GetDeviceCount;
- platform_fns->create_device = CreateDevice;
- platform_fns->destroy_device = DestroyDevice;
- platform_fns->create_device_fns = CreateDeviceFns;
- platform_fns->destroy_device_fns = DestroyDeviceFns;
- platform_fns->create_stream_executor = CreateStreamExecutor;
- platform_fns->destroy_stream_executor = DestroyStreamExecutor;
- platform_fns->create_timer_fns = CreateTimerFns;
- platform_fns->destroy_timer_fns = DestroyTimerFns;
-}
-
-/*** Functions for creating SE_PlatformRegistrationParams ***/
-void DestroyPlatform(SP_Platform* platform) {}
-void DestroyPlatformFns(SP_PlatformFns* platform_fns) {}
-
-void PopulateDefaultPlatformRegistrationParams(
- SE_PlatformRegistrationParams* const params) {
- PopulateDefaultPlatform(params->platform, params->platform_fns);
- params->destroy_platform = DestroyPlatform;
- params->destroy_platform_fns = DestroyPlatformFns;
-}
-
-} // namespace test_util
-} // namespace stream_executor
diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test_util.h b/tensorflow/c/experimental/stream_executor/stream_executor_test_util.h
deleted file mode 100644
index 0bebf6f..0000000
--- a/tensorflow/c/experimental/stream_executor/stream_executor_test_util.h
+++ /dev/null
@@ -1,56 +0,0 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_TEST_UTIL_H_
-#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_TEST_UTIL_H_
-
-#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
-
-struct SP_Stream_st {
- explicit SP_Stream_st(int id) : stream_id(id) {}
- int stream_id;
-};
-
-struct SP_Event_st {
- explicit SP_Event_st(int id) : event_id(id) {}
- int event_id;
-};
-
-struct SP_Timer_st {
- explicit SP_Timer_st(int id) : timer_id(id) {}
- int timer_id;
-};
-
-namespace stream_executor {
-namespace test_util {
-
-constexpr int kDeviceCount = 2;
-constexpr char kDeviceName[] = "MY_DEVICE";
-constexpr char kDeviceType[] = "GPU";
-
-void PopulateDefaultStreamExecutor(SP_StreamExecutor* se);
-void PopulateDefaultDeviceFns(SP_DeviceFns* device_fns);
-void PopulateDefaultTimerFns(SP_TimerFns* timer_fns);
-void PopulateDefaultPlatform(SP_Platform* platform,
- SP_PlatformFns* platform_fns);
-void PopulateDefaultPlatformRegistrationParams(
- SE_PlatformRegistrationParams* const params);
-
-void DestroyPlatform(SP_Platform* platform);
-void DestroyPlatformFns(SP_PlatformFns* platform_fns);
-
-} // namespace test_util
-} // namespace stream_executor
-
-#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_TEST_UTIL_H_
diff --git a/tensorflow/c/experimental/stream_executor/test/BUILD b/tensorflow/c/experimental/stream_executor/test/BUILD
index c13639f..ca8bdaf 100644
--- a/tensorflow/c/experimental/stream_executor/test/BUILD
+++ b/tensorflow/c/experimental/stream_executor/test/BUILD
@@ -13,8 +13,5 @@
name = "test_pluggable_device.so",
srcs = ["test_pluggable_device.cc"],
visibility = ["//tensorflow/c:__subpackages__"],
- deps = [
- "//tensorflow/c/experimental/stream_executor:stream_executor_hdrs",
- "//tensorflow/c/experimental/stream_executor:stream_executor_test_util",
- ],
+ deps = ["//tensorflow/c/experimental/stream_executor:stream_executor_hdrs"],
)
diff --git a/tensorflow/c/experimental/stream_executor/test/test_pluggable_device.cc b/tensorflow/c/experimental/stream_executor/test/test_pluggable_device.cc
index c242340..d985f3c 100644
--- a/tensorflow/c/experimental/stream_executor/test/test_pluggable_device.cc
+++ b/tensorflow/c/experimental/stream_executor/test/test_pluggable_device.cc
@@ -14,9 +14,10 @@
==============================================================================*/
#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
-#include "tensorflow/c/experimental/stream_executor/stream_executor_test_util.h"
void SE_InitPlugin(SE_PlatformRegistrationParams* const params,
TF_Status* const status) {
- stream_executor::test_util::PopulateDefaultPlatformRegistrationParams(params);
+ params->platform->struct_size = SP_PLATFORM_STRUCT_SIZE;
+ params->platform->name = "GPU";
+ params->platform->type = "XGPU";
}