blob: f5df82176ee9e916a4ab946453dbc0d0f4ef21f5 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""End-to-end profiler tests.
This must be built and run with `buck2 -c executorch.prof_enabled=true`.
"""
import unittest
import torch
from executorch.exir import to_edge
from executorch.extension.pybindings.portable_lib import (
_create_profile_block,
_dump_profile_results,
_load_for_executorch_from_buffer,
_reset_profile_results,
)
from executorch.extension.pytree import tree_flatten
from executorch.profiler.fb.parse_profiler_results import profile_table
from executorch.profiler.parse_profiler_results import (
deserialize_profile_results,
profile_aggregate_framework_tax,
profile_framework_tax_table,
)
from torch.export import export
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("a", 3 * torch.ones(2, 2, dtype=torch.float))
self.register_buffer("b", 2 * torch.ones(2, 2, dtype=torch.float))
def forward(self, x):
a = torch.mul(self.a, x)
b = torch.add(a, self.b)
return b
class TestCustomOps(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
model = Module()
inputs = (torch.ones(2, 2, dtype=torch.float),)
# The serialized program file. This must live longer than cls.module,
# because the C++ pybindings will have a pointer to it. But none of the
# tests should need to touch it.
cls.__buffer: bytes = to_edge(export(model, inputs)).to_executorch().buffer
cls.module = _load_for_executorch_from_buffer(cls.__buffer)
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
cls.inputs_flattened, _ = tree_flatten(inputs)
cls.module.run_method("forward", tuple(cls.inputs_flattened))
prof_dump = _dump_profile_results()
assert (
len(prof_dump) > 0
), "prof_dump is empty; may need to build with `-c executorch.prof_enabled=true`"
cls.prof_results, cls.mem_results = deserialize_profile_results(prof_dump)
cls.expect_ops = ["native_call_add.out", "native_call_mul.out"]
def test_profiler_new_block(self) -> None:
block_names = ["block_1", "block_2"]
_reset_profile_results()
_create_profile_block(block_names[0])
self.module.run_method("forward", tuple(self.inputs_flattened))
_create_profile_block(block_names[1])
self.module.run_method("forward", tuple(self.inputs_flattened))
prof_dump = _dump_profile_results()
self.assertGreater(
len(prof_dump),
0,
"prof_dump is empty; may need to build with `-c executorch.prof_enabled=true`",
)
prof_results, mem_results = deserialize_profile_results(prof_dump)
for i, (block_name_, _) in enumerate(prof_results.items()):
self.assertTrue(block_names[i] == block_name_)
self.assertEqual(len(prof_results), 2)
def test_profiler_expected_ops(self) -> None:
found_count = 0
for block_name, prof_data_list in self.prof_results.items():
for prof_event in prof_data_list:
if prof_event.name in self.expect_ops:
found_count += 1
self.assertTrue(block_name == "default")
self.assertEqual(found_count, len(self.expect_ops))
def test_profile_framework_tax(self) -> None:
prof_agg_data = profile_aggregate_framework_tax(self.prof_results)
for name, framework_tax in prof_agg_data.items():
self.assertTrue(len(framework_tax.exec_time) == 1)
self.assertTrue(len(framework_tax.kernel_and_delegate_time) == 1)
self.assertTrue(len(framework_tax.framework_tax) == 1)
self.assertTrue(float(framework_tax.framework_tax[0]) < 100)
self.assertTrue(name == "default")
def test_gen_profile_table(self) -> None:
prof_table = profile_table(self.prof_results)
found_count = 0
for table in prof_table:
for entry in table:
for op in self.expect_ops:
found_count += 1 if op in entry.get_string() else 0
self.assertEqual(found_count, len(self.expect_ops))
def test_gen_profile_framework_tax_table(self) -> None:
prof_agg_data = profile_aggregate_framework_tax(self.prof_results)
prof_framework_tax_table = profile_framework_tax_table(prof_agg_data)
expected_entries = [
"Model execution time",
"Time spent in kernels",
"Framework tax",
]
found_count = 0
for table in prof_framework_tax_table:
for entry in table:
for expected_entry in expected_entries:
found_count += 1 if expected_entry in entry.get_string() else 0
self.assertEqual(found_count, len(expected_entries))
def main() -> None:
unittest.main()
if __name__ == "__main__":
main() # pragma: no cover