cupti tracer allow device synchronization before gputracer stop.

PiperOrigin-RevId: 312702777
Change-Id: Ied5df1fb045c6e2c35ae0b63dc73fb04be54104f
diff --git a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc
index 9119c3d..51f89bd 100644
--- a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc
+++ b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/core/profiler/internal/gpu/cupti_tracer.h"
 
 #include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
 #include "absl/container/node_hash_map.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/errors.h"
@@ -614,15 +615,42 @@
     // Grab timestamp for API exit. API entry timestamp saved in cbdata.
     uint64 end_tsc = CuptiTracer::GetTimestamp();
     uint64 start_tsc = *cbdata->correlationData;
+    TrackContext(cbid, cbdata->context);
     return AddDriverApiCallbackEvent(collector_, cupti_interface_, device_id,
                                      start_tsc, end_tsc, domain, cbid, cbdata);
   }
-  Status Flush() override { return Status::OK(); }
+  Status SyncAndFlush() override {
+    if (option_.sync_devices_before_stop) {
+      CuptiApiTracingDisabler disabler;
+      absl::MutexLock lock(&mutex_);
+      for (auto &ctx : contexts_) {
+        cuCtxPushCurrent(ctx);
+        cuCtxSynchronize();  // Ignore error here for best effort.
+        CUcontext current;
+        cuCtxPopCurrent(&current);
+      }
+    }
+    return Status::OK();
+  }
 
  private:
+  void TrackContext(CUpti_CallbackId cbid, CUcontext ctx) {
+    if (!option_.sync_devices_before_stop) return;
+    if (ctx == NULL) return;
+    absl::MutexLock lock(&mutex_);
+    if (cbid == CUPTI_DRIVER_TRACE_CBID_cuCtxDestroy_v2 ||
+        cbid == CUPTI_DRIVER_TRACE_CBID_cuCtxDestroy) {
+      contexts_.erase(ctx);
+    } else {
+      contexts_.emplace(ctx);
+    }
+  }
+
   const CuptiTracerOptions option_;
   CuptiInterface *cupti_interface_;
   CuptiTraceCollector *collector_;
+  absl::Mutex mutex_;
+  absl::flat_hash_set<CUcontext> contexts_ TF_GUARDED_BY(mutex_);
 
   TF_DISALLOW_COPY_AND_ASSIGN(CuptiDriverApiHookWithActivityApi);
 };
@@ -1158,7 +1186,7 @@
     return AddDriverApiCallbackEvent(collector_, cupti_interface_, device_id,
                                      start_tsc, end_tsc, domain, cbid, cbdata);
   }
-  Status Flush() override {
+  Status SyncAndFlush() override {
     for (auto &recorder : cuda_event_recorders_) {
       TF_RETURN_IF_ERROR(recorder->Stop());
     }
@@ -1397,7 +1425,7 @@
   }
   cupti_interface_->CleanUp();
   Finalize().IgnoreError();
-  cupti_driver_api_hook_->Flush().IgnoreError();
+  cupti_driver_api_hook_->SyncAndFlush().IgnoreError();
   collector_->Flush();
   collector_ = nullptr;
   option_.reset();
diff --git a/tensorflow/core/profiler/internal/gpu/cupti_tracer.h b/tensorflow/core/profiler/internal/gpu/cupti_tracer.h
index e236afc..a62c080 100644
--- a/tensorflow/core/profiler/internal/gpu/cupti_tracer.h
+++ b/tensorflow/core/profiler/internal/gpu/cupti_tracer.h
@@ -147,6 +147,8 @@
   std::vector<CUpti_ActivityKind> activities_selected;
   // Whether to call cuptiFinalize.
   bool cupti_finalize = false;
+  // Whether to call cuCtxSynchronize for each device before Stop().
+  bool sync_devices_before_stop = false;
 };
 
 struct CuptiTracerCollectorOptions {
@@ -219,7 +221,7 @@
   virtual Status OnDriverApiExit(int device_id, CUpti_CallbackDomain domain,
                                  CUpti_CallbackId cbid,
                                  const CUpti_CallbackData* callback_info) = 0;
-  virtual Status Flush() = 0;
+  virtual Status SyncAndFlush() = 0;
 
  protected:
   static Status AddDriverApiCallbackEvent(