[pytorch profiler] Add step tracker logic to handle multiple sources of step increments (#90880)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90880

# Summary
Enables multiple step trackers. Currently we only had one place to mark that a step() has occurred in the program. This was via pytorch profiler step().
We are now working on adding an Optimizer step hook - https://github.com/pytorch/pytorch/issues/88446
- This could mean programs that already call profiler.step() every iteration can end up double incrementing steps
- If a model uses multiple optimizers we can also have double or more counting of the step.

## Solution
We fix this by adding a layer of abstraction before calling step() to the kineto library. The idea is to maintain steps per requester in a dictionary
```
{
   "ProfilerStep": 100,  # triggered by profiler step() call
   "Optimizer1Step": 100,   # Optimizer 1 or 2 are just examples, could be SGD, Adam etc
   "Optimizer2Step": 100,
}
```
To figure out the global step count just take max on the dict values (100).
```
{
   "ProfilerStep": 100,
   "Optimizer1Step": 101,   # Optimizer1 got incremented first say
   "Optimizer2Step": 100,
}
```
Then global step count is 101

## Calling kineto
We only call the kineto step() function when global count increments.

# Test Plan:
Added a unit test
   buck2 run mode/dev-nosan caffe2/test:profiler

Differential Revision: D41751157

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90880
Approved by: https://github.com/chaekit
diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py
index acaa1f9..69c75a9 100644
--- a/test/profiler/test_profiler.py
+++ b/test/profiler/test_profiler.py
@@ -22,6 +22,7 @@
     _record_function_with_args_exit,
 )
 from torch.autograd.profiler import profile as _profile
+from torch.autograd.profiler import KinetoStepTracker
 from torch.autograd.profiler_legacy import profile as _profile_legacy
 from torch.profiler import (
     _utils,
@@ -996,6 +997,8 @@
             # p.export_chrome_trace("/tmp/test_trace_" + str(called_num[0]) + ".json")
             called_num[0] += 1
 
+        initial_step = KinetoStepTracker.current_step()
+
         with profile(
             activities=supported_activities(),
             schedule=torch.profiler.schedule(
@@ -1009,6 +1012,7 @@
                 p.step()
 
         self.assertEqual(called_num[0], 2)
+        self.assertEqual(KinetoStepTracker.current_step(), initial_step + 8)
 
         # case without schedule
         with profile(
@@ -1045,6 +1049,49 @@
         for step in range(len(test_schedule_expected_outputs)):
             self.assertEqual(test_schedule(step), test_schedule_expected_outputs[step])
 
+    def test_kineto_profiler_multiple_steppers(self):
+        niters = 8
+        use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
+        net = SimpleNet()
+        opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
+        opt.zero_grad()
+        inputs = torch.rand(10)
+
+        with profile(activities=supported_activities()):
+            self.payload(use_cuda=use_cuda)
+
+        def optimizer_step():
+            """This simulates a step() hook in the optimizer"""
+            KinetoStepTracker.increment_step("yet_another_step")
+
+        initial_step = KinetoStepTracker.current_step()
+
+        def run_batch():
+            out = net(inputs)
+            loss = torch.nn.functional.cross_entropy(out, torch.rand(2))
+            loss.backward()
+            opt.step()
+            # Manually call the hook. TODO: Remove this once we add the
+            # profiler step hooks in the Optimizer class that will get triggered above.
+            # See https://github.com/pytorch/pytorch/issues/88446
+            optimizer_step()
+
+        for idx in range(niters):
+            run_batch()
+
+        with profile(
+            activities=supported_activities(),
+            schedule=torch.profiler.schedule(
+                wait=1,
+                warmup=1,
+                active=2),
+        ) as p:
+            for idx in range(niters):
+                run_batch()
+                p.step()
+
+        self.assertEqual(KinetoStepTracker.current_step(), initial_step + 2 * niters)
+
     def test_export_stacks(self):
         with _profile(with_stack=True, use_kineto=kineto_available(), experimental_config=_ExperimentalConfig(verbose=True)) as p:
             x = torch.randn(10, 10)
diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py
index e70ec6c..ad87635 100644
--- a/torch/autograd/profiler.py
+++ b/torch/autograd/profiler.py
@@ -1,4 +1,5 @@
 from typing import Any, Dict, List, Optional
+from collections import defaultdict
 from warnings import warn
 
 import torch
@@ -31,7 +32,7 @@
 from torch.futures import Future
 
 __all__ = ["profile", "record_function", "emit_itt", "emit_nvtx", "load_nvprof", "EnforceUnique",
-           "parse_nvprof_trace", "kineto_step", "EventList", "FunctionEvent", "MemRecordsAcc"]
+           "parse_nvprof_trace", "KinetoStepTracker", "EventList", "FunctionEvent", "MemRecordsAcc"]
 
 try:
     # Available in Python >= 3.2
@@ -812,8 +813,75 @@
     return functions
 
 
-def kineto_step():
-    """ Notify kineto so it is aware of iteration boundaries for asynchronous
-        trace requests.
+class KinetoStepTracker:
+    """Provides an abstraction for incrementing the step count globally.
+    Previously, we only had one place to mark that a step() has occurred
+    in the program via pytorch profiler step(). We will now add step hooks
+    in the Optimizer class https://github.com/pytorch/pytorch/issues/88446
+
+    - This could mean programs that already call profiler.step() every
+      iteration can end up double incrementing step count.
+    - If a model uses multiple optimizers we can also have double or more
+      counting of the step.
+
+    We fix this by adding a layer of abstraction before calling step()
+    to the kineto library. The idea is to maintain steps per requester in a dict:
+    ```
+    {
+       "ProfilerStep": 100,  # triggered by profiler step() call
+       "Optimizer1Step": 100,   # Optimizer 1 or 2 are just examples, could be SGD, Adam etc
+       "Optimizer2Step": 100,
+    }
+    ```
+    To figure out the global step count just take the max of dict values (100).
+
+    If one of the count increments the max will go up.
+    ```
+    {
+       "ProfilerStep": 100,
+       "Optimizer1Step": 101,   # Optimizer1 got incremented first say
+       "Optimizer2Step": 100,
+    }
+    ```
+    Then global step count is 101
+    We only call the kineto step() function when global count increments.
+
+    NOTE: Please do not use the KinetoStepTracker in modules beside the Optimizer
+    for now. The result could be incorrect increments of the step count.
     """
-    _kineto_step()
+    _current_step = -1
+    _step_dict: Dict[str, int] = defaultdict(int)
+
+    @classmethod
+    def init_step_count(cls, requester: str):
+        cls._step_dict[requester] = cls._current_step
+
+    @classmethod
+    def erase_step_count(cls, requester: str) -> bool:
+        return cls._step_dict.pop(requester, None) is not None
+
+    @classmethod
+    def increment_step(cls, requester: str) -> int:
+        """Increments the step count for the requester.
+        Additionally if the max over all step counts has incremented then
+        trigger the _kineto_step()
+        returns global step count
+        """
+        if requester not in cls._step_dict:
+            cls.init_step_count(requester)
+        cls._step_dict[requester] += 1
+
+        new_step = max(cls._step_dict.values())
+        if new_step > cls._current_step:
+            delta = new_step - cls._current_step
+            if delta > 1:
+                warn("Profiler step count has increased more than 1 - "
+                     f"current_step = {cls._current_step} step dict =  {cls._step_dict}")
+            for _ in range(0, delta):
+                _kineto_step()
+            cls._current_step = new_step
+        return cls._current_step
+
+    @classmethod
+    def current_step(cls) -> int:
+        return cls._current_step
diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py
index 85d9576..9ebbc3d 100644
--- a/torch/profiler/profiler.py
+++ b/torch/profiler/profiler.py
@@ -28,7 +28,7 @@
     "profile",
     "ExecutionGraphObserver",
 ]
-
+PROFILER_STEP_NAME = "ProfilerStep"
 
 def supported_activities():
     """
@@ -496,6 +496,9 @@
             (ProfilerAction.RECORD, None): [self.stop_trace, self._trace_ready],
             (ProfilerAction.RECORD_AND_SAVE, None): [self.stop_trace, self._trace_ready]
         }
+        # Start tracking increments to profiler step, this will be used
+        # by Kineto
+        prof.KinetoStepTracker.init_step_count(PROFILER_STEP_NAME)
 
     def __enter__(self):
         self.start()
@@ -503,6 +506,7 @@
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         self.stop()
+        prof.KinetoStepTracker.erase_step_count(PROFILER_STEP_NAME)
 
     def start(self):
         self._transit_action(ProfilerAction.NONE, self.current_action)
@@ -527,8 +531,8 @@
         self.current_action = self.schedule(self.step_num)
 
         self._transit_action(prev_action, self.current_action)
+        prof.KinetoStepTracker.increment_step(PROFILER_STEP_NAME)
 
-        prof.kineto_step()
         if self.record_steps:
             self.step_rec_fn = prof.record_function("ProfilerStep#" + str(cur_step))
             self.step_rec_fn.__enter__()