[pytorch profiler] Add step tracker logic to handle multiple sources of step increments (#90880)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90880
# Summary
Enables multiple step trackers. Currently we only had one place to mark that a step() has occurred in the program. This was via pytorch profiler step().
We are now working on adding an Optimizer step hook - https://github.com/pytorch/pytorch/issues/88446
- This could mean programs that already call profiler.step() every iteration can end up double incrementing steps
- If a model uses multiple optimizers we can also have double or more counting of the step.
## Solution
We fix this by adding a layer of abstraction before calling step() to the kineto library. The idea is to maintain steps per requester in a dictionary
```
{
"ProfilerStep": 100, # triggered by profiler step() call
"Optimizer1Step": 100, # Optimizer 1 or 2 are just examples, could be SGD, Adam etc
"Optimizer2Step": 100,
}
```
To figure out the global step count just take max on the dict values (100).
```
{
"ProfilerStep": 100,
"Optimizer1Step": 101, # Optimizer1 got incremented first say
"Optimizer2Step": 100,
}
```
Then global step count is 101
## Calling kineto
We only call the kineto step() function when global count increments.
# Test Plan:
Added a unit test
buck2 run mode/dev-nosan caffe2/test:profiler
Differential Revision: D41751157
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90880
Approved by: https://github.com/chaekit
diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py
index acaa1f9..69c75a9 100644
--- a/test/profiler/test_profiler.py
+++ b/test/profiler/test_profiler.py
@@ -22,6 +22,7 @@
_record_function_with_args_exit,
)
from torch.autograd.profiler import profile as _profile
+from torch.autograd.profiler import KinetoStepTracker
from torch.autograd.profiler_legacy import profile as _profile_legacy
from torch.profiler import (
_utils,
@@ -996,6 +997,8 @@
# p.export_chrome_trace("/tmp/test_trace_" + str(called_num[0]) + ".json")
called_num[0] += 1
+ initial_step = KinetoStepTracker.current_step()
+
with profile(
activities=supported_activities(),
schedule=torch.profiler.schedule(
@@ -1009,6 +1012,7 @@
p.step()
self.assertEqual(called_num[0], 2)
+ self.assertEqual(KinetoStepTracker.current_step(), initial_step + 8)
# case without schedule
with profile(
@@ -1045,6 +1049,49 @@
for step in range(len(test_schedule_expected_outputs)):
self.assertEqual(test_schedule(step), test_schedule_expected_outputs[step])
+ def test_kineto_profiler_multiple_steppers(self):
+ niters = 8
+ use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
+ net = SimpleNet()
+ opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
+ opt.zero_grad()
+ inputs = torch.rand(10)
+
+ with profile(activities=supported_activities()):
+ self.payload(use_cuda=use_cuda)
+
+ def optimizer_step():
+ """This simulates a step() hook in the optimizer"""
+ KinetoStepTracker.increment_step("yet_another_step")
+
+ initial_step = KinetoStepTracker.current_step()
+
+ def run_batch():
+ out = net(inputs)
+ loss = torch.nn.functional.cross_entropy(out, torch.rand(2))
+ loss.backward()
+ opt.step()
+ # Manually call the hook. TODO: Remove this once we add the
+ # profiler step hooks in the Optimizer class that will get triggered above.
+ # See https://github.com/pytorch/pytorch/issues/88446
+ optimizer_step()
+
+ for idx in range(niters):
+ run_batch()
+
+ with profile(
+ activities=supported_activities(),
+ schedule=torch.profiler.schedule(
+ wait=1,
+ warmup=1,
+ active=2),
+ ) as p:
+ for idx in range(niters):
+ run_batch()
+ p.step()
+
+ self.assertEqual(KinetoStepTracker.current_step(), initial_step + 2 * niters)
+
def test_export_stacks(self):
with _profile(with_stack=True, use_kineto=kineto_available(), experimental_config=_ExperimentalConfig(verbose=True)) as p:
x = torch.randn(10, 10)
diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py
index e70ec6c..ad87635 100644
--- a/torch/autograd/profiler.py
+++ b/torch/autograd/profiler.py
@@ -1,4 +1,5 @@
from typing import Any, Dict, List, Optional
+from collections import defaultdict
from warnings import warn
import torch
@@ -31,7 +32,7 @@
from torch.futures import Future
__all__ = ["profile", "record_function", "emit_itt", "emit_nvtx", "load_nvprof", "EnforceUnique",
- "parse_nvprof_trace", "kineto_step", "EventList", "FunctionEvent", "MemRecordsAcc"]
+ "parse_nvprof_trace", "KinetoStepTracker", "EventList", "FunctionEvent", "MemRecordsAcc"]
try:
# Available in Python >= 3.2
@@ -812,8 +813,75 @@
return functions
-def kineto_step():
- """ Notify kineto so it is aware of iteration boundaries for asynchronous
- trace requests.
+class KinetoStepTracker:
+ """Provides an abstraction for incrementing the step count globally.
+ Previously, we only had one place to mark that a step() has occurred
+ in the program via pytorch profiler step(). We will now add step hooks
+ in the Optimizer class https://github.com/pytorch/pytorch/issues/88446
+
+ - This could mean programs that already call profiler.step() every
+ iteration can end up double incrementing step count.
+ - If a model uses multiple optimizers we can also have double or more
+ counting of the step.
+
+ We fix this by adding a layer of abstraction before calling step()
+ to the kineto library. The idea is to maintain steps per requester in a dict:
+ ```
+ {
+ "ProfilerStep": 100, # triggered by profiler step() call
+ "Optimizer1Step": 100, # Optimizer 1 or 2 are just examples, could be SGD, Adam etc
+ "Optimizer2Step": 100,
+ }
+ ```
+ To figure out the global step count just take the max of dict values (100).
+
+ If one of the count increments the max will go up.
+ ```
+ {
+ "ProfilerStep": 100,
+ "Optimizer1Step": 101, # Optimizer1 got incremented first say
+ "Optimizer2Step": 100,
+ }
+ ```
+ Then global step count is 101
+ We only call the kineto step() function when global count increments.
+
+ NOTE: Please do not use the KinetoStepTracker in modules beside the Optimizer
+ for now. The result could be incorrect increments of the step count.
"""
- _kineto_step()
+ _current_step = -1
+ _step_dict: Dict[str, int] = defaultdict(int)
+
+ @classmethod
+ def init_step_count(cls, requester: str):
+ cls._step_dict[requester] = cls._current_step
+
+ @classmethod
+ def erase_step_count(cls, requester: str) -> bool:
+ return cls._step_dict.pop(requester, None) is not None
+
+ @classmethod
+ def increment_step(cls, requester: str) -> int:
+ """Increments the step count for the requester.
+ Additionally if the max over all step counts has incremented then
+ trigger the _kineto_step()
+ returns global step count
+ """
+ if requester not in cls._step_dict:
+ cls.init_step_count(requester)
+ cls._step_dict[requester] += 1
+
+ new_step = max(cls._step_dict.values())
+ if new_step > cls._current_step:
+ delta = new_step - cls._current_step
+ if delta > 1:
+ warn("Profiler step count has increased more than 1 - "
+ f"current_step = {cls._current_step} step dict = {cls._step_dict}")
+ for _ in range(0, delta):
+ _kineto_step()
+ cls._current_step = new_step
+ return cls._current_step
+
+ @classmethod
+ def current_step(cls) -> int:
+ return cls._current_step
diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py
index 85d9576..9ebbc3d 100644
--- a/torch/profiler/profiler.py
+++ b/torch/profiler/profiler.py
@@ -28,7 +28,7 @@
"profile",
"ExecutionGraphObserver",
]
-
+PROFILER_STEP_NAME = "ProfilerStep"
def supported_activities():
"""
@@ -496,6 +496,9 @@
(ProfilerAction.RECORD, None): [self.stop_trace, self._trace_ready],
(ProfilerAction.RECORD_AND_SAVE, None): [self.stop_trace, self._trace_ready]
}
+ # Start tracking increments to profiler step, this will be used
+ # by Kineto
+ prof.KinetoStepTracker.init_step_count(PROFILER_STEP_NAME)
def __enter__(self):
self.start()
@@ -503,6 +506,7 @@
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
+ prof.KinetoStepTracker.erase_step_count(PROFILER_STEP_NAME)
def start(self):
self._transit_action(ProfilerAction.NONE, self.current_action)
@@ -527,8 +531,8 @@
self.current_action = self.schedule(self.step_num)
self._transit_action(prev_action, self.current_action)
+ prof.KinetoStepTracker.increment_step(PROFILER_STEP_NAME)
- prof.kineto_step()
if self.record_steps:
self.step_rec_fn = prof.record_function("ProfilerStep#" + str(cur_step))
self.step_rec_fn.__enter__()