[Profiler] Optimize `reportMemoryUsage` (#71538)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71538

`reportMemoryUsage` is kind of awful. It does a bunch of string writes and such that makes it VERY expensive. Just moving that work off the hot path reduces the overhead for `profile_memory` from ~6.5 us to ~1.2 us. (85% reduction in the kineto contribution to profiling overhead.)

Test Plan: Ran ubenchmark with `--op empty --stressTestKineto --kinetoProfileMemory`

Reviewed By: swolchok

Differential Revision: D32730167

fbshipit-source-id: fe18e8fa3881967cad8fa1c26c71c805e9b034e5
(cherry picked from commit 0d394cb252e6eac78626b467e0bb497d6d6ae86c)
diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp
index 206bd52..cc42e6d 100644
--- a/torch/csrc/autograd/profiler_kineto.cpp
+++ b/torch/csrc/autograd/profiler_kineto.cpp
@@ -155,6 +155,19 @@
     torch::profiler::impl::CUDAEventStub cuda_event_end_ = nullptr;
 };
 
+struct MemoryEventData {
+  int64_t start_time;
+  void* ptr;
+  int64_t alloc_size;
+  int64_t total_allocated;
+  int64_t total_reserved;
+  uint64_t threadID;
+  torch::profiler::impl::kineto::DeviceAndResource kineto_info;
+  c10::DeviceType device_type;
+  c10::DeviceIndex device_index;
+};
+static_assert(std::is_pod<MemoryEventData>::value, "Non-POD member of MemoryEventData.");
+
 // Assumption: Total threads number will not exceed 2^16-1, and total ops will
 // not exceed 2^48 -1.
 static inline uint64_t getForwardThreadKey(uint64_t tid, uint64_t seqNr) {
@@ -204,29 +217,16 @@
       int64_t total_reserved,
       c10::Device device) override {
     if (config_.profile_memory && config_.state != ProfilerState::Disabled) {
-      std::lock_guard<std::mutex> guard(state_mutex_);
-      auto start_time = getTimeUs();
-      if (cpu_trace_) {
-        torch::profiler::impl::kineto::recordThreadInfo();
-        cpu_trace_.addMemoryUsageActivity(
-            kMemoryEventName,
-            torch::profiler::impl::kineto::kineto_ids(),
-            start_time,
-            device,
-            ptr,
-            alloc_size,
-            total_allocated,
-            total_reserved);
-      }
-
-      kineto_events_.emplace_back();
-      auto& evt = kineto_events_.back();
-      evt.name(kMemoryEventName)
-          .startUs(start_time)
-          .deviceIndex(device.index())
-          .deviceType(device.type())
-          .nBytes(alloc_size)
-          .startThreadId(at::RecordFunction::currentThreadId());
+      memory_events_.push_back(
+          {getTimeUs(),
+           ptr,
+           alloc_size,
+           total_allocated,
+           total_reserved,
+           at::RecordFunction::currentThreadId(),
+           torch::profiler::impl::kineto::kineto_ids(),
+           device.type(),
+           device.index()});
     }
   }
 
@@ -264,6 +264,28 @@
 
   void materializeOpEvents() {
     std::lock_guard<std::mutex> guard(state_mutex_);
+
+    for (const auto& e : memory_events_) {
+        cpu_trace_.addMemoryUsageActivity(
+            kMemoryEventName,
+            e.kineto_info,
+            e.start_time,
+            c10::Device(e.device_type, e.device_index),
+            e.ptr,
+            e.alloc_size,
+            e.total_allocated,
+            e.total_reserved);
+
+      kineto_events_.emplace_back();
+      auto& evt = kineto_events_.back();
+      evt.name(kMemoryEventName)
+          .startUs(e.start_time)
+          .deviceIndex(e.device_index)
+          .deviceType(e.device_type)
+          .nBytes(e.alloc_size)
+          .startThreadId(e.threadID);
+    }
+
     for (const auto& e : op_events_) {
       if (e.end_us_ < e.start_us_) {
         // We initialize end_us_ to the smallest int64_t, so this means that
@@ -585,6 +607,7 @@
   uint64_t start_time_;
   std::set<torch::profiler::impl::ActivityType> activities_;
   std::deque<OpEventData> op_events_;
+  std::deque<MemoryEventData> memory_events_;
   torch::profiler::impl::kineto::TraceWrapper cpu_trace_;
   std::vector<KinetoEvent> kineto_events_;
   // Optional, if event post-processing is enabled.