Revert "Revert "[Profiler] Move python tracing to unified event type (Part 2)""

This reverts commit 4305f8e9bda34f18eb7aacab51c63651cfc61802.

replace TEST_CUDA with torch.has_cuda

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79173

Approved by: https://github.com/ezyang
diff --git a/test/test_profiler.py b/test/test_profiler.py
index 203f478..89fb62b 100644
--- a/test/test_profiler.py
+++ b/test/test_profiler.py
@@ -1,15 +1,19 @@
 # Owner(s): ["oncall: profiler"]
 
 import collections
+import functools
 import gc
 import io
 import json
 import os
+import re
 import tempfile
 import textwrap
 import typing
 import unittest
 
+import expecttest
+
 import torch
 import torch.nn as nn
 import torch.optim
@@ -97,10 +101,32 @@
     OP_TEMPLATE = "[ {}]"
     PAD_LENGTH = len(OP_TEMPLATE.format(""))
 
+    @staticmethod
+    def test(f):
+        """Mark unit test that will be using IcicleNode to test traces.
+
+        This decorator serves two purposes. First, it provides a method name
+        that `format` can use to tell where the test runner (which is
+        environment specific) ends and the unit test begins. Second, it runs
+        the test with replicates and allows `assertTreesMatch` to adjust
+        based on which replicate is running.
+        """
+
+        @functools.wraps(f)
+        def begin_unit_test_marker(self, replicates=5):
+            try:
+                for i in range(replicates):
+                    self.icicle_replicate = i
+                    return f(self)
+            finally:
+                delattr(self, "icicle_replicate")
+        return begin_unit_test_marker
+
     @classmethod
     def format(cls, profiler, indent: int = 0):
         tree = profiler.kineto_results.experimental_event_tree()
         lines = cls.cat([cls(i).materialize() for i in tree])
+        lines = lines[min([i + 1 for i, l in enumerate(lines) if "begin_unit_test_marker" in l] or [0]):]
         out = "\n".join([textwrap.indent(l.rstrip(), " " * indent) for l in lines])
         return f"{out}\n{' ' * indent}"
 
@@ -113,6 +139,34 @@
         inputs = [[j.ljust(w) for j in i] for i, w in zip(inputs, widths)]
         return [join_str.join(i) for i in zip(*inputs)]
 
+    @staticmethod
+    def fmt_name(name: str) -> str:
+        # torch::autograd::Node relies on c10::demangle to generate names, and
+        # Windows demangles to include `struct` in the name.
+        if IS_WINDOWS:
+            name = name.replace('struct torch::autograd::AccumulateGrad', 'torch::autograd::AccumulateGrad')
+
+        match = re.match(r"(.*)\.py\(([0-9]+)\): (.*)$", name)
+        if match:
+            filename, lineno, fn = match.groups()
+
+            # This test can appear as `test/test_profiler.py` depending on
+            # where it is run from.
+            if filename.endswith(os.path.splitext(__file__)[0]):
+                filename = os.path.split(os.path.splitext(__file__)[0])[1]
+
+            # We test against a string literal, so all paths have to look like POSIX paths.
+            filename = filename.replace(os.sep, "/")
+
+            # We don't want to have to update this test every time PyTorch changes.
+            lineno = lineno if os.path.split(filename.strip())[1] == "test_profiler" else "..."
+            return f"{filename}.py({lineno}): {fn}"
+
+        return re.sub(
+            "object at 0x[0-9a-fA-F]+>",
+            "object at 0xXXXXXXXXXXXX>",
+            name)
+
     def __init__(self, event) -> None:
         self.width = 0
         self.children : typing.List[IcicleNode] = []
@@ -120,13 +174,7 @@
             self.children.append(IcicleNode(child))
             self.width += self.children[-1].width
 
-        self.name = f"{event.name()} "
-
-        # torch::autograd::Node relies on c10::demangle to generate names, and
-        # Windows demangles to include `struct` in the name.
-        if IS_WINDOWS:
-            self.name = self.name.replace('struct torch::autograd::AccumulateGrad', 'torch::autograd::AccumulateGrad')
-
+        self.name = self.fmt_name(event.name() + " ")
         self.width = max(self.width, len(self.name) + self.PAD_LENGTH)
 
     def materialize(self) -> typing.List[str]:
@@ -1109,6 +1157,64 @@
         with profile():
             self.assertEqual(profiler_type(), ActiveProfilerType.KINETO)
 
+    def assertTreesMatch(self, actual: str, expected: str):
+        # Warning: Here be dragons
+        #   Different platforms will have subtly different behavior for Python
+        #   tracing. Observed differences include:
+        #     1) Windows symbolicates names differently from posix
+        #     2) The profile callback for c_call does not fire for Tensor.__pow__
+        #        on certain platforms. This is not caused by the function tracer,
+        #        but by cPython itself.
+        #
+        # The purpose of these unit tests is to ensure that the profiler is
+        # doing reasonable things. When these platform dependent variations occur
+        # simply coerce them into a platform independent form. If you made a
+        # change in the codebase which changes the trace produced, simply use
+        # EXPECTTEST_ACCEPT=1 to update the tests to reflect the new structure.
+
+        replicate = getattr(self, "icicle_replicate", None)
+        self.assertIsNotNone(replicate, "Please annotate test with `@IcicleNode.test`")
+
+        def split(line):
+            open_count = 0
+            buffer = []
+            results = []
+            for i, char in enumerate(line):
+                buffer.append(char)
+                if char == "[":
+                    open_count += 1
+                elif char == "]":
+                    open_count -= 1
+                    if not open_count:
+                        results.append(re.sub(r"\s-*]$", " ]", "".join(buffer)))
+                        buffer.clear()
+            return results
+
+        # Best effort attempt to provide a human comprehensible summary of the
+        # difference between actual and expected.
+        if actual != expected and (not expecttest.ACCEPT or replicate > 0):
+            print(f"Replicate: {replicate}")
+            actual_lines = actual.splitlines(False)
+            expected_lines = expected.splitlines(False)
+            print(f"Lines: {len(actual_lines)} vs. {len(expected_lines)}")
+            for line_a, line_e in zip(actual_lines, expected_lines):
+                split_a, split_e = split(line_a), split(line_e)
+                if " ".join(split_a) != " ".join(split_e):
+                    print(f"  Ops: {len(split_a)} vs. {len(split_e)}")
+                    for a, e in zip(split_a, split_e):
+                        if a != e:
+                            print(f"    {a}\n    {e}\n")
+
+        # The profiler should produce deterministic results and should return
+        # to a clean state after each run. As a result, only the first
+        # replicate is allowed to update `expected`. If subsequent runs do not
+        # match it is a bug in the profiler.
+        if replicate:
+            self.assertEqual(actual, expected)
+        else:
+            self.assertExpectedInline(actual, expected, skip=1)
+
+    @IcicleNode.test
     def test_profiler_experimental_tree(self):
         t1, t2 = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
         with profile() as p:
@@ -1117,7 +1223,7 @@
             loss = (y - z) ** 2
             loss.backward()
 
-        self.assertExpectedInline(
+        self.assertTreesMatch(
             IcicleNode.format(p.profiler, 12),
             """\
             [ aten::add ][ aten::ones ----------------][ aten::sub ][ aten::pow --------------------][ aten::ones_like -------------------][ autograd::engine::evaluate_function: PowBackward0 ----------------------------------------------][ autograd::engine::evaluate_function: SubBackward0 ][ autograd::engine::evaluate_function: AddBackward0 ][ autograd::engine::evaluate_function: torch::autograd::AccumulateGrad ][ autograd::engine::evaluate_function: torch::autograd::AccumulateGrad ]
@@ -1130,6 +1236,7 @@
             """  # noqa: B950
         )
 
+    @IcicleNode.test
     def test_profiler_experimental_tree_with_record_function(self):
         with profile() as p:
             with torch.autograd.profiler.record_function("Top level Annotation"):
@@ -1147,7 +1254,7 @@
         # NB: The `aten::zeros` before the record function annotations are due to
         # `at::cpp_custom_type_hack`. When we switch to `torch::CustomClassHolder`
         # they will disappear.
-        self.assertExpectedInline(
+        self.assertTreesMatch(
             IcicleNode.format(p.profiler, 12),
             """\
             [ aten::zeros ---------------][ Top level Annotation ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------]
@@ -1160,6 +1267,7 @@
             """  # noqa: B950
         )
 
+    @IcicleNode.test
     def test_profiler_experimental_tree_with_memory(self):
         t1, t2 = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
         with profile(profile_memory=True) as p:
@@ -1168,7 +1276,7 @@
             loss = (y - z) ** 2
             loss.backward()
 
-        self.assertExpectedInline(
+        self.assertTreesMatch(
             IcicleNode.format(p.profiler, 12),
             """\
             [ aten::add ][ aten::ones ----------------][ aten::sub ][ aten::pow --------------------------------][ aten::ones_like -------------------][ autograd::engine::evaluate_function: PowBackward0 ----------------------------------------------------------------------------------------------------------------------------------------------][ autograd::engine::evaluate_function: SubBackward0 ][ autograd::engine::evaluate_function: AddBackward0 ][ autograd::engine::evaluate_function: torch::autograd::AccumulateGrad ][ autograd::engine::evaluate_function: torch::autograd::AccumulateGrad ][ [memory] ]
@@ -1182,6 +1290,88 @@
             """  # noqa: B950
         )
 
+        self.assertTreesMatch(
+            IcicleNode.format(p.profiler, 12),
+            """\
+            [ aten::add ][ aten::ones ----------------][ aten::sub ][ aten::pow --------------------------------][ aten::ones_like -------------------][ autograd::engine::evaluate_function: PowBackward0 ----------------------------------------------------------------------------------------------------------------------------------------------][ autograd::engine::evaluate_function: SubBackward0 ][ autograd::engine::evaluate_function: AddBackward0 ][ autograd::engine::evaluate_function: torch::autograd::AccumulateGrad ][ autograd::engine::evaluate_function: torch::autograd::AccumulateGrad ][ [memory] ]
+            [ [memory] ] [ aten::empty ][ aten::fill_ ][ [memory] ] [ aten::result_type ][ aten::to ][ [memory] ][ aten::empty_like ---][ aten::fill_ ][ PowBackward0 -----------------------------------------------------------------------------------------------------------------------------------------------------------------------][ [memory] ][ SubBackward0 ][ [memory] ]                         [ AddBackward0 ]                                     [ torch::autograd::AccumulateGrad -------]                              [ torch::autograd::AccumulateGrad ]
+                         [ [memory] ]                                                                            [ aten::empty_strided ]               [ aten::pow -----------------------------------------------][ aten::mul -------------------------------------------------------------------------][ aten::mul ][ [memory] ][ [memory] ]            [ aten::neg ]                                                                                             [ aten::new_empty_strided ][ aten::copy_ ]                              [ aten::detach ]
+                                                                                                                 [ [memory] ]                          [ aten::result_type ][ aten::to ][ [memory] ][ aten::copy_ ][ [memory] ][ aten::mul -------------------------------------------------][ [memory] ][ [memory] ]                                     [ [memory] ]                                                                                              [ aten::empty_strided ]                                                 [ detach ]
+                                                                                                                                                                                                                               [ aten::to --------------------------][ [memory] ][ [memory] ]                                                                                                                                                                       [ [memory] ]
+                                                                                                                                                                                                                               [ aten::_to_copy --------------------]
+                                                                                                                                                                                                                               [ aten::empty_strided ][ aten::copy_ ]
+                                                                                                                                                                                                                               [ [memory] ]
+            """  # noqa: B950
+        )
+
+    @unittest.skipIf(TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite.")
+    @unittest.skipIf(torch.has_cuda, "CUDA invokes extra Python functions.")
+    @IcicleNode.test
+    def test_profiler_experimental_tree_with_memory_and_stack(self):
+        t1, t2 = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
+        with profile(with_stack=True, profile_memory=True) as p:
+            z = torch.add(t1, t2)
+            y = torch.ones(1)
+            loss = torch.pow(y - z, 2)
+            loss.backward()
+
+        self.assertTreesMatch(
+            IcicleNode.format(p.profiler, 12),
+            """\
+            [ test_profiler.py(1312): test_profiler_experimental_tree_with_memory_and_stack ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------]
+            [ torch/profiler/profiler.py(...): __enter__ -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------][ <built-in method add of type object at 0xXXXXXXXXXXXX> ][ <built-in method ones of type object at 0xXXXXXXXXXXXX> ][ aten::sub ][ <built-in method pow of type object at 0xXXXXXXXXXXXX> ][ torch/_tensor.py(...): backward ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------][ torch/profiler/profiler.py(...): __exit__ ----------------------------------------------------------------------------------------------]
+            [ torch/profiler/profiler.py(...): start -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------][ aten::add ]                                             [ aten::ones ----------------]                             [ [memory] ] [ aten::pow --------------------------------]             [ <built-in function _has_torch_function_unary> ][ torch/autograd/__init__.py(...): backward -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------][ [memory] ][ torch/profiler/profiler.py(...): stop --------------------------------------------------------------------------------------------------]
+            [ torch/profiler/profiler.py(...): _transit_action -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------][ [memory] ]                                              [ aten::empty ][ aten::fill_ ]                                          [ aten::result_type ][ aten::to ][ [memory] ]                                                              [ <built-in function isinstance> ][ <built-in function isinstance> ][ <built-in function len> ][ torch/autograd/__init__.py(...): _tensor_or_tensors_to_tuple ][ torch/autograd/__init__.py(...): _make_grads ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------][ <built-in method numel of Tensor object at 0xXXXXXXXXXXXX> -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------]            [ torch/profiler/profiler.py(...): _transit_action ---------------------------------------------------------------------------------------]
+            [ torch/profiler/profiler.py(...): start_trace -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------]                                                          [ [memory] ]                                                                                                                                                                                                                                                                                                                                      [ <built-in function isinstance> ][ <built-in method numel of Tensor object at 0xXXXXXXXXXXXX> ][ <built-in method ones_like of type object at 0xXXXXXXXXXXXX> ][ <built-in method numel of Tensor object at 0xXXXXXXXXXXXX> ][ autograd::engine::evaluate_function: PowBackward0 ----------------------------------------------------------------------------------------------------------------------------------------------][ autograd::engine::evaluate_function: SubBackward0 ][ autograd::engine::evaluate_function: AddBackward0 ][ autograd::engine::evaluate_function: torch::autograd::AccumulateGrad ][ autograd::engine::evaluate_function: torch::autograd::AccumulateGrad ]            [ <built-in method numel of Tensor object at 0xXXXXXXXXXXXX> ][ torch/profiler/profiler.py(...): stop_trace ------------------------------]
+            [ torch/autograd/profiler.py(...): _start_trace ][ <built-in method kineto_available of PyCapsule object at 0xXXXXXXXXXXXX> ][ torch/profiler/profiler.py(...): _get_distributed_info --------------------------------------------------------]                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            [ aten::ones_like -------------------]                                                                                        [ PowBackward0 -----------------------------------------------------------------------------------------------------------------------------------------------------------------------][ [memory] ][ SubBackward0 ][ [memory] ]                         [ AddBackward0 ]                                     [ torch::autograd::AccumulateGrad -------]                              [ torch::autograd::AccumulateGrad ]                                                 [ enum.py(...): __hash__ --]                                  [ torch/autograd/profiler.py(...): __exit__ --------------------------------]
+                                                                                                                                         [ torch/distributed/__init__.py(...): is_available ][ torch/distributed/distributed_c10d.py(...): is_initialized ]                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            [ aten::empty_like ---][ aten::fill_ ]                                                                                        [ aten::pow -----------------------------------------------][ aten::mul -------------------------------------------------------------------------][ aten::mul ][ [memory] ][ [memory] ]            [ aten::neg ]                                                                                             [ aten::new_empty_strided ][ aten::copy_ ]                              [ aten::detach ]                                                                    [ <built-in function hash> ]                                  [ <built-in method _disable_profiler of PyCapsule object at 0xXXXXXXXXXXXX> ]
+                                                                                                                                         [ <built-in function hasattr> ]                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               [ aten::empty_strided ]                                                                                                       [ aten::result_type ][ aten::to ][ [memory] ][ aten::copy_ ][ [memory] ][ aten::mul -------------------------------------------------][ [memory] ][ [memory] ]                                     [ [memory] ]                                                                                              [ aten::empty_strided ]                                                 [ detach ]
+                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       [ [memory] ]                                                                                                                                                                                          [ aten::to --------------------------][ [memory] ][ [memory] ]                                                                                                                                                                       [ [memory] ]
+                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             [ aten::_to_copy --------------------]
+                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             [ aten::empty_strided ][ aten::copy_ ]
+                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             [ [memory] ]
+            """  # noqa: B950
+        )
+
+    @unittest.skipIf(TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite.")
+    @unittest.skipIf(torch.has_cuda, "CUDA invokes extra Python functions.")
+    @IcicleNode.test
+    def test_profiler_experimental_tree_with_stack_and_modules(self):
+        class MyModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.layers = [
+                    torch.nn.ReLU(),
+                    torch.nn.Linear(1, 1),
+                    torch.nn.ReLU(),
+                ]
+
+            def forward(self, x: torch.Tensor) -> torch.Tensor:
+                for l in self.layers:
+                    x = l(x)
+                return x
+
+        model = MyModule()
+        with profile(with_stack=True) as p:
+            for _ in range(2):
+                model(torch.ones((1,)))
+
+        self.assertTreesMatch(
+            IcicleNode.format(p.profiler, 12),
+            """\
+            [ test_profiler.py(1355): test_profiler_experimental_tree_with_stack_and_modules ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------]
+            [ torch/profiler/profiler.py(...): __enter__ -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------][ <built-in method ones of type object at 0xXXXXXXXXXXXX> ][ nn.Module: MyModule_0 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------][ <built-in method ones of type object at 0xXXXXXXXXXXXX> ][ nn.Module: MyModule_0 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------][ torch/profiler/profiler.py(...): __exit__ ------------------------------------------------------------------------------------------]
+            [ torch/profiler/profiler.py(...): start -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------][ aten::ones ----------------]                             [ <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX> ][ test_profiler.py(1349): forward --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------][ aten::ones ----------------]                             [ <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX> ][ test_profiler.py(1349): forward --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------][ torch/profiler/profiler.py(...): stop ----------------------------------------------------------------------------------------------]
+            [ torch/profiler/profiler.py(...): _transit_action -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------][ aten::empty ][ aten::fill_ ]                                                                                                           [ nn.Module: ReLU_0 ---------------------------------------------------------------------------------------------------------------------------------------------------------------------][ nn.Module: Linear_0 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------][ nn.Module: ReLU_1 ---------------------------------------------------------------------------------------------------------------------------------------------------------------------][ aten::empty ][ aten::fill_ ]                                                                                                           [ nn.Module: ReLU_0 ---------------------------------------------------------------------------------------------------------------------------------------------------------------------][ nn.Module: Linear_0 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------][ nn.Module: ReLU_1 ---------------------------------------------------------------------------------------------------------------------------------------------------------------------][ torch/profiler/profiler.py(...): _transit_action -----------------------------------------------------------------------------------]
+            [ torch/profiler/profiler.py(...): start_trace -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------]                                                                                                                                         [ <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX> ][ torch/nn/modules/activation.py(...): forward ------------------------------------------------------------][ <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX> ][ torch/nn/modules/linear.py(...): forward -----------------------------------------------------------------------------------------------------------------------------------------][ <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX> ][ torch/nn/modules/activation.py(...): forward ------------------------------------------------------------]                                                                                                                                         [ <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX> ][ torch/nn/modules/activation.py(...): forward ------------------------------------------------------------][ <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX> ][ torch/nn/modules/linear.py(...): forward -----------------------------------------------------------------------------------------------------------------------------------------][ <built-in method _get_tracing_state of PyCapsule object at 0xXXXXXXXXXXXX> ][ torch/nn/modules/activation.py(...): forward ------------------------------------------------------------][ <built-in method get of dict object at 0xXXXXXXXXXXXX> ][ torch/profiler/profiler.py(...): stop_trace ------------------------------]
+            [ torch/autograd/profiler.py(...): _start_trace ][ <built-in method kineto_available of PyCapsule object at 0xXXXXXXXXXXXX> ][ torch/profiler/profiler.py(...): _get_distributed_info --------------------------------------------------------]                                                                                                                                                                                                                       [ torch/nn/functional.py(...): relu -----------------------------------------------------------------------]                                                                              [ torch/nn/modules/module.py(...): __getattr__ ][ torch/nn/modules/module.py(...): __getattr__ ][ <built-in function linear> -------------------------------------------------------]                                                                              [ torch/nn/functional.py(...): relu -----------------------------------------------------------------------]                                                                                                                                                                                                                       [ torch/nn/functional.py(...): relu -----------------------------------------------------------------------]                                                                              [ torch/nn/modules/module.py(...): __getattr__ ][ torch/nn/modules/module.py(...): __getattr__ ][ <built-in function linear> -------------------------------------------------------]                                                                              [ torch/nn/functional.py(...): relu -----------------------------------------------------------------------][ enum.py(...): __hash__ --]                              [ torch/autograd/profiler.py(...): __exit__ --------------------------------]
+                                                                                                                                         [ torch/distributed/__init__.py(...): is_available ][ torch/distributed/distributed_c10d.py(...): is_initialized ]                                                                                                                                                                                                                       [ <built-in function _has_torch_function_unary> ][ <built-in method relu of type object at 0xXXXXXXXXXXXX> ]                                                                                                                                                                              [ aten::linear ---------------------------------------------------------------------]                                                                              [ <built-in function _has_torch_function_unary> ][ <built-in method relu of type object at 0xXXXXXXXXXXXX> ]                                                                                                                                                                                                                       [ <built-in function _has_torch_function_unary> ][ <built-in method relu of type object at 0xXXXXXXXXXXXX> ]                                                                                                                                                                              [ aten::linear ---------------------------------------------------------------------]                                                                              [ <built-in function _has_torch_function_unary> ][ <built-in method relu of type object at 0xXXXXXXXXXXXX> ][ <built-in function hash> ]                              [ <built-in method _disable_profiler of PyCapsule object at 0xXXXXXXXXXXXX> ]
+                                                                                                                                         [ <built-in function hasattr> ]                                                                                                                                                                                                                                                                                                                                                           [ aten::relu -----]                                                                                                                                                                                                                      [ aten::t ---------][ aten::matmul -----------------------------------][ aten::add_ ]                                                                                                                               [ aten::relu -----]                                                                                                                                                                                                                                                                                                                [ aten::relu -----]                                                                                                                                                                                                                      [ aten::t ---------][ aten::matmul -----------------------------------][ aten::add_ ]                                                                                                                               [ aten::relu -----]
+                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   [ aten::clamp_min ]                                                                                                                                                                                                                      [ aten::transpose -][ aten::t ---------][ aten::mv -------------------]                                                                                                                                             [ aten::clamp_min ]                                                                                                                                                                                                                                                                                                                [ aten::clamp_min ]                                                                                                                                                                                                                      [ aten::transpose -][ aten::t ---------][ aten::mv -------------------]                                                                                                                                             [ aten::clamp_min ]
+                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            [ aten::as_strided ][ aten::transpose -][ aten::empty ][ aten::addmv_ ]                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         [ aten::as_strided ][ aten::transpose -][ aten::empty ][ aten::addmv_ ]
+                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                [ aten::as_strided ]                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            [ aten::as_strided ]
+            """  # noqa: B950
+        )
 
 if __name__ == '__main__':
     run_tests()
diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp
index 6565ae4..9294d45 100644
--- a/torch/csrc/autograd/profiler_kineto.cpp
+++ b/torch/csrc/autograd/profiler_kineto.cpp
@@ -45,7 +45,6 @@
 namespace profiler {
 
 namespace {
-const std::string kMemoryEventName = "[memory]";
 // TODO: consider TLS (tid + tls counter)
 uint64_t next_correlation_id() {
   static std::atomic<uint64_t> corr_id_{1};
@@ -59,44 +58,7 @@
   return torch::profiler::impl::getTime() / 1000;
 #endif // USE_KINETO
 }
-} // namespace
 
-namespace python_tracer {
-using torch::profiler::impl::python_tracer::PyTraceEvent;
-using torch::profiler::impl::python_tracer::PythonTracerBase;
-
-// We do not want `getTimeUs` to be directly visible, but we need a way for
-// the python tracer to use the same timing convention as the profiler.
-int64_t now() {
-  return getTimeUs();
-}
-
-struct Replay {
-  PyTraceEvent* frame_;
-  bool enter_;
-
-  C10_NODISCARD int64_t t() const {
-    return enter_ ? frame_->startTime_ : frame_->endTime_;
-  }
-
-  C10_NODISCARD size_t idx() const {
-    return enter_ ? frame_->call_idx_ : frame_->return_idx_;
-  }
-
-  bool operator<(const Replay& other) const {
-    return idx() < other.idx();
-  }
-};
-
-void _push_reverse_order(PyTraceEvent* e, std::vector<std::string>& names) {
-  if (e != nullptr) {
-    _push_reverse_order(e->parent_, names);
-    names.push_back(e->name_);
-  }
-}
-} // namespace python_tracer
-
-namespace {
 using torch::profiler::impl::ProfilerThreadLocalStateBase;
 using torch::profiler::impl::ActiveProfilerType;
 using torch::profiler::impl::EventType;
@@ -106,14 +68,18 @@
 using torch::profiler::impl::shapesToStr;
 using torch::profiler::impl::dtypesToStr;
 using torch::profiler::impl::stacksToStr;
+using torch::profiler::impl::kineto::KinetoActivityType;
 
 struct EventFieldsVisitor {
   EventFieldsVisitor(
       std::shared_ptr<Result>& result,
       KinetoEvent& kineto_event,
       const post_process_t& post_process)
-      : kineto_event_{kineto_event}, post_process_{post_process} {
+      : kineto_event_{kineto_event},
+        post_process_{post_process} {
+    pushPythonMetadata(result->parent_.lock());
     c10::visit(*this, result->extra_fields_);
+    handleStack(result->parent_);
   }
 
   void operator()(ExtraFields<EventType::TorchOp>& op_event) {
@@ -214,8 +180,69 @@
     }
   }
 
+  void operator()(const ExtraFields<EventType::PyCall>& py_call) {
+    addPythonAnnotations(py_call);
+    if (py_call.module_.has_value()) {
+      annotations_.emplace_back(
+          "Python module id", std::to_string(py_call.module_->id_));
+    }
+  }
+
+  void operator()(const ExtraFields<EventType::PyCCall>& py_call) {
+    addPythonAnnotations(py_call);
+  }
+
+  void pushPythonMetadata(std::shared_ptr<Result> parent) {
+    auto push = [&](const auto& i) {
+      c10::guts::if_constexpr<std::is_base_of<
+          torch::profiler::impl::PyExtraFieldsBase,
+          typename std::remove_reference<decltype(i)>::type>::
+                                  value>([&](auto _) {
+        py_metadata_.push_back({_(i).id_, _(i).python_tid_, parent->name()});
+      });
+    };
+
+    while (parent != nullptr) {
+      c10::visit(push, parent->extra_fields_);
+      parent = parent->parent_.lock();
+    }
+  }
+
+  template <typename T>
+  void addPythonAnnotations(T& t) {
+    annotations_.emplace_back("Python id", std::to_string(t.id_));
+    annotations_.emplace_back(
+        "Python parent id",
+        !py_metadata_.empty() ? py_metadata_.at(0).name_ : "null");
+    annotations_.emplace_back("Python thread", std::to_string(t.python_tid_));
+  }
+
+  void handleStack(std::weak_ptr<Result> parent) {
+    // JIT stack takes precidence.
+    if (!kineto_event_.get().hasStack() && !py_metadata_.empty()) {
+      std::vector<std::string> stack;
+      for (auto i = py_metadata_.rbegin(); i < py_metadata_.rend(); ++i) {
+        stack.push_back(i->name_);
+      }
+      kineto_event_.get().stack(std::move(stack));
+    }
+
+    if (kineto_event_.get().hasStack()) {
+      annotations_.emplace_back(
+          "Call stack",
+          torch::profiler::impl::stacksToStr(kineto_event_.get().stack(), ";"));
+    }
+  }
+
+  struct PythonMetadata {
+    size_t id_;
+    size_t python_tid_;
+    std::string name_;
+  };
+
   std::reference_wrapper<KinetoEvent> kineto_event_;
   std::reference_wrapper<const post_process_t> post_process_;
+  std::vector<PythonMetadata> py_metadata_;
   annotation_t annotations_;
 };
 
@@ -231,8 +258,7 @@
       std::set<torch::profiler::impl::ActivityType> activities)
       : ProfilerThreadLocalStateBase(config),
         start_time_(getTimeUs()),
-        activities_(std::move(activities)),
-        record_queue_(config),
+        record_queue_(config, activities),
         cpu_trace_(start_time_, "PyTorch Profiler") {}
   ~KinetoThreadLocalState() override = default;
 
@@ -247,10 +273,6 @@
     return ActiveProfilerType::KINETO;
   }
 
-  bool tracePython() {
-    return config().with_stack && activities_.count(ActivityType::CPU);
-  }
-
   void reportMemoryUsage(
       void* ptr,
       int64_t alloc_size,
@@ -279,9 +301,20 @@
 
   torch::profiler::impl::kineto::ActivityTraceWrapper finalizeTrace() {
     auto end_time = getTimeUs();
+    record_queue_.stop();
     materializeOpEvents();
 
     finalizeCPUTrace(cpu_trace_.get());
+
+    // `kineto_events_` does not include Python events. Instead it exposes them
+    // via the `stacks` property.
+    kineto_events_.erase(
+        std::remove_if(
+            kineto_events_.begin(),
+            kineto_events_.end(),
+            [](const auto& i) { return i.is_python_function_; }),
+        kineto_events_.end());
+
     {
       std::lock_guard<std::mutex> guard(state_mutex_);
       cpu_trace_.transferCpuTrace(end_time);
@@ -309,7 +342,8 @@
       if (e->finished_) {
         int64_t start_us = e->start_time_ns_ / 1000;
         int64_t end_us = e->endTimeNS() / 1000;
-        kineto_events_.emplace_back();
+        kineto_events_.emplace_back(
+            e->kinetoType() == KinetoActivityType::PYTHON_FUNCTION);
         kineto_events_.back()
             .name(e->name())
             .startUs(start_us)
@@ -366,129 +400,6 @@
       }
     }
     */
-
-    addPythonEvents(cpu_trace);
-  }
-
-  void addPythonEvents(std::unique_ptr<torch::profiler::impl::kineto::trace_t>& cpu_trace) {
-    if (!tracePython()) {
-      return;
-    }
-
-    auto py_events = python_tracer::PythonTracerBase::get().getEvents();
-    for (const auto& e : py_events) {
-      TORCH_INTERNAL_ASSERT(
-          !e->thread_id_,
-          "Profiler expects only single threaded Python tracing.");
-    }
-
-    // The remainder of this function merges the Python and Kineto event
-    // streams into a single stream. If Python tracing is not enabled, we want
-    // to avoid this process altogether to cut down on processing time.
-    if (!py_events.size()) {
-      return;
-    }
-
-    // Kineto event times
-    std::vector<int64_t> op_start_times;
-    for (const auto& a : cpu_trace->activities) {
-      op_start_times.push_back(a.startTime);
-    }
-    std::sort(op_start_times.begin(), op_start_times.end());
-
-    // Map PyTraceEvent* to sequential integers for JSON export.
-    ska::flat_hash_map<python_tracer::PyTraceEvent*, std::string>
-        py_event_indices_{
-            { nullptr,
-              std::string("null") }};
-    for (const auto i : c10::irange(py_events.size())) {
-      py_event_indices_.insert({py_events[i].get(), std::to_string(i)});
-    }
-
-    // Python events
-    std::vector<python_tracer::Replay> py_replay;
-    for (const auto& e : py_events) {
-      py_replay.push_back({e.get(), true});
-      py_replay.push_back({e.get(), false});
-    }
-    std::sort(py_replay.begin(), py_replay.end());
-
-    // In order to determine the state of the python interpreter when a
-    // particular op is called, we have to replay the python events and note
-    // timestamps which are associated with op start times.
-    std::vector<python_tracer::PyTraceEvent*> py_stack;
-    ska::flat_hash_map<int64_t, python_tracer::PyTraceEvent*> op_py_map;
-    auto replay_it = py_replay.begin();
-    for (auto t : op_start_times) {
-      while (replay_it != py_replay.end() && replay_it->t() <= t) {
-        if (replay_it->enter_) {
-          py_stack.push_back(replay_it->frame_);
-        } else {
-          TORCH_INTERNAL_ASSERT(py_stack.size());
-          TORCH_INTERNAL_ASSERT(py_stack.back() == replay_it->frame_);
-          py_stack.pop_back();
-        }
-        replay_it++;
-      }
-      op_py_map.insert({t, py_stack.size() ? py_stack.back() : nullptr});
-    }
-
-    std::vector<libkineto::GenericTraceActivity> py_activities;
-    auto py_events_it = py_events.begin();
-    auto py_device = libkineto::processId();
-    auto main_thread = libkineto::systemThreadId();
-    auto push_py_event = [&]() {
-      auto e = (*py_events_it).get();
-      libkineto::GenericTraceActivity op(
-          cpu_trace->span, libkineto::ActivityType::PYTHON_FUNCTION, e->name_);
-
-      op.device = py_device;
-      op.resource = main_thread;
-      op.startTime = e->startTime_;
-      op.endTime = e->endTime_;
-
-      op.addMetadata("Python id", py_event_indices_.at(e));
-      op.addMetadata("Python parent id", py_event_indices_.at(e->parent_));
-      op.addMetadata("Python thread", std::to_string(e->thread_id_));
-      if (e->module_id_.has_value()) {
-        op.addMetadata("Python module id", *e->module_id_);
-      }
-
-      py_activities.push_back(op);
-      py_events_it++;
-    };
-
-    TORCH_INTERNAL_ASSERT(cpu_trace->activities.size() == kineto_events_.size());
-    for (const auto idx : c10::irange(cpu_trace->activities.size())) {
-      auto& activity = cpu_trace->activities[idx];
-
-      // Add any python events that occurred between this Kineto event and the
-      // previous Kineto event.
-      while (py_events_it != py_events.end() &&
-             (*py_events_it)->endTime_ <= activity.endTime) {
-        push_py_event();
-      }
-
-      auto python_caller = op_py_map.at(activity.startTime);
-      activity.addMetadata(
-          "python_caller_id", py_event_indices_.at(python_caller));
-
-      // If the kineto event has a stack that means the JIT model has a stack
-      // associated with it that we need to respect.
-      if (!kineto_events_[idx].hasStack()) {
-        std::vector<std::string> py_names;
-        python_tracer::_push_reverse_order(python_caller, py_names);
-        kineto_events_[idx].stack(py_names);
-        activity.addMetadata("Call stack", torch::profiler::impl::stacksToStr(py_names, ";"));
-      }
-    }
-
-    // Add any Python events which finish after the last Kineto event.
-    while (py_events_it != py_events.end()) {
-      push_py_event();
-    }
-
-    cpu_trace->activities.insert(cpu_trace->activities.end(), py_activities.begin(), py_activities.end());
   }
 
   void generateForwardBackwardLink(
@@ -565,7 +476,6 @@
 
   uint64_t start_time_;
   torch::profiler::impl::ApproximateClockToUnixTimeConverter clock_converter_;
-  std::set<torch::profiler::impl::ActivityType> activities_;
   torch::profiler::impl::RecordQueue record_queue_;
   torch::profiler::impl::kineto::TraceWrapper cpu_trace_;
   std::vector<KinetoEvent> kineto_events_;
@@ -768,10 +678,6 @@
     auto state = std::make_shared<KinetoThreadLocalState>(config, activities);
     c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state);
 
-    if (state->tracePython()) {
-      python_tracer::PythonTracerBase::get().start();
-    }
-
     if (activities.count(ActivityType::CPU)) {
       pushProfilingCallbacks<false>(scopes);
     }
@@ -821,15 +727,7 @@
   if (config.state == ProfilerState::KINETO ||
       config.state == ProfilerState::KINETO_GPU_FALLBACK) {
     auto kineto_state_ptr = std::static_pointer_cast<KinetoThreadLocalState>(state_ptr);
-    if (kineto_state_ptr->tracePython()) {
-      python_tracer::PythonTracerBase::get().stop();
-    }
-
     auto trace = kineto_state_ptr->finalizeTrace();
-    if (kineto_state_ptr->tracePython()) {
-      python_tracer::PythonTracerBase::get().clear();
-    }
-
     result = std::make_unique<ProfilerResult>(
         kineto_state_ptr->start_time_,
         std::move(kineto_state_ptr->kineto_events_),
diff --git a/torch/csrc/autograd/profiler_kineto.h b/torch/csrc/autograd/profiler_kineto.h
index 2f62036..6525c52 100644
--- a/torch/csrc/autograd/profiler_kineto.h
+++ b/torch/csrc/autograd/profiler_kineto.h
@@ -14,6 +14,9 @@
 using experimental_event_t = std::shared_ptr<torch::profiler::impl::Result>;
 
 struct TORCH_API KinetoEvent {
+  explicit KinetoEvent(bool is_python_function = false)
+      : is_python_function_{is_python_function} {}
+
   uint64_t startThreadId() const {
     return start_thread_id_;
   }
@@ -237,6 +240,10 @@
     return *this;
   }
 
+  bool isPythonFunction() const {
+    return is_python_function_;
+  }
+
   int64_t cudaElapsedUs() const;
 
   uint64_t start_thread_id_ = 0;
@@ -267,6 +274,7 @@
 
   torch::profiler::impl::CUDAEventStub cuda_event_start_ = nullptr;
   torch::profiler::impl::CUDAEventStub cuda_event_end_ = nullptr;
+  bool is_python_function_;
 };
 
 // Consolidating events returned directly from Kineto
@@ -368,11 +376,5 @@
     const torch::profiler::impl::ProfilerConfig& config,
     const std::set<torch::profiler::impl::ActivityType>& activities);
 
-namespace python_tracer {
-// Because we are interleaving events, the Python tracer should use the same
-// timer as the profiler.
-TORCH_API int64_t now();
-}  // namespace python_tracer
-
 } // namespace profiler
 }} // namespace torch::autograd
diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp
index 8ba7f89..e4d4dfa 100644
--- a/torch/csrc/autograd/profiler_python.cpp
+++ b/torch/csrc/autograd/profiler_python.cpp
@@ -5,11 +5,11 @@
 #include <iostream>
 #include <limits>
 #include <memory>
+#include <queue>
 #include <string>
 #include <utility>
 #include <vector>
 
-#include <fmt/format.h>
 #include <Python.h>
 #include <frameobject.h>
 
@@ -17,43 +17,24 @@
 #include <c10/util/C++17.h>
 #include <c10/util/flat_hash_map.h>
 #include <c10/util/irange.h>
-#include <c10/util/strong_type.h>
-#include <torch/csrc/autograd/profiler_kineto.h>
 #include <torch/csrc/profiler/collection.h>
 #include <torch/csrc/profiler/containers.h>
+#include <torch/csrc/profiler/util.h>
 #include <torch/csrc/utils/python_strings.h>
 #include <torch/csrc/utils/pybind.h>
 
 namespace py = pybind11;
 
 namespace torch {
-namespace autograd {
 namespace profiler {
+namespace impl {
 namespace {
 enum CallType { PyCall = 0, PyModuleCall, PyCCall };
 static constexpr size_t CallTypeSize = 3;
 
-using torch::profiler::impl::AppendOnlyList;
-using torch::profiler::impl::python_tracer::PythonTracerBase;
-using torch::profiler::impl::python_tracer::PyTraceEvent;
-
 // ============================================================================
 // == Miscellaneous structs and utils =========================================
 // ============================================================================
-struct PyFrameState {
-  int line_no_;
-  at::StringView filename_;
-  at::StringView funcname_;
-};
-
-template <typename T, typename Tag>
-using strong_t = strong::
-    type<T, Tag, strong::regular, strong::convertible_to<T>, strong::hashable>;
-
-using PyModuleSelf = strong_t<PyObject*, struct PyModuleSelf_>;
-using PyModuleCls = strong_t<PyObject*, struct PyModuleCls_>;
-using PyCFunction = strong_t<PyObject*, struct PyCFunction_>;
-
 struct CodeLocation {
   CodeLocation() = default;
   explicit CodeLocation(const PyFrameObject* frame)
@@ -67,14 +48,6 @@
   int lasti_{0};
 };
 
-// Temporary struct. This will be replaced by ExtraFields<EventType>.
-struct FrameArgs {
-  std::string name_;
-  CallType call_type_;
-  c10::optional<std::pair<PyModuleSelf, PyModuleCls>> module_;
-  c10::optional<size_t> module_id_;
-};
-
 PyObject* nnModuleCode() {
   static auto module_call_code = []() {
     pybind11::gil_scoped_acquire gil;
@@ -88,23 +61,21 @@
 }
 
 } // namespace
+} // namespace impl
 } // namespace profiler
-} // namespace autograd
 } // namespace torch
 
 template <>
-struct std::hash<torch::autograd::profiler::CodeLocation> {
-  size_t operator()(const torch::autograd::profiler::CodeLocation& x) {
+struct std::hash<torch::profiler::impl::CodeLocation> {
+  size_t operator()(const torch::profiler::impl::CodeLocation& x) {
     return c10::get_hash(x.code_, x.lasti_);
   }
 };
 
 namespace torch {
-namespace autograd {
 namespace profiler {
-namespace python_tracer {
+namespace impl {
 namespace {
-
 // ============================================================================
 // == CallTypeHelper: Tools for generic programming on specializations. =======
 // ============================================================================
@@ -124,8 +95,7 @@
   static void map(T& t, FunctorT& f, Args... args) {
     f(std::get<C>(t), args...);
     c10::guts::if_constexpr<C + 1 < End>(
-        [&](auto _) { map<C + 1>(_(t), f, std::forward<Args>(args)...); },
-        [&](auto) {});
+        [&](auto _) { map<C + 1>(_(t), f, std::forward<Args>(args)...); });
   }
 
  public:
@@ -181,6 +151,7 @@
 struct Config<CallType::PyCall> {
   using key_t = CodeLocation;
   using cache_t = ska::flat_hash_map<key_t, PyFrameState>;
+  static constexpr EventType event_type = EventType::PyCall;
 };
 
 template <>
@@ -191,12 +162,14 @@
     ska::flat_hash_map<PyModuleSelf, PyModuleCls> modules_;
     ska::flat_hash_map<PyModuleCls, at::StringView> module_cls_names_;
   };
+  static constexpr EventType event_type = EventType::PyCall;
 };
 
 template<>
 struct Config<CallType::PyCCall> {
-  using key_t = PyCFunction;
+  using key_t = torch::profiler::impl::PyCFunction;
   using cache_t = ska::flat_hash_map<key_t, at::StringView>;
+  static constexpr EventType event_type = EventType::PyCCall;
 };
 
 // ============================================================================
@@ -230,16 +203,22 @@
   void store(const typename Config<C>::key_t&);
 
   template <CallType C>
-  auto load(const Callsite<C>& callsite) {
-    // NB: For now caller is dropped. It will be used in the next PR.
-    return load<C>(callsite.value_);
+  auto load(const Callsite<C>& callsite, size_t python_tid) const {
+    auto caller = load<CallType::PyCall>(callsite.caller_);
+    TORCH_INTERNAL_ASSERT(!caller.second.has_value());
+    return ExtraFields<Config<C>::event_type>{
+        /*end_time_ns=*/std::numeric_limits<time_t>::min(),
+        python_tid,
+        caller.first,
+        load<C>(callsite.value_)};
   }
 
   void trimPrefixes();
 
  private:
   template <CallType C>
-  FrameArgs load(const typename Config<C>::key_t&) const;
+  typename ExtraFields<Config<C>::event_type>::args_t load(
+      const typename Config<C>::key_t&) const;
 
   template <CallType C>
   using State = typename Config<C>::cache_t;
@@ -267,16 +246,9 @@
 }
 
 template <>
-FrameArgs ValueCache::load<CallType::PyCall>(const PyCallKey& key) const {
-  auto frame_state = std::get<CallType::PyCall>(state_).at(key);
-  return {
-      fmt::format(
-          "{}({}): {}",
-          frame_state.filename_.str(),
-          frame_state.line_no_,
-          frame_state.funcname_.str()),
-      CallType::PyCall,
-      /*module_=*/c10::nullopt};
+ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyCall>(
+    const PyCallKey& key) const {
+  return {std::get<CallType::PyCall>(state_).at(key), c10::nullopt};
 }
 
 template <>
@@ -301,20 +273,13 @@
 }
 
 template <>
-FrameArgs ValueCache::load<CallType::PyModuleCall>(
+ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyModuleCall>(
     const PyModuleCallKey& key) const {
   auto& cache = std::get<CallType::PyModuleCall>(state_);
+  TORCH_INTERNAL_ASSERT(cache.module_forward_.has_value());
   auto cls = cache.modules_.at(key);
-
-  // NB: For now fwd is not used.
-  // TORCH_INTERNAL_ASSERT(cache.module_forward_.has_value());
-  // auto fwd = std::get<CallType::PyCall>(state_).at(*cache.module_forward_);
-
-  return {
-      fmt::format("nn.Module: {}", cache.module_cls_names_.at(cls).str()),
-      CallType::PyModuleCall,
-      std::make_pair(key, cls),
-      /*module_id_=*/c10::nullopt};
+  auto fwd = std::get<CallType::PyCall>(state_).at(*cache.module_forward_);
+  return {fwd, NNModuleInfo{key, cls, cache.module_cls_names_.at(cls)}};
 }
 
 template <>
@@ -326,11 +291,9 @@
 }
 
 template <>
-FrameArgs ValueCache::load<CallType::PyCCall>(const PyCCallKey& key) const {
-  return {
-      std::get<CallType::PyCCall>(state_).at(key).str(),
-      CallType::PyCCall,
-      /*module_=*/c10::nullopt};
+ExtraFields<EventType::PyCCall>::args_t ValueCache::load<CallType::PyCCall>(
+    const PyCCallKey& key) const {
+  return std::get<CallType::PyCCall>(state_).at(key);
 }
 
 // TODO: Use re2.
@@ -353,8 +316,7 @@
 // ============================================================================
 // == TraceKey cache ==========================================================
 // ============================================================================
-using TraceKey =
-    strong::type<uint64_t, struct TraceKey_, strong::regular, strong::hashable>;
+using python_tracer::TraceKey;
 
 TraceKey nextKey() {
   static std::atomic<uint64_t> key{0};
@@ -375,7 +337,6 @@
       value_cache.store<C>(callsite.value_);
       value_cache.store<CallType::PyCall>(callsite.caller_);
       it = state_.insert({callsite, nextKey()}).first;
-
     }
     return it->second;
   }
@@ -465,8 +426,11 @@
     Py_DECREF((PyObject*)ctx_);
   }
 
-  template <CallType C, typename... Args>
+  template <CallType C, EventType E, typename... Args>
   TraceKey intern(Args... args) {
+    static_assert(
+        Config<C>::event_type == E,
+        "ThreadLocalResults.intern called from the wrong typed context.");
     return std::get<C>(trace_keys_)
         .intern(Callsite<C>(std::forward<Args>(args)...), *value_cache_);
   }
@@ -477,15 +441,14 @@
   TraceContext* ctx_;
   ValueCache* value_cache_;
   CallTypeHelper<TraceKeyCacheState>::tuple_type trace_keys_;
-  AppendOnlyList<std::pair<TraceKey, int64_t>, BLOCK_SIZE> enters_;
-  AppendOnlyList<int64_t, BLOCK_SIZE> exit_times_;
-  AppendOnlyList<int64_t, BLOCK_SIZE> c_exit_times_;
+  AppendOnlyList<approx_time_t, BLOCK_SIZE> exit_times_;
+  AppendOnlyList<approx_time_t, BLOCK_SIZE> c_exit_times_;
 };
 
 // ============================================================================
 // == Tracing implementation ==================================================
 // ============================================================================
-class PythonTracer final : public PythonTracerBase {
+class PythonTracer final : public python_tracer::PythonTracerBase {
  public:
   static int pyProfileFn(
       PyObject* obj,
@@ -494,22 +457,22 @@
       PyObject* arg);
 
   static PythonTracer& singleton();
-  void start() override;
+  void start(torch::profiler::impl::RecordQueue* queue) override;
   void stop() override;
-  std::vector<std::unique_ptr<PyTraceEvent>> getEvents() override;
+  std::vector<std::shared_ptr<Result>> getEvents(
+      std::function<time_t(approx_time_t)> time_converter,
+      std::vector<python_tracer::CompressedEvent>& enters) override;
   void clear() override;
 
  private:
   PythonTracer();
-  friend class PyTraceReplay;
 
   void recordPyCall(ThreadLocalResults& tls, PyFrameObject* frame);
   void recordCCall(ThreadLocalResults& tls, PyFrameObject* frame, PyObject* arg);
 
-  bool active_;
+  torch::profiler::impl::RecordQueue* queue_;
   PyObject* module_call_code_;
 
-  // TODO: Move to RecordQueue
   std::deque<ThreadLocalResults> thread_local_results_;
   ValueCache value_cache_;
 };
@@ -520,16 +483,16 @@
 }
 
 PythonTracer::PythonTracer()
-    : active_(false), module_call_code_(nnModuleCode()) {}
+    : queue_(nullptr), module_call_code_(nnModuleCode()) {}
 
-void PythonTracer::start() {
-  TORCH_CHECK(!active_, "PythonTracer is already active")
+void PythonTracer::start(torch::profiler::impl::RecordQueue* queue) {
+  TORCH_CHECK(queue_ == nullptr, "PythonTracer is already active")
   TORCH_CHECK(
       !thread_local_results_.size(),
       "PythonTracer should not have active contexts");
+  queue_ = queue;
 
   pybind11::gil_scoped_acquire gil;
-  auto t0 = now();
 
   // Loop over all threads within the current interpreter. We will need to
   // register a trace function with each thread. We set the current thread to
@@ -580,12 +543,11 @@
 
   // Restore the thread state to its initial value.
   PyThreadState_Swap(thread_states[0]);
-
-  active_ = true;
 };
 
 void PythonTracer::stop() {
-  TORCH_INTERNAL_ASSERT(active_, "PythonTracer is not running.")
+  TORCH_INTERNAL_ASSERT(queue_ != nullptr, "PythonTracer is not running.")
+  queue_ = nullptr;
 
   pybind11::gil_scoped_acquire gil;
 
@@ -595,16 +557,16 @@
     PyEval_SetProfile(nullptr, nullptr);
   }
   PyThreadState_Swap(initial_thread_state);
-  active_ = false;
 }
 
 void PythonTracer::clear() {
-  TORCH_CHECK(!active_, "Cannot clear state while PythonTracer is active.");
+  TORCH_CHECK(queue_ == nullptr, "Cannot clear state while PythonTracer is active.");
   thread_local_results_.clear();
   value_cache_ = ValueCache();
 }
 
 void PythonTracer::recordPyCall(ThreadLocalResults& tls, PyFrameObject* frame) {
+  static constexpr auto E = EventType::PyCall;
   auto get_key = [&]() -> TraceKey {
     if ((PyObject*)(frame->f_code) == module_call_code_) {
       // By default, CPython stores locals in a "fast" format, with an array
@@ -620,14 +582,14 @@
       auto self = PyDict_GetItemString(frame->f_locals, "self");
       PyFrame_LocalsToFast(frame, 0);
       TORCH_INTERNAL_ASSERT(frame->f_back != nullptr);
-      return tls.intern<CallType::PyModuleCall>(self, frame->f_back);
+      return tls.intern<CallType::PyModuleCall, E>(self, frame->f_back);
 
     } else {
       auto f_back = frame->f_back != nullptr ? frame->f_back : frame;
-      return tls.intern<CallType::PyCall>(frame, f_back);
+      return tls.intern<CallType::PyCall, E>(frame, f_back);
     }
   };
-  tls.enters_.emplace_back(get_key(), now());
+  queue_->getSubqueue()->emplace_py_call(get_key(), getApproximateTime());
 }
 
 void PythonTracer::recordCCall(
@@ -636,167 +598,159 @@
     PyObject* arg) {
   // NB: For C calls a new frame is not created, so we use `frame` rather than
   //     `frame->f_back`.
-  tls.enters_.emplace_back(tls.intern<CallType::PyCCall>(arg, frame), now());
+  auto key = tls.intern<CallType::PyCCall, EventType::PyCCall>(arg, frame);
+  queue_->getSubqueue()->emplace_py_call(key, getApproximateTime());
 }
 
 // ============================================================================
 // == Post processing =========================================================
 // ============================================================================
+struct Exit {
+  bool operator>(const Exit& other) const {
+    return t_ > other.t_;
+  }
 
-class PyTraceReplay {
+  time_t t_;
+  size_t python_tid_;
+};
+
+class PostProcess {
  public:
-  static std::vector<std::unique_ptr<PyTraceEvent>> getEvents() {
-    return PyTraceReplay().replayStack();
+  PostProcess(
+      std::function<time_t(approx_time_t)> time_converter,
+      std::deque<ThreadLocalResults>& tls,
+      const ValueCache& value_cache)
+      : time_converter_{time_converter} {
+    for (size_t python_tid : c10::irange(tls.size())) {
+      CallTypeHelper<TraceKeyCacheState>::map(
+          tls[python_tid].trace_keys_, *this, value_cache, python_tid);
+
+      addExits<EventType::PyCall>(tls[python_tid].exit_times_, python_tid);
+      addExits<EventType::PyCCall>(tls[python_tid].c_exit_times_, python_tid);
+    }
+  }
+
+  template <CallType C>
+  void operator()(
+      const TraceKeyCacheState<C>& trace_cache,
+      const ValueCache& value_cache,
+      size_t python_tid) {
+    for (const auto& it : trace_cache.state_) {
+      const auto inserted = get_state<Config<C>::event_type>().fields_.insert(
+          {it.second, value_cache.load(it.first, python_tid)});
+      TORCH_INTERNAL_ASSERT(inserted.second, "Duplicate key: ", it.second);
+    }
+  }
+
+  template <EventType E, size_t N>
+  void addExits(
+      AppendOnlyList<approx_time_t, N>& exits,
+      size_t python_tid) {
+    for (const auto i : exits) {
+      get_state<E>().exits_.push({time_converter_(i), python_tid});
+    }
+  }
+
+  std::vector<std::shared_ptr<Result>> run(
+      std::vector<python_tracer::CompressedEvent>& enters) {
+    std::stable_sort(
+        enters.begin(), enters.end(), [](const auto a, const auto b) {
+          return a.enter_t_ < b.enter_t_;
+        });
+    std::vector<std::shared_ptr<Result>> out;
+    populate<EventType::PyCall>(enters, out);
+    populate<EventType::PyCCall>(enters, out);
+    return out;
   }
 
  private:
-  PyTraceReplay();
-  std::vector<std::unique_ptr<PyTraceEvent>> replayStack() const;
+  template <EventType E>
+  void populate(
+      std::vector<python_tracer::CompressedEvent>& enters,
+      std::vector<std::shared_ptr<Result>>& out) {
+    using stack_t = std::vector<std::shared_ptr<Result>>;
+    ska::flat_hash_map<size_t, stack_t> stacks;
+    auto& state = get_state<E>();
+    for (const auto& enter : enters) {
+      auto fields_it = state.fields_.find(enter.key_);
+      if (fields_it != state.fields_.end()) {
+        while (!state.exits_.empty() &&
+               state.exits_.top().t_ < enter.enter_t_) {
+          auto& stack = stacks[state.exits_.top().python_tid_];
+          TORCH_INTERNAL_ASSERT(stack.size(), "Python replay stack is empty.");
+          c10::get<ExtraFields<E>>(stack.back()->extra_fields_).end_time_ns_ =
+              state.exits_.top().t_;
+          state.exits_.pop();
+          stack.pop_back();
+        }
+        out.push_back(Result::create(
+            enter.enter_t_,
+            enter.system_tid_,
+            enter.kineto_info_,
+            fields_it->second));
 
-  struct RawEvent {
-    int64_t t_;
-    size_t thread_id_;
-    TraceKey key_;
-    int what_;  // cPython uses integers to tag event types.
+        stacks[fields_it->second.python_tid_].push_back(out.back());
+      }
+    }
+  }
+
+  template <EventType E>
+  struct State {
+    ska::flat_hash_map<TraceKey, ExtraFields<E>> fields_;
+    std::priority_queue<Exit, std::vector<Exit>, std::greater<Exit>> exits_;
   };
 
-  struct ReplayFrame {
-    std::unique_ptr<PyTraceEvent> event_;
-    size_t id_;
-    size_t parent_id_;
-  };
+  template <EventType E>
+  auto& get_state() {
+    return std::get<E == EventType::PyCall ? 0 : 1>(state_);
+  }
 
-  ska::flat_hash_map<TraceKey, FrameArgs> frame_args_;
-  std::vector<RawEvent> raw_events_;
+  std::function<time_t(approx_time_t)> time_converter_;
+  std::tuple<State<EventType::PyCall>, State<EventType::PyCCall>> state_;
 };
 
-PyTraceReplay::PyTraceReplay() {
-  auto& tracer = PythonTracer::singleton();
-  tracer.value_cache_.trimPrefixes();
-
-  ska::flat_hash_map<PyModuleCallKey, size_t> self_to_id;
-  ska::flat_hash_map<PyModuleCls, size_t> cls_id_counter;
-
-  for (auto& local_results : tracer.thread_local_results_) {
-    auto f = [&](auto& cache) {
-      for (const auto& it : cache.state_) {
-        auto frame = tracer.value_cache_.load(it.first);
-        if (frame.module_.has_value()) {
-          auto id_it = self_to_id.find(frame.module_->first);
-          if (id_it == self_to_id.end()) {
-            auto id = cls_id_counter[frame.module_->second]++;
-            id_it = self_to_id.insert({frame.module_->first, id}).first;
-          }
-          frame.module_id_ = id_it->second;
-        }
-        auto inserted = frame_args_.insert({it.second, frame});
-        TORCH_INTERNAL_ASSERT(inserted.second);
-      }
-    };
-    CallTypeHelper<TraceKeyCacheState>::map(local_results.trace_keys_, f);
-  }
-
-  for (const auto py_tid : c10::irange(tracer.thread_local_results_.size())) {
-    auto& local_results = tracer.thread_local_results_[py_tid];
-    for (const auto& i : local_results.exit_times_) {
-      raw_events_.push_back({i, py_tid, TraceKey(), PyTrace_RETURN});
-    }
-    for (const auto& i : local_results.c_exit_times_) {
-      raw_events_.push_back({i, py_tid, TraceKey(), PyTrace_C_RETURN});
-    }
-
-    for (const auto& it : local_results.enters_) {
-      auto call_type = frame_args_.at(it.first).call_type_;
-      auto what =
-          call_type == CallType::PyCCall ? PyTrace_C_CALL : PyTrace_CALL;
-      raw_events_.push_back({it.second, py_tid, it.first, what});
+struct PythonIDVisitor {
+  void operator()(ExtraFields<EventType::PyCall>& py_call) {
+    py_call.id_ = ++current_python_id_;
+    if (py_call.module_.has_value()) {
+      auto& m = py_call.module_;
+      auto& module_ids = module_ids_[m->cls_];
+      m->id_ = module_ids.insert({m->self_, module_ids.size()}).first->second;
     }
   }
+
+  void operator()(ExtraFields<EventType::PyCCall>& py_call) {
+    py_call.id_ = ++current_python_id_;
+  }
+
+  template <typename T>
+  void operator()(T&) {}
+
+  size_t current_python_id_{0};
+  ska::flat_hash_map<PyModuleCls, ska::flat_hash_map<PyModuleSelf, size_t>>
+      module_ids_;
+};
+
+std::vector<std::shared_ptr<Result>> PythonTracer::getEvents(
+    std::function<time_t(approx_time_t)> time_converter,
+    std::vector<python_tracer::CompressedEvent>& enters) {
+  value_cache_.trimPrefixes();
+  PostProcess post_process(time_converter, thread_local_results_, value_cache_);
+  auto out = post_process.run(enters);
+
   std::stable_sort(
-      raw_events_.begin(), raw_events_.end(), [](const auto& a, const auto& b) {
-        return a.t_ < b.t_;
-      });
-}
+    out.begin(), out.end(), [](const auto& a, const auto& b) {
+      return a->start_time_ns_ < b->start_time_ns_;
+    });
 
-std::vector<std::unique_ptr<PyTraceEvent>> PyTraceReplay::replayStack() const {
-  auto& tracer = PythonTracer::singleton();
-  size_t id_counter = 0;
-  std::vector<std::vector<ReplayFrame>> stacks(tracer.thread_local_results_.size());
-  std::vector<ReplayFrame> results;
-
-  // Match calls and returns.
-  size_t event_idx = 0;
-  for (auto& raw_event : raw_events_) {
-    auto& stack = stacks[raw_event.thread_id_];
-    auto push_frame =
-        [&]() {
-          auto& args = frame_args_.at(raw_event.key_);
-          stack.push_back(ReplayFrame{
-              /*event_=*/std::make_unique<PyTraceEvent>(PyTraceEvent{
-                  /*startTime_=*/raw_event.t_,
-                  /*endTime_=*/-1, // Placeholder
-                  /*name_=*/args.name_,
-                  /*thread_id_=*/raw_event.thread_id_,
-                  /*parent_=*/nullptr, // Placeholder
-                  /*module_id_=*/args.module_id_,
-                  /*call_idx_=*/event_idx,
-                  /*return_idx_=*/0 // Placeholder
-              }),
-              /*id_=*/id_counter++,
-              /*parent_id_=*/stack.size() ? stack.back().id_ : 0,
-          });
-        };
-
-    switch (raw_event.what_) {
-      case PyTrace_CALL:
-      case PyTrace_C_CALL:
-        push_frame();
-        break;
-
-      case PyTrace_RETURN:
-      case PyTrace_C_RETURN:
-        TORCH_INTERNAL_ASSERT(stack.size(), "Python replay stack is empty.")
-        stack.back().event_->endTime_ = raw_event.t_;
-        stack.back().event_->return_idx_ = event_idx;
-        results.push_back(std::move(stack.back()));
-        stack.pop_back();
-        break;
-    }
-    event_idx++;
+  PythonIDVisitor id_visitor;
+  for (auto& i : out) {
+    c10::visit(id_visitor, i->extra_fields_);
   }
 
-  // Cleanup by feining return to close out the stack. This is needed so
-  // frames above the one that called the profiler still appear in the trace.
-  const auto t_final = now();
-  for (auto& stack : stacks) {
-    while (stack.size()) {
-      stack.back().event_->endTime_ = t_final;
-      stack.back().event_->return_idx_ = event_idx;
-      results.push_back(std::move(stack.back()));
-      stack.pop_back();
-      event_idx++;
-    }
-  }
-
-  // Convert to `PyTraceEvent`, and map id to pointer.
-  ska::flat_hash_map<size_t, PyTraceEvent*> event_id_map{{0, nullptr}};
-  std::vector<std::unique_ptr<PyTraceEvent>> out;
-  for (auto& r : results) {
-    out.push_back(std::move(r.event_));
-    event_id_map.insert({r.id_, out.back().get()});
-  }
-
-  // Link parents to children.
-  for (const auto i : c10::irange(results.size())) {
-    out[i]->parent_ = event_id_map.at(results[i].parent_id_);
-  }
   return out;
 }
 
-std::vector<std::unique_ptr<PyTraceEvent>> PythonTracer::getEvents() {
-  return PyTraceReplay::getEvents();
-}
-
 // ============================================================================
 // == API =====================================================================
 // ============================================================================
@@ -818,26 +772,35 @@
 
     case PyTrace_EXCEPTION:
     case PyTrace_RETURN:
-      local_results.exit_times_.emplace_back(now());
+      local_results.exit_times_.emplace_back(getApproximateTime());
       break;
 
     case PyTrace_C_EXCEPTION:
     case PyTrace_C_RETURN:
-      local_results.c_exit_times_.emplace_back(now());
+      local_results.c_exit_times_.emplace_back(getApproximateTime());
       break;
   }
   return 0;
 }
 
-PythonTracerBase& getTracer() {
+python_tracer::PythonTracerBase& getTracer() {
   return PythonTracer::singleton();
 }
 } // namespace
+} // namespace impl
+} // namespace profiler
+} // namespace torch
+
+namespace torch {
+namespace autograd {
+namespace profiler {
+namespace python_tracer {
 
 void init() {
   pybind11::gil_scoped_acquire gil;
-  TORCH_CHECK(PyType_Ready(&TraceContextType) == 0);
-  torch::profiler::impl::python_tracer::registerTracer(&getTracer);
+  TORCH_CHECK(PyType_Ready(&torch::profiler::impl::TraceContextType) == 0);
+  torch::profiler::impl::python_tracer::registerTracer(
+      &torch::profiler::impl::getTracer);
 }
 } // namespace python_tracer
 } // namespace profiler
diff --git a/torch/csrc/profiler/collection.cpp b/torch/csrc/profiler/collection.cpp
index e26745a..9eec2c9 100644
--- a/torch/csrc/profiler/collection.cpp
+++ b/torch/csrc/profiler/collection.cpp
@@ -3,6 +3,8 @@
 #include <algorithm>
 #include <queue>
 
+#include <fmt/format.h>
+
 #include <ATen/record_function.h>
 #include <c10/core/ScalarTypeToTypeMeta.h>
 #include <c10/util/flat_hash_map.h>
@@ -138,10 +140,12 @@
     static NoOpPythonTracer singleton_;
     return singleton_;
   }
-  void start() override {}
+  void start(RecordQueue*) override {}
   void stop() override {}
   void clear() override {}
-  std::vector<std::unique_ptr<PyTraceEvent>> getEvents() override {
+  std::vector<std::shared_ptr<Result>> getEvents(
+      std::function<time_t(approx_time_t)>,
+      std::vector<CompressedEvent>&) override {
     return {};
   }
   ~NoOpPythonTracer() = default;
@@ -162,7 +166,12 @@
 
 #define OUT_T(method_name) decltype(std::declval<Result>().method_name())
 #define DEFINE_VISITOR(                                                 \
-    method_name, torch_op_field, backend_field, allocation_field)       \
+    method_name,                                                        \
+    torch_op_field,                                                     \
+    backend_field,                                                      \
+    allocation_field,                                                   \
+    py_field,                                                           \
+    py_c_field)                                                         \
   OUT_T(method_name) Result::method_name() const {                      \
     using out_t = OUT_T(method_name);                                   \
     return c10::visit(                                                  \
@@ -178,10 +187,30 @@
             [&](const ExtraFields<EventType::Allocation>& e) -> out_t { \
               (void)e;                                                  \
               return allocation_field;                                  \
+            },                                                          \
+            [&](const ExtraFields<EventType::PyCall>& e) -> out_t {     \
+              (void)e;                                                  \
+              return py_field;                                          \
+            },                                                          \
+            [&](const ExtraFields<EventType::PyCCall>& e) -> out_t {    \
+              (void)e;                                                  \
+              return py_c_field;                                        \
             }),                                                         \
         extra_fields_);                                                 \
   }
 
+std::string toString(const ExtraFields<EventType::PyCall>& e) {
+  if (e.module_.has_value()) {
+    return fmt::format(
+        "nn.Module: {}_{}", e.module_->cls_name_.str(), e.module_->id_);
+  }
+  return fmt::format(
+      "{}({}): {}",
+      e.callsite_.filename_.str(),
+      e.callsite_.line_no_,
+      e.callsite_.funcname_.str());
+}
+
 using torch::profiler::impl::kineto::KinetoActivityType;
 namespace {
 KinetoActivityType scopeToType(at::RecordScope scope) {
@@ -191,20 +220,42 @@
 }
 } // namespace
 
-DEFINE_VISITOR(name, e.name_, e.name_, "[memory]");
+DEFINE_VISITOR(
+    name,
+    e.name_,
+    e.name_,
+    "[memory]",
+    toString(e),
+    e.function_name_.str());
 DEFINE_VISITOR(
     kinetoType,
     scopeToType(e.scope_),
     scopeToType(e.scope_),
-    KinetoActivityType::CPU_INSTANT_EVENT);
-DEFINE_VISITOR(correlationID, e.correlation_id_, 0, 0);
-DEFINE_VISITOR(endTimeNS, e.end_time_ns_, e.end_time_us_ * 1000, start_time_ns_);
-DEFINE_VISITOR(endTID, e.end_tid_, start_tid_, start_tid_);
+    KinetoActivityType::CPU_INSTANT_EVENT,
+    KinetoActivityType::PYTHON_FUNCTION,
+    KinetoActivityType::PYTHON_FUNCTION);
+DEFINE_VISITOR(correlationID, e.correlation_id_, 0, 0, 0, 0);
+DEFINE_VISITOR(
+    endTimeNS,
+    e.end_time_ns_,
+    e.end_time_us_ * 1000,
+    start_time_ns_,
+    e.end_time_ns_,
+    e.end_time_ns_);
+DEFINE_VISITOR(
+    endTID,
+    e.end_tid_,
+    start_tid_,
+    start_tid_,
+    start_tid_,
+    start_tid_);
 DEFINE_VISITOR(
     deviceType,
     c10::DeviceType::CPU,
     c10::DeviceType::CPU,
-    e.device_type_);
+    e.device_type_,
+    c10::DeviceType::CPU,
+    c10::DeviceType::CPU);
 #undef DEFINE_VISITOR
 #undef OUT_T
 
@@ -262,8 +313,18 @@
   return out;
 }
 
-RecordQueue::RecordQueue(const ProfilerConfig& config)
-    : id_(++queue_id_), config_{config} {}
+RecordQueue::RecordQueue(
+    const ProfilerConfig& config,
+    std::set<ActivityType> activities)
+    : id_(++queue_id_), config_{config}, activities_{activities} {
+  if (tracePython()) {
+    python_tracer::PythonTracerBase::get().start(this);
+  }
+}
+
+bool RecordQueue::tracePython() const {
+  return config_.with_stack && activities_.count(ActivityType::CPU);
+}
 
 ThreadLocalSubqueue* RecordQueue::getSubqueue() {
   // In the most common case, a thread will want to write to the same sub-queue
@@ -290,6 +351,12 @@
   return it->second.get();
 }
 
+void RecordQueue::stop() {
+  if (tracePython()) {
+    python_tracer::PythonTracerBase::get().stop();
+  }
+}
+
 namespace {
 template <typename T>
 auto steal_or_default(T& it) {
@@ -433,6 +500,7 @@
         : time_converter(t);
   };
   std::vector<std::shared_ptr<Result>> out;
+  std::vector<python_tracer::CompressedEvent> python_enters;
   for (auto& subqueue_it : sub_queues_) {
     auto& queue = *subqueue_it.second;
     for (auto& i : queue.backend_events_) {
@@ -482,6 +550,19 @@
           /*extra_fields_=*/std::move(i)));
     }
     queue.allocations_.clear();
+
+    for (auto& i : queue.py_calls_) {
+      python_enters.push_back(
+          {i.first, queue.tid(), queue.kineto_info(), converter(i.second)});
+    }
+  }
+
+  if (tracePython()) {
+    auto& tracer = python_tracer::PythonTracerBase::get();
+    for (auto i : tracer.getEvents(converter, python_enters)) {
+      out.push_back(i);
+    }
+    tracer.clear();
   }
 
   build_tree(out);
diff --git a/torch/csrc/profiler/collection.h b/torch/csrc/profiler/collection.h
index 68aa837..9509c65 100644
--- a/torch/csrc/profiler/collection.h
+++ b/torch/csrc/profiler/collection.h
@@ -10,10 +10,12 @@
 #include <c10/core/DeviceType.h>
 #include <c10/macros/Macros.h>
 #include <c10/util/flat_hash_map.h>
+#include <c10/util/strong_type.h>
 #include <c10/util/variant.h>
 #include <torch/csrc/profiler/containers.h>
 #include <torch/csrc/profiler/kineto_shim.h>
 #include <torch/csrc/profiler/util.h>
+#include <torch/csrc/utils/python_stub.h>
 
 namespace torch {
 namespace profiler {
@@ -22,7 +24,9 @@
 enum class EventType : uint8_t {
   TorchOp = 0,
   Backend,
-  Allocation
+  Allocation,
+  PyCall,
+  PyCCall
 };
 
 template <EventType>
@@ -109,6 +113,73 @@
     std::is_pod<ExtraFields<EventType::Allocation>>::value,
     "Non-POD member of ExtraFields<EventType::Allocation>.");
 
+struct PyFrameState {
+  int line_no_;
+  at::StringView filename_;
+  at::StringView funcname_;
+};
+
+template <typename T, typename Tag>
+using strong_t = strong::
+    type<T, Tag, strong::regular, strong::convertible_to<T>, strong::hashable>;
+
+using PyModuleSelf = strong_t<PyObject*, struct PyModuleSelf_>;
+using PyModuleCls = strong_t<PyObject*, struct PyModuleCls_>;
+using PyCFunction = strong_t<PyObject*, struct PyCFunction_>;
+
+struct NNModuleInfo {
+  PyModuleSelf self_;
+  PyModuleCls cls_;
+  at::StringView cls_name_;
+
+  // Indicates that `self_` is the kth instance of `cls_` observed.
+  size_t id_{std::numeric_limits<size_t>::max()};
+};
+
+struct PyExtraFieldsBase {
+  PyExtraFieldsBase(time_t end_time_ns, size_t python_tid, PyFrameState caller)
+      : end_time_ns_{end_time_ns}, python_tid_{python_tid}, caller_{caller} {}
+
+  time_t end_time_ns_;
+  size_t python_tid_;
+  PyFrameState caller_;
+
+  // kth python event observed. (Used by TensorBoard)
+  size_t id_{std::numeric_limits<size_t>::max()};
+};
+
+template <>
+struct ExtraFields<EventType::PyCall> : public PyExtraFieldsBase {
+  using args_t = std::pair<PyFrameState, c10::optional<NNModuleInfo>>;
+
+  ExtraFields(
+      time_t end_time_ns,
+      size_t python_tid,
+      PyFrameState caller,
+      args_t args)
+      : PyExtraFieldsBase(end_time_ns, python_tid, caller),
+        callsite_{args.first},
+        module_{args.second} {}
+
+  PyFrameState callsite_;
+  c10::optional<NNModuleInfo> module_;
+};
+
+template <>
+struct ExtraFields<EventType::PyCCall> : public PyExtraFieldsBase {
+  using args_t = at::StringView;
+
+  ExtraFields(
+      time_t end_time_ns,
+      size_t python_tid,
+      PyFrameState caller,
+      args_t args)
+      : PyExtraFieldsBase(end_time_ns, python_tid, caller),
+        function_name_{args} {}
+
+  at::StringView function_name_;
+};
+
 struct TORCH_API Result : public std::enable_shared_from_this<Result> {
   template <typename... Args>
   [[nodiscard]] static std::shared_ptr<Result> create(Args... args) {
@@ -128,7 +199,9 @@
   c10::variant<
       ExtraFields<EventType::TorchOp>,
       ExtraFields<EventType::Backend>,
-      ExtraFields<EventType::Allocation>>
+      ExtraFields<EventType::Allocation>,
+      ExtraFields<EventType::PyCall>,
+      ExtraFields<EventType::PyCCall>>
       extra_fields_;
 
   std::weak_ptr<Result> parent_;
@@ -201,6 +274,7 @@
   AppendOnlyList<int64_t, IO_ENCODER_DEFAULT_BLOCK_SIZE> tensor_sizes_;
 };
 
+class RecordQueue;
 namespace python_tracer {
 /*
 Libtorch does not depend on Python (e.g. cannot #include <Python.h>); however
@@ -214,29 +288,29 @@
 in the PyTorch codebase.
 */
 
-struct TORCH_API PyTraceEvent {
-  int64_t startTime_;
-  int64_t endTime_;
-  std::string name_;
+using TraceKey = strong::type<
+    uint64_t,
+    struct TraceKey_,
+    strong::regular,
+    strong::hashable,
+    strong::ostreamable>;
 
-  uint64_t thread_id_;
-  PyTraceEvent* parent_;
-  c10::optional<size_t> module_id_;
-
-  // Index in the list of raw call and return events. This allows one to
-  // convert a vector of PyTraceEvents back into the constituent call and
-  // return events, even when events share the same timestamp.
-  size_t call_idx_;
-  size_t return_idx_;
+struct CompressedEvent {
+  TraceKey key_;
+  uint64_t system_tid_;
+  kineto::DeviceAndResource kineto_info_;
+  time_t enter_t_;
 };
 
 struct TORCH_API PythonTracerBase {
   static PythonTracerBase& get();
   virtual ~PythonTracerBase() = default;
 
-  virtual void start() = 0;
+  virtual void start(RecordQueue* queue) = 0;
   virtual void stop() = 0;
-  virtual std::vector<std::unique_ptr<PyTraceEvent>> getEvents() = 0;
+  virtual std::vector<std::shared_ptr<Result>> getEvents(
+      std::function<time_t(approx_time_t)> time_converter,
+      std::vector<CompressedEvent>& enters) = 0;
   virtual void clear() = 0;
 };
 
@@ -260,6 +334,11 @@
     allocations_.emplace_back(std::forward<Args>(args)...);
   }
 
+  template <class... Args>
+  void emplace_py_call(Args&&... args) {
+    py_calls_.emplace_back(std::forward<Args>(args)...);
+  }
+
   uint64_t tid() const {
     return tid_;
   }
@@ -283,6 +362,7 @@
 
   // with_stack
   AppendOnlyList<jit_stack_t, BlockSize> jit_stack_;
+  AppendOnlyList<std::pair<python_tracer::TraceKey, approx_time_t>, BlockSize> py_calls_;
 
   // with_modules
   AppendOnlyList<jit_modules_t, BlockSize> jit_modules_;
@@ -302,9 +382,11 @@
 
 class TORCH_API RecordQueue {
  public:
-  explicit RecordQueue(const ProfilerConfig& config);
+  RecordQueue(const ProfilerConfig& config, std::set<ActivityType> activities);
 
+  bool tracePython() const;
   ThreadLocalSubqueue* getSubqueue();
+  void stop();
 
   // NB: This is a destructive operation.
   std::vector<std::shared_ptr<Result>> getRecords(
@@ -313,6 +395,7 @@
  private:
   uint32_t id_;
   ProfilerConfig config_;
+  std::set<ActivityType> activities_;
   ska::flat_hash_map<uint64_t, std::unique_ptr<ThreadLocalSubqueue>> sub_queues_;
   std::mutex sub_queue_mutex_;
 };
diff --git a/torch/csrc/profiler/kineto_shim.cpp b/torch/csrc/profiler/kineto_shim.cpp
index df2098f..38fb008 100644
--- a/torch/csrc/profiler/kineto_shim.cpp
+++ b/torch/csrc/profiler/kineto_shim.cpp
@@ -68,6 +68,8 @@
       return libkineto::ActivityType::CPU_OP;
     case KinetoActivityType::CPU_INSTANT_EVENT:
       return libkineto::ActivityType::CPU_INSTANT_EVENT;
+    case KinetoActivityType::PYTHON_FUNCTION:
+      return libkineto::ActivityType::PYTHON_FUNCTION;
     default:
       TORCH_INTERNAL_ASSERT(
           type == KinetoActivityType::USER_ANNOTATION,
diff --git a/torch/csrc/profiler/kineto_shim.h b/torch/csrc/profiler/kineto_shim.h
index 59ff529..ce3b33d 100644
--- a/torch/csrc/profiler/kineto_shim.h
+++ b/torch/csrc/profiler/kineto_shim.h
@@ -62,7 +62,8 @@
 enum class KinetoActivityType : uint8_t {
   CPU_OP = 0,
   CPU_INSTANT_EVENT,
-  USER_ANNOTATION
+  USER_ANNOTATION,
+  PYTHON_FUNCTION
 };
 
 using annotation_t = std::vector<std::pair<std::string, std::string>>;
diff --git a/torch/profiler/python_tracer.py b/torch/profiler/python_tracer.py
index 73e5dba..f803b64 100644
--- a/torch/profiler/python_tracer.py
+++ b/torch/profiler/python_tracer.py
@@ -15,6 +15,6 @@
         [os.path.dirname(os.path.dirname(torch.__file__))]
     )
 
-    path_prefixes = sorted({os.path.abspath(i) for i in raw_paths})
+    path_prefixes = sorted({os.path.abspath(i) for i in raw_paths}, reverse=True)
     assert all(isinstance(i, str) for i in path_prefixes)
     return [i + os.sep for i in path_prefixes]