cupti tracer allow device synchronization before gputracer stop.
PiperOrigin-RevId: 312702777
Change-Id: Ied5df1fb045c6e2c35ae0b63dc73fb04be54104f
diff --git a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc
index 9119c3d..51f89bd 100644
--- a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc
+++ b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc
@@ -16,6 +16,7 @@
#include "tensorflow/core/profiler/internal/gpu/cupti_tracer.h"
#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/container/node_hash_map.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
@@ -614,15 +615,42 @@
// Grab timestamp for API exit. API entry timestamp saved in cbdata.
uint64 end_tsc = CuptiTracer::GetTimestamp();
uint64 start_tsc = *cbdata->correlationData;
+ TrackContext(cbid, cbdata->context);
return AddDriverApiCallbackEvent(collector_, cupti_interface_, device_id,
start_tsc, end_tsc, domain, cbid, cbdata);
}
- Status Flush() override { return Status::OK(); }
+ Status SyncAndFlush() override {
+ if (option_.sync_devices_before_stop) {
+ CuptiApiTracingDisabler disabler;
+ absl::MutexLock lock(&mutex_);
+ for (auto &ctx : contexts_) {
+ cuCtxPushCurrent(ctx);
+ cuCtxSynchronize(); // Ignore error here for best effort.
+ CUcontext current;
+ cuCtxPopCurrent(¤t);
+ }
+ }
+ return Status::OK();
+ }
private:
+ void TrackContext(CUpti_CallbackId cbid, CUcontext ctx) {
+ if (!option_.sync_devices_before_stop) return;
+ if (ctx == NULL) return;
+ absl::MutexLock lock(&mutex_);
+ if (cbid == CUPTI_DRIVER_TRACE_CBID_cuCtxDestroy_v2 ||
+ cbid == CUPTI_DRIVER_TRACE_CBID_cuCtxDestroy) {
+ contexts_.erase(ctx);
+ } else {
+ contexts_.emplace(ctx);
+ }
+ }
+
const CuptiTracerOptions option_;
CuptiInterface *cupti_interface_;
CuptiTraceCollector *collector_;
+ absl::Mutex mutex_;
+ absl::flat_hash_set<CUcontext> contexts_ TF_GUARDED_BY(mutex_);
TF_DISALLOW_COPY_AND_ASSIGN(CuptiDriverApiHookWithActivityApi);
};
@@ -1158,7 +1186,7 @@
return AddDriverApiCallbackEvent(collector_, cupti_interface_, device_id,
start_tsc, end_tsc, domain, cbid, cbdata);
}
- Status Flush() override {
+ Status SyncAndFlush() override {
for (auto &recorder : cuda_event_recorders_) {
TF_RETURN_IF_ERROR(recorder->Stop());
}
@@ -1397,7 +1425,7 @@
}
cupti_interface_->CleanUp();
Finalize().IgnoreError();
- cupti_driver_api_hook_->Flush().IgnoreError();
+ cupti_driver_api_hook_->SyncAndFlush().IgnoreError();
collector_->Flush();
collector_ = nullptr;
option_.reset();
diff --git a/tensorflow/core/profiler/internal/gpu/cupti_tracer.h b/tensorflow/core/profiler/internal/gpu/cupti_tracer.h
index e236afc..a62c080 100644
--- a/tensorflow/core/profiler/internal/gpu/cupti_tracer.h
+++ b/tensorflow/core/profiler/internal/gpu/cupti_tracer.h
@@ -147,6 +147,8 @@
std::vector<CUpti_ActivityKind> activities_selected;
// Whether to call cuptiFinalize.
bool cupti_finalize = false;
+ // Whether to call cuCtxSynchronize for each device before Stop().
+ bool sync_devices_before_stop = false;
};
struct CuptiTracerCollectorOptions {
@@ -219,7 +221,7 @@
virtual Status OnDriverApiExit(int device_id, CUpti_CallbackDomain domain,
CUpti_CallbackId cbid,
const CUpti_CallbackData* callback_info) = 0;
- virtual Status Flush() = 0;
+ virtual Status SyncAndFlush() = 0;
protected:
static Status AddDriverApiCallbackEvent(