Make profiler C APIs portable by serializing opaque protobufs into a user provided buffer.
WHAT
* Make profiler APIs in libtpu portable.
* CollectData() takes in a user sized buffer that is big enough to contain a serialized XSpace from the TPU driver.
WHY
* Minimize number of APIs to maintain.
* Minimize number of serialize and deserializations of XSpace.
* Eliminate incompatibilities and crashes as a result of passing protobufs at shared library boundaries.
* Untangle ownership and lifetime of resources; buffer for serializing is used and owned by the client, collected XSpace is owned by driver until after serialization.
Misc clean ups:
* Fix mismatch of TpuProfiler_Free definition vs declaration.
* Fixed a leak in TpuProfiler_Create() on error conditions.
Note: TPU driver refers to libtpu.
PiperOrigin-RevId: 355299692
Change-Id: Ie37c295c20c29bb511e9e969d785de57fe1446fd
diff --git a/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc b/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc
index a156c17..dbe32ee 100644
--- a/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc
+++ b/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc
@@ -58,16 +58,22 @@
};
TpuTracer::TpuTracer() {
- tpu_profiler_ = tpu::OpsApiFn()->TpuProfiler_CreateFn();
+ StatusHelper status;
+ tpu::OpsApiFn()->TpuProfiler_CreateFn(&tpu_profiler_, status.c_status);
+ if (!status.ok()) {
+ LOG(ERROR) << status.status().error_message();
+ }
}
-TpuTracer::~TpuTracer() { tpu::OpsApiFn()->TpuProfiler_FreeFn(tpu_profiler_); }
+TpuTracer::~TpuTracer() {
+ tpu::OpsApiFn()->TpuProfiler_DestroyFn(tpu_profiler_);
+}
Status TpuTracer::Start() {
StatusHelper status;
tpu::OpsApiFn()->TpuProfiler_StartFn(tpu_profiler_, status.c_status);
if (!status.ok()) {
- VLOG(1) << "Run Start failed.";
+ LOG(ERROR) << "TPU tracer failed to start.";
return status.status();
}
return Status::OK();
@@ -77,7 +83,7 @@
StatusHelper status;
tpu::OpsApiFn()->TpuProfiler_StopFn(tpu_profiler_, status.c_status);
if (!status.ok()) {
- VLOG(1) << "Run Stop failed.";
+ LOG(ERROR) << "TPU tracer failed to stop.";
return status.status();
}
return Status::OK();
@@ -90,10 +96,21 @@
Status TpuTracer::CollectData(XSpace* space) {
StatusHelper status;
+ // Get size of buffer required for TPU driver to serialize XSpace into.
+ size_t size_in_bytes;
tpu::OpsApiFn()->TpuProfiler_CollectDataFn(tpu_profiler_, status.c_status,
- space);
+ /*buffer=*/nullptr,
+ &size_in_bytes);
+ // Prepare an appropriately sized buffer.
+ if (size_in_bytes > 0) {
+ std::vector<uint8_t> buffer(size_in_bytes);
+ tpu::OpsApiFn()->TpuProfiler_CollectDataFn(tpu_profiler_, status.c_status,
+ buffer.data(), &size_in_bytes);
+ // Deserialize XSpace from the buffer and return it.
+ space->ParseFromArray(buffer.data(), buffer.size());
+ }
if (!status.ok()) {
- VLOG(1) << "Run CollectData failed.";
+ LOG(ERROR) << "TPU tracer failed to collect data.";
return status.status();
}
return Status::OK();
diff --git a/tensorflow/core/tpu/tpu_api.cc b/tensorflow/core/tpu/tpu_api.cc
index 690e204..339e8ef 100644
--- a/tensorflow/core/tpu/tpu_api.cc
+++ b/tensorflow/core/tpu/tpu_api.cc
@@ -23,7 +23,7 @@
return &base_fn;
}
-TfTpu_OpsApiFn* OpsApiFn() {
+const TfTpu_OpsApiFn* OpsApiFn() {
static TfTpu_OpsApiFn ops_api_fn;
return &ops_api_fn;
}
diff --git a/tensorflow/core/tpu/tpu_api.h b/tensorflow/core/tpu/tpu_api.h
index b880f4e..45ada40 100644
--- a/tensorflow/core/tpu/tpu_api.h
+++ b/tensorflow/core/tpu/tpu_api.h
@@ -25,7 +25,7 @@
TfTpu_BaseFn* InitializeApiFn();
-TfTpu_OpsApiFn* OpsApiFn();
+const TfTpu_OpsApiFn* OpsApiFn();
} // namespace tpu
} // namespace tensorflow
diff --git a/tensorflow/core/tpu/tpu_library_init_fns.inc b/tensorflow/core/tpu/tpu_library_init_fns.inc
index 0b984fa..340077f 100644
--- a/tensorflow/core/tpu/tpu_library_init_fns.inc
+++ b/tensorflow/core/tpu/tpu_library_init_fns.inc
@@ -7,7 +7,9 @@
namespace {
tensorflow::Status SetTpuOpsStructFns(void* library_handle) {
- auto* ops_api_fn = tensorflow::tpu::OpsApiFn();
+ // Constant cast so that we can initialize the functions. The functions are
+ // mutable here because this is the only place where they are initialized.
+ auto* ops_api_fn = const_cast<TfTpu_OpsApiFn*>(tensorflow::tpu::OpsApiFn());
TFTPU_SET_FN(ops_api_fn, ConfigureDistributedTpuOp_DoWork);
TFTPU_SET_FN(ops_api_fn, WaitForDistributedTpuOp_DoWork);
@@ -70,9 +72,9 @@
TFTPU_SET_FN(ops_api_fn, TpuCompile_CreateCompilationCacheKey);
TFTPU_SET_FN(ops_api_fn, TpuCompile_DestroyCompilationCacheKey);
TFTPU_SET_FN(ops_api_fn, TpuCompile_CreateGuaranteedConstFingerprint);
-
+
TFTPU_SET_FN(ops_api_fn, TpuProfiler_Create);
- TFTPU_SET_FN(ops_api_fn, TpuProfiler_Free);
+ TFTPU_SET_FN(ops_api_fn, TpuProfiler_Destroy);
TFTPU_SET_FN(ops_api_fn, TpuProfiler_Start);
TFTPU_SET_FN(ops_api_fn, TpuProfiler_Stop);
TFTPU_SET_FN(ops_api_fn, TpuProfiler_CollectData);
diff --git a/tensorflow/core/tpu/tpu_ops_c_api.h b/tensorflow/core/tpu/tpu_ops_c_api.h
index f361110..a84579c 100644
--- a/tensorflow/core/tpu/tpu_ops_c_api.h
+++ b/tensorflow/core/tpu/tpu_ops_c_api.h
@@ -19,7 +19,6 @@
#include <cstdint>
-#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/tpu/libtftpu.h"
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h"
@@ -106,20 +105,40 @@
TpuSerializedProto xrt_computation, const XLA_TpuMeshState* mesh_state,
XLA_TpuProgram** tpu_programs[], size_t* count, TF_Status* status);
-// Creates a new TPU profiler object.
-TFTPU_CAPI_EXPORT TpuProfiler* TpuProfiler_Create();
-
-TFTPU_CAPI_EXPORT TpuProfiler* TpuProfiler_Free(TpuProfiler* tpu_profiler);
-
+// Creates a TPU profiler that is ready to start profiling.
+TFTPU_CAPI_EXPORT void TpuProfiler_Create(TpuProfiler** tpu_profiler,
+ TF_Status* status);
+// Destroys the given TPU profiler.
+TFTPU_CAPI_EXPORT void TpuProfiler_Destroy(TpuProfiler* tpu_profiler);
+// Starts profiling if not already started, returns an error otherwise.
TFTPU_CAPI_EXPORT void TpuProfiler_Start(TpuProfiler* tpu_profiler,
TF_Status* status);
-
+// Stops profiling if not already stopped, returns an error otherwise.
TFTPU_CAPI_EXPORT void TpuProfiler_Stop(TpuProfiler* tpu_profiler,
TF_Status* status);
-
-TFTPU_CAPI_EXPORT void TpuProfiler_CollectData(
- TpuProfiler* tpu_profiler, TF_Status* status,
- tensorflow::profiler::XSpace* space);
+// Serializes profiled data into `buffer` and returns the size of `buffer`. The
+// profile data held by the TPU driver will be cleared after retrieval.
+//
+// Step 1. Query the size of buffer required into `size_in_bytes`.
+//
+// size_t size_in_bytes;
+// TpuProfiler_CollectData(profiler, status, nullptr, &size_in_bytes);
+//
+// Step 2. Retrieve the data into a `buffer` of size `size_in_bytes`.
+// Subsequently,The TPU driver clears its copy of the profile data.
+//
+// uint8_t buffer = new uint8_t[size_in_bytes];
+// TpuProfiler_CollectData(profiler, status, buffer, size_in_bytes);
+//
+// Step 3. Unpack the data into an XSpace.
+//
+// tensorflow::profiler::XSpace space;
+// space.ParseFromArray(buffer, size_in_bytes);
+//
+TFTPU_CAPI_EXPORT void TpuProfiler_CollectData(TpuProfiler* tpu_profiler,
+ TF_Status* status,
+ uint8_t* buffer,
+ size_t* size_in_bytes);
// Creates a new TPU mesh state object.
TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create();
@@ -416,7 +435,7 @@
TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState);
TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Create);
- TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Free);
+ TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Destroy);
TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Start);
TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Stop);
TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_CollectData);