[Profiler] Memory profiler part 9: Mark activations (#88924)

This is a fairly straightforward pass: start at inputs and flood fill until we reach the backward pass.

Differential Revision: [D40868662](https://our.internmc.facebook.com/intern/diff/D40868662/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88924
Approved by: https://github.com/chaekit
diff --git a/test/profiler/test_memory_profiler.py b/test/profiler/test_memory_profiler.py
index 6819109..6e42f33 100644
--- a/test/profiler/test_memory_profiler.py
+++ b/test/profiler/test_memory_profiler.py
@@ -1189,16 +1189,16 @@
             aten::ones                                                                             -> 2 (INPUT)
 
             -- Forward & loss ---------------------------------------------------------------------------------------
-            aten::mul.Tensor                         1 (INPUT), 3 (PARAMETER)                      -> 4 (???)
-            aten::mul.Tensor                         1 (INPUT), 5 (PARAMETER)                      -> 6 (???)
-            aten::cat                                4 (???), 6 (???)                              -> 7 (???)
-            aten::binary_cross_entropy_with_logits   7 (???), 2 (INPUT)                            -> 13 (???)
+            aten::mul.Tensor                         1 (INPUT), 3 (PARAMETER)                      -> 4 (ACTIVATION)
+            aten::mul.Tensor                         1 (INPUT), 5 (PARAMETER)                      -> 6 (ACTIVATION)
+            aten::cat                                4 (ACTIVATION), 6 (ACTIVATION)                -> 7 (ACTIVATION)
+            aten::binary_cross_entropy_with_logits   7 (ACTIVATION), 2 (INPUT)                     -> 13 (ACTIVATION)
 
             -- Backward ---------------------------------------------------------------------------------------------
-            aten::ones_like                          13 (???)                                      -> 16 (???)
-            aten::sigmoid                            7 (???)                                       -> 17 (TEMPORARY)
+            aten::ones_like                          13 (ACTIVATION)                               -> 16 (ACTIVATION)
+            aten::sigmoid                            7 (ACTIVATION)                                -> 17 (TEMPORARY)
             aten::sub.Tensor                         17 (TEMPORARY), 2 (INPUT)                     -> 18 (TEMPORARY)
-            aten::mul.Tensor                         18 (TEMPORARY), 16 (???)                      -> 19 (???)
+            aten::mul.Tensor                         18 (TEMPORARY), 16 (ACTIVATION)               -> 19 (???)
             aten::div_.Scalar                        19 (???)                                      -> 19 (???)
             aten::slice.Tensor                       19 (???)                                      -> 19 (???)
             aten::slice.Tensor                       19 (???)                                      -> 19 (???)
@@ -1227,7 +1227,7 @@
             """\
             aten::ones                                                                             -> 1 (INPUT)
             aten::t                                  2 (PARAMETER)                                 -> 2 (PARAMETER)
-            aten::addmm                              3 (PARAMETER), 1 (INPUT), 2 (PARAMETER)       -> 4 (???)""",
+            aten::addmm                              3 (PARAMETER), 1 (INPUT), 2 (PARAMETER)       -> 4 (ACTIVATION)""",
         )
 
     def test_categories_e2e_simple_module_fwd_bwd(self) -> None:
@@ -1247,16 +1247,16 @@
             -- Forward & loss ---------------------------------------------------------------------------------------
             aten::ones                                                                             -> 1 (INPUT)
             aten::t                                  2 (PARAMETER)                                 -> 2 (PARAMETER)
-            aten::addmm                              3 (PARAMETER), 1 (INPUT), 2 (PARAMETER)       -> 4 (???)
-            aten::sum                                4 (???)                                       -> 5 (???)
+            aten::addmm                              3 (PARAMETER), 1 (INPUT), 2 (PARAMETER)       -> 4 (ACTIVATION)
+            aten::sum                                4 (ACTIVATION)                                -> 5 (ACTIVATION)
 
             -- Backward ---------------------------------------------------------------------------------------------
-            aten::ones_like                          5 (???)                                       -> 6 (???)
-            aten::expand                             6 (???)                                       -> 6 (???)
-            aten::t                                  6 (???)                                       -> 6 (???)
-            aten::mm                                 6 (???), 1 (INPUT)                            -> 7 (GRADIENT)
+            aten::ones_like                          5 (ACTIVATION)                                -> 6 (ACTIVATION)
+            aten::expand                             6 (ACTIVATION)                                -> 6 (ACTIVATION)
+            aten::t                                  6 (ACTIVATION)                                -> 6 (ACTIVATION)
+            aten::mm                                 6 (ACTIVATION), 1 (INPUT)                     -> 7 (GRADIENT)
             aten::t                                  7 (GRADIENT)                                  -> 7 (GRADIENT)
-            aten::sum.dim_IntList                    6 (???)                                       -> 9 (GRADIENT)
+            aten::sum.dim_IntList                    6 (ACTIVATION)                                -> 9 (GRADIENT)
             aten::view                               9 (GRADIENT)                                  -> 9 (GRADIENT)
             aten::detach                             9 (GRADIENT)                                  -> 9 (GRADIENT)
             aten::detach                             9 (GRADIENT)                                  -> ???
@@ -1287,16 +1287,16 @@
             -- Forward & loss ---------------------------------------------------------------------------------------
             aten::ones                                                                             -> 1 (INPUT)
             aten::t                                  2 (PARAMETER)                                 -> 2 (PARAMETER)
-            aten::addmm                              3 (PARAMETER), 1 (INPUT), 2 (PARAMETER)       -> 4 (???)
-            aten::sum                                4 (???)                                       -> 5 (???)
+            aten::addmm                              3 (PARAMETER), 1 (INPUT), 2 (PARAMETER)       -> 4 (ACTIVATION)
+            aten::sum                                4 (ACTIVATION)                                -> 5 (ACTIVATION)
 
             -- Backward ---------------------------------------------------------------------------------------------
-            aten::ones_like                          5 (???)                                       -> 6 (???)
-            aten::expand                             6 (???)                                       -> 6 (???)
-            aten::t                                  6 (???)                                       -> 6 (???)
-            aten::mm                                 6 (???), 1 (INPUT)                            -> 7 (GRADIENT)
+            aten::ones_like                          5 (ACTIVATION)                                -> 6 (ACTIVATION)
+            aten::expand                             6 (ACTIVATION)                                -> 6 (ACTIVATION)
+            aten::t                                  6 (ACTIVATION)                                -> 6 (ACTIVATION)
+            aten::mm                                 6 (ACTIVATION), 1 (INPUT)                     -> 7 (GRADIENT)
             aten::t                                  7 (GRADIENT)                                  -> 7 (GRADIENT)
-            aten::sum.dim_IntList                    6 (???)                                       -> 9 (GRADIENT)
+            aten::sum.dim_IntList                    6 (ACTIVATION)                                -> 9 (GRADIENT)
             aten::view                               9 (GRADIENT)                                  -> 9 (GRADIENT)
             aten::detach                             9 (GRADIENT)                                  -> 9 (GRADIENT)
             aten::detach                             9 (GRADIENT)                                  -> 9 (GRADIENT)
@@ -1329,13 +1329,13 @@
             """\
             aten::ones                                                                             -> 1 (INPUT)
             aten::t                                  2 (PARAMETER)                                 -> 2 (PARAMETER)
-            aten::addmm                              3 (PARAMETER), 1 (INPUT), 2 (PARAMETER)       -> 4 (???)
-            aten::relu                               4 (???)                                       -> 5 (???)
-            aten::detach                             5 (???)                                       -> ???
+            aten::addmm                              3 (PARAMETER), 1 (INPUT), 2 (PARAMETER)       -> 4 (ACTIVATION)
+            aten::relu                               4 (ACTIVATION)                                -> 5 (ACTIVATION)
+            aten::detach                             5 (ACTIVATION)                                -> ???
             aten::t                                  6 (PARAMETER)                                 -> 6 (PARAMETER)
-            aten::mm                                 5 (???), 6 (PARAMETER)                        -> 7 (???)
-            aten::_softmax                           7 (???)                                       -> 8 (???)
-            aten::detach                             8 (???)                                       -> ???""",
+            aten::mm                                 5 (ACTIVATION), 6 (PARAMETER)                 -> 7 (ACTIVATION)
+            aten::_softmax                           7 (ACTIVATION)                                -> 8 (ACTIVATION)
+            aten::detach                             8 (ACTIVATION)                                -> ???""",
         )
 
     def test_categories_e2e_sequential_fwd_bwd(self) -> None:
@@ -1367,40 +1367,40 @@
 
             -- Forward ----------------------------------------------------------------------------------------------
             aten::t                                  3 (PARAMETER)                                 -> 3 (PARAMETER)
-            aten::addmm                              4 (PARAMETER), 1 (INPUT), 3 (PARAMETER)       -> 5 (???)
-            aten::relu                               5 (???)                                       -> 6 (???)
-            aten::detach                             6 (???)                                       -> 6 (???)
+            aten::addmm                              4 (PARAMETER), 1 (INPUT), 3 (PARAMETER)       -> 5 (ACTIVATION)
+            aten::relu                               5 (ACTIVATION)                                -> 6 (ACTIVATION)
+            aten::detach                             6 (ACTIVATION)                                -> 6 (ACTIVATION)
             aten::t                                  7 (PARAMETER)                                 -> 7 (PARAMETER)
-            aten::mm                                 6 (???), 7 (PARAMETER)                        -> 8 (???)
-            aten::_softmax                           8 (???)                                       -> 9 (???)
-            aten::detach                             9 (???)                                       -> 9 (???)
+            aten::mm                                 6 (ACTIVATION), 7 (PARAMETER)                 -> 8 (ACTIVATION)
+            aten::_softmax                           8 (ACTIVATION)                                -> 9 (ACTIVATION)
+            aten::detach                             9 (ACTIVATION)                                -> 9 (ACTIVATION)
 
             -- Loss -------------------------------------------------------------------------------------------------
-            aten::sub.Tensor                         9 (???), 2 (INPUT)                            -> 10 (???)
-            aten::pow.Tensor_Scalar                  10 (???)                                      -> 11 (???)
-            aten::sum                                11 (???)                                      -> 12 (???)
-            aten::mean                               12 (???)                                      -> 13 (???)
+            aten::sub.Tensor                         9 (ACTIVATION), 2 (INPUT)                     -> 10 (ACTIVATION)
+            aten::pow.Tensor_Scalar                  10 (ACTIVATION)                               -> 11 (ACTIVATION)
+            aten::sum                                11 (ACTIVATION)                               -> 12 (ACTIVATION)
+            aten::mean                               12 (ACTIVATION)                               -> 13 (ACTIVATION)
 
             -- Backward ---------------------------------------------------------------------------------------------
-            aten::ones_like                          13 (???)                                      -> 16 (???)
-            aten::expand                             16 (???)                                      -> 16 (???)
-            aten::div.Scalar                         16 (???)                                      -> 19 (???)
+            aten::ones_like                          13 (ACTIVATION)                               -> 16 (ACTIVATION)
+            aten::expand                             16 (ACTIVATION)                               -> 16 (ACTIVATION)
+            aten::div.Scalar                         16 (ACTIVATION)                               -> 19 (???)
             aten::expand                             19 (???)                                      -> 19 (???)
-            aten::pow.Tensor_Scalar                  10 (???)                                      -> 20 (TEMPORARY)
+            aten::pow.Tensor_Scalar                  10 (ACTIVATION)                               -> 20 (TEMPORARY)
             aten::mul.Scalar                         20 (TEMPORARY)                                -> 23 (TEMPORARY)
             aten::mul.Tensor                         19 (???), 23 (TEMPORARY)                      -> 24 (???)
-            aten::detach                             9 (???)                                       -> 9 (???)
-            aten::_softmax_backward_data             24 (???), 9 (???)                             -> 25 (???)
+            aten::detach                             9 (ACTIVATION)                                -> 9 (ACTIVATION)
+            aten::_softmax_backward_data             24 (???), 9 (ACTIVATION)                      -> 25 (???)
             aten::t                                  25 (???)                                      -> 25 (???)
-            aten::mm                                 25 (???), 6 (???)                             -> 26 (GRADIENT)
+            aten::mm                                 25 (???), 6 (ACTIVATION)                      -> 26 (GRADIENT)
             aten::t                                  26 (GRADIENT)                                 -> 26 (GRADIENT)
             aten::t                                  7 (PARAMETER)                                 -> 7 (PARAMETER)
             aten::mm                                 25 (???), 7 (PARAMETER)                       -> 27 (???)
             aten::t                                  26 (GRADIENT)                                 -> 26 (GRADIENT)
             aten::detach                             26 (GRADIENT)                                 -> 26 (GRADIENT)
             aten::detach                             26 (GRADIENT)                                 -> ???
-            aten::detach                             6 (???)                                       -> 6 (???)
-            aten::threshold_backward                 27 (???), 6 (???)                             -> 28 (???)
+            aten::detach                             6 (ACTIVATION)                                -> 6 (ACTIVATION)
+            aten::threshold_backward                 27 (???), 6 (ACTIVATION)                      -> 28 (???)
             aten::t                                  28 (???)                                      -> 28 (???)
             aten::mm                                 28 (???), 1 (INPUT)                           -> 29 (GRADIENT)
             aten::t                                  29 (GRADIENT)                                 -> 29 (GRADIENT)
diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py
index 25107c5..cf06345 100644
--- a/torch/profiler/_memory_profiler.py
+++ b/torch/profiler/_memory_profiler.py
@@ -33,6 +33,7 @@
 class Category(enum.Enum):
     INPUT = enum.auto()
     TEMPORARY = enum.auto()
+    ACTIVATION = enum.auto()
     GRADIENT = enum.auto()
     PARAMETER = enum.auto()
 
@@ -559,6 +560,7 @@
         self._set_parameters_using_python_tracer()
         self._set_inputs()
         self._set_parameters_using_data_flow()
+        self._set_activations()
 
     def _is_gradient(self, *args, **kwargs) -> bool:
         return self._categories.get(*args, **kwargs) == Category.GRADIENT
@@ -760,3 +762,22 @@
         for key, _ in snapshot.keys():
             if key.id in parameter_keys:
                 self._categories.set_by_id(key, Category.PARAMETER)
+
+    def _set_activations(self) -> None:
+        """Flood the graph to identify activations."""
+
+        required = {Category.INPUT, Category.ACTIVATION}
+        also_allowed = {Category.PARAMETER, Category.TEMPORARY}
+        for node in self._data_flow_graph.flow_nodes:
+            inputs = {(key, value) for key, (_, value) in node.inputs.items()}
+            input_categories = {self._categories.get(*i) for i in inputs}
+
+            if (
+                (input_categories & required)
+                and not (input_categories - (required | also_allowed))
+                #
+                # Stop filling when we reach the backward pass.
+                and RecordScope.BACKWARD_FUNCTION not in get_scopes(node._event)
+            ):
+                for i in node.outputs.items():
+                    self._categories.setdefault_by_version(*i, Category.ACTIVATION)