[Profiler] Memory profiler part 9: Mark activations (#88924)
This is a fairly straightforward pass: start at inputs and flood fill until we reach the backward pass.
Differential Revision: [D40868662](https://our.internmc.facebook.com/intern/diff/D40868662/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88924
Approved by: https://github.com/chaekit
diff --git a/test/profiler/test_memory_profiler.py b/test/profiler/test_memory_profiler.py
index 6819109..6e42f33 100644
--- a/test/profiler/test_memory_profiler.py
+++ b/test/profiler/test_memory_profiler.py
@@ -1189,16 +1189,16 @@
aten::ones -> 2 (INPUT)
-- Forward & loss ---------------------------------------------------------------------------------------
- aten::mul.Tensor 1 (INPUT), 3 (PARAMETER) -> 4 (???)
- aten::mul.Tensor 1 (INPUT), 5 (PARAMETER) -> 6 (???)
- aten::cat 4 (???), 6 (???) -> 7 (???)
- aten::binary_cross_entropy_with_logits 7 (???), 2 (INPUT) -> 13 (???)
+ aten::mul.Tensor 1 (INPUT), 3 (PARAMETER) -> 4 (ACTIVATION)
+ aten::mul.Tensor 1 (INPUT), 5 (PARAMETER) -> 6 (ACTIVATION)
+ aten::cat 4 (ACTIVATION), 6 (ACTIVATION) -> 7 (ACTIVATION)
+ aten::binary_cross_entropy_with_logits 7 (ACTIVATION), 2 (INPUT) -> 13 (ACTIVATION)
-- Backward ---------------------------------------------------------------------------------------------
- aten::ones_like 13 (???) -> 16 (???)
- aten::sigmoid 7 (???) -> 17 (TEMPORARY)
+ aten::ones_like 13 (ACTIVATION) -> 16 (ACTIVATION)
+ aten::sigmoid 7 (ACTIVATION) -> 17 (TEMPORARY)
aten::sub.Tensor 17 (TEMPORARY), 2 (INPUT) -> 18 (TEMPORARY)
- aten::mul.Tensor 18 (TEMPORARY), 16 (???) -> 19 (???)
+ aten::mul.Tensor 18 (TEMPORARY), 16 (ACTIVATION) -> 19 (???)
aten::div_.Scalar 19 (???) -> 19 (???)
aten::slice.Tensor 19 (???) -> 19 (???)
aten::slice.Tensor 19 (???) -> 19 (???)
@@ -1227,7 +1227,7 @@
"""\
aten::ones -> 1 (INPUT)
aten::t 2 (PARAMETER) -> 2 (PARAMETER)
- aten::addmm 3 (PARAMETER), 1 (INPUT), 2 (PARAMETER) -> 4 (???)""",
+ aten::addmm 3 (PARAMETER), 1 (INPUT), 2 (PARAMETER) -> 4 (ACTIVATION)""",
)
def test_categories_e2e_simple_module_fwd_bwd(self) -> None:
@@ -1247,16 +1247,16 @@
-- Forward & loss ---------------------------------------------------------------------------------------
aten::ones -> 1 (INPUT)
aten::t 2 (PARAMETER) -> 2 (PARAMETER)
- aten::addmm 3 (PARAMETER), 1 (INPUT), 2 (PARAMETER) -> 4 (???)
- aten::sum 4 (???) -> 5 (???)
+ aten::addmm 3 (PARAMETER), 1 (INPUT), 2 (PARAMETER) -> 4 (ACTIVATION)
+ aten::sum 4 (ACTIVATION) -> 5 (ACTIVATION)
-- Backward ---------------------------------------------------------------------------------------------
- aten::ones_like 5 (???) -> 6 (???)
- aten::expand 6 (???) -> 6 (???)
- aten::t 6 (???) -> 6 (???)
- aten::mm 6 (???), 1 (INPUT) -> 7 (GRADIENT)
+ aten::ones_like 5 (ACTIVATION) -> 6 (ACTIVATION)
+ aten::expand 6 (ACTIVATION) -> 6 (ACTIVATION)
+ aten::t 6 (ACTIVATION) -> 6 (ACTIVATION)
+ aten::mm 6 (ACTIVATION), 1 (INPUT) -> 7 (GRADIENT)
aten::t 7 (GRADIENT) -> 7 (GRADIENT)
- aten::sum.dim_IntList 6 (???) -> 9 (GRADIENT)
+ aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT)
aten::view 9 (GRADIENT) -> 9 (GRADIENT)
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
aten::detach 9 (GRADIENT) -> ???
@@ -1287,16 +1287,16 @@
-- Forward & loss ---------------------------------------------------------------------------------------
aten::ones -> 1 (INPUT)
aten::t 2 (PARAMETER) -> 2 (PARAMETER)
- aten::addmm 3 (PARAMETER), 1 (INPUT), 2 (PARAMETER) -> 4 (???)
- aten::sum 4 (???) -> 5 (???)
+ aten::addmm 3 (PARAMETER), 1 (INPUT), 2 (PARAMETER) -> 4 (ACTIVATION)
+ aten::sum 4 (ACTIVATION) -> 5 (ACTIVATION)
-- Backward ---------------------------------------------------------------------------------------------
- aten::ones_like 5 (???) -> 6 (???)
- aten::expand 6 (???) -> 6 (???)
- aten::t 6 (???) -> 6 (???)
- aten::mm 6 (???), 1 (INPUT) -> 7 (GRADIENT)
+ aten::ones_like 5 (ACTIVATION) -> 6 (ACTIVATION)
+ aten::expand 6 (ACTIVATION) -> 6 (ACTIVATION)
+ aten::t 6 (ACTIVATION) -> 6 (ACTIVATION)
+ aten::mm 6 (ACTIVATION), 1 (INPUT) -> 7 (GRADIENT)
aten::t 7 (GRADIENT) -> 7 (GRADIENT)
- aten::sum.dim_IntList 6 (???) -> 9 (GRADIENT)
+ aten::sum.dim_IntList 6 (ACTIVATION) -> 9 (GRADIENT)
aten::view 9 (GRADIENT) -> 9 (GRADIENT)
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
aten::detach 9 (GRADIENT) -> 9 (GRADIENT)
@@ -1329,13 +1329,13 @@
"""\
aten::ones -> 1 (INPUT)
aten::t 2 (PARAMETER) -> 2 (PARAMETER)
- aten::addmm 3 (PARAMETER), 1 (INPUT), 2 (PARAMETER) -> 4 (???)
- aten::relu 4 (???) -> 5 (???)
- aten::detach 5 (???) -> ???
+ aten::addmm 3 (PARAMETER), 1 (INPUT), 2 (PARAMETER) -> 4 (ACTIVATION)
+ aten::relu 4 (ACTIVATION) -> 5 (ACTIVATION)
+ aten::detach 5 (ACTIVATION) -> ???
aten::t 6 (PARAMETER) -> 6 (PARAMETER)
- aten::mm 5 (???), 6 (PARAMETER) -> 7 (???)
- aten::_softmax 7 (???) -> 8 (???)
- aten::detach 8 (???) -> ???""",
+ aten::mm 5 (ACTIVATION), 6 (PARAMETER) -> 7 (ACTIVATION)
+ aten::_softmax 7 (ACTIVATION) -> 8 (ACTIVATION)
+ aten::detach 8 (ACTIVATION) -> ???""",
)
def test_categories_e2e_sequential_fwd_bwd(self) -> None:
@@ -1367,40 +1367,40 @@
-- Forward ----------------------------------------------------------------------------------------------
aten::t 3 (PARAMETER) -> 3 (PARAMETER)
- aten::addmm 4 (PARAMETER), 1 (INPUT), 3 (PARAMETER) -> 5 (???)
- aten::relu 5 (???) -> 6 (???)
- aten::detach 6 (???) -> 6 (???)
+ aten::addmm 4 (PARAMETER), 1 (INPUT), 3 (PARAMETER) -> 5 (ACTIVATION)
+ aten::relu 5 (ACTIVATION) -> 6 (ACTIVATION)
+ aten::detach 6 (ACTIVATION) -> 6 (ACTIVATION)
aten::t 7 (PARAMETER) -> 7 (PARAMETER)
- aten::mm 6 (???), 7 (PARAMETER) -> 8 (???)
- aten::_softmax 8 (???) -> 9 (???)
- aten::detach 9 (???) -> 9 (???)
+ aten::mm 6 (ACTIVATION), 7 (PARAMETER) -> 8 (ACTIVATION)
+ aten::_softmax 8 (ACTIVATION) -> 9 (ACTIVATION)
+ aten::detach 9 (ACTIVATION) -> 9 (ACTIVATION)
-- Loss -------------------------------------------------------------------------------------------------
- aten::sub.Tensor 9 (???), 2 (INPUT) -> 10 (???)
- aten::pow.Tensor_Scalar 10 (???) -> 11 (???)
- aten::sum 11 (???) -> 12 (???)
- aten::mean 12 (???) -> 13 (???)
+ aten::sub.Tensor 9 (ACTIVATION), 2 (INPUT) -> 10 (ACTIVATION)
+ aten::pow.Tensor_Scalar 10 (ACTIVATION) -> 11 (ACTIVATION)
+ aten::sum 11 (ACTIVATION) -> 12 (ACTIVATION)
+ aten::mean 12 (ACTIVATION) -> 13 (ACTIVATION)
-- Backward ---------------------------------------------------------------------------------------------
- aten::ones_like 13 (???) -> 16 (???)
- aten::expand 16 (???) -> 16 (???)
- aten::div.Scalar 16 (???) -> 19 (???)
+ aten::ones_like 13 (ACTIVATION) -> 16 (ACTIVATION)
+ aten::expand 16 (ACTIVATION) -> 16 (ACTIVATION)
+ aten::div.Scalar 16 (ACTIVATION) -> 19 (???)
aten::expand 19 (???) -> 19 (???)
- aten::pow.Tensor_Scalar 10 (???) -> 20 (TEMPORARY)
+ aten::pow.Tensor_Scalar 10 (ACTIVATION) -> 20 (TEMPORARY)
aten::mul.Scalar 20 (TEMPORARY) -> 23 (TEMPORARY)
aten::mul.Tensor 19 (???), 23 (TEMPORARY) -> 24 (???)
- aten::detach 9 (???) -> 9 (???)
- aten::_softmax_backward_data 24 (???), 9 (???) -> 25 (???)
+ aten::detach 9 (ACTIVATION) -> 9 (ACTIVATION)
+ aten::_softmax_backward_data 24 (???), 9 (ACTIVATION) -> 25 (???)
aten::t 25 (???) -> 25 (???)
- aten::mm 25 (???), 6 (???) -> 26 (GRADIENT)
+ aten::mm 25 (???), 6 (ACTIVATION) -> 26 (GRADIENT)
aten::t 26 (GRADIENT) -> 26 (GRADIENT)
aten::t 7 (PARAMETER) -> 7 (PARAMETER)
aten::mm 25 (???), 7 (PARAMETER) -> 27 (???)
aten::t 26 (GRADIENT) -> 26 (GRADIENT)
aten::detach 26 (GRADIENT) -> 26 (GRADIENT)
aten::detach 26 (GRADIENT) -> ???
- aten::detach 6 (???) -> 6 (???)
- aten::threshold_backward 27 (???), 6 (???) -> 28 (???)
+ aten::detach 6 (ACTIVATION) -> 6 (ACTIVATION)
+ aten::threshold_backward 27 (???), 6 (ACTIVATION) -> 28 (???)
aten::t 28 (???) -> 28 (???)
aten::mm 28 (???), 1 (INPUT) -> 29 (GRADIENT)
aten::t 29 (GRADIENT) -> 29 (GRADIENT)
diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py
index 25107c5..cf06345 100644
--- a/torch/profiler/_memory_profiler.py
+++ b/torch/profiler/_memory_profiler.py
@@ -33,6 +33,7 @@
class Category(enum.Enum):
INPUT = enum.auto()
TEMPORARY = enum.auto()
+ ACTIVATION = enum.auto()
GRADIENT = enum.auto()
PARAMETER = enum.auto()
@@ -559,6 +560,7 @@
self._set_parameters_using_python_tracer()
self._set_inputs()
self._set_parameters_using_data_flow()
+ self._set_activations()
def _is_gradient(self, *args, **kwargs) -> bool:
return self._categories.get(*args, **kwargs) == Category.GRADIENT
@@ -760,3 +762,22 @@
for key, _ in snapshot.keys():
if key.id in parameter_keys:
self._categories.set_by_id(key, Category.PARAMETER)
+
+ def _set_activations(self) -> None:
+ """Flood the graph to identify activations."""
+
+ required = {Category.INPUT, Category.ACTIVATION}
+ also_allowed = {Category.PARAMETER, Category.TEMPORARY}
+ for node in self._data_flow_graph.flow_nodes:
+ inputs = {(key, value) for key, (_, value) in node.inputs.items()}
+ input_categories = {self._categories.get(*i) for i in inputs}
+
+ if (
+ (input_categories & required)
+ and not (input_categories - (required | also_allowed))
+ #
+ # Stop filling when we reach the backward pass.
+ and RecordScope.BACKWARD_FUNCTION not in get_scopes(node._event)
+ ):
+ for i in node.outputs.items():
+ self._categories.setdefault_by_version(*i, Category.ACTIVATION)