| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-strict |
| |
| import unittest |
| from typing import get_args, List, Union |
| |
| import torch |
| from executorch.devtools.bundled_program.config import DataContainer |
| |
| from executorch.devtools.bundled_program.util.test_util import ( |
| get_random_test_suites, |
| get_random_test_suites_with_eager_model, |
| SampleModel, |
| ) |
| from executorch.extension.pytree import tree_flatten |
| |
| |
| class TestConfig(unittest.TestCase): |
| def assertTensorEqual(self, t1: torch.Tensor, t2: torch.Tensor) -> None: |
| self.assertTrue((t1 == t2).all()) |
| |
| def assertIOListEqual( |
| self, |
| tl1: List[Union[bool, float, int, torch.Tensor]], |
| tl2: List[Union[bool, float, int, torch.Tensor]], |
| ) -> None: |
| self.assertEqual(len(tl1), len(tl2)) |
| for t1, t2 in zip(tl1, tl2): |
| if isinstance(t1, torch.Tensor): |
| assert isinstance(t2, torch.Tensor) |
| self.assertTensorEqual(t1, t2) |
| else: |
| self.assertTrue(t1 == t2) |
| |
| def test_create_test_suites(self) -> None: |
| n_sets_per_plan_test = 10 |
| n_method_test_suites = 5 |
| |
| ( |
| rand_method_names, |
| rand_inputs, |
| rand_expected_outpus, |
| method_test_suites, |
| ) = get_random_test_suites( |
| n_model_inputs=2, |
| model_input_sizes=[[2, 2], [2, 2]], |
| n_model_outputs=1, |
| model_output_sizes=[[2, 2]], |
| dtype=torch.int32, |
| n_sets_per_plan_test=n_sets_per_plan_test, |
| n_method_test_suites=n_method_test_suites, |
| ) |
| |
| self.assertEqual(len(method_test_suites), n_method_test_suites) |
| |
| # Compare to see if bundled execution plan test match expectations. |
| for method_test_suite_idx in range(n_method_test_suites): |
| self.assertEqual( |
| method_test_suites[method_test_suite_idx].method_name, |
| rand_method_names[method_test_suite_idx], |
| ) |
| for testset_idx in range(n_sets_per_plan_test): |
| self.assertIOListEqual( |
| # pyre-ignore [6]: expected `List[Union[bool, float, int, Tensor]]` but got `Sequence[Union[bool, float, int, Tensor]] |
| rand_inputs[method_test_suite_idx][testset_idx], |
| method_test_suites[method_test_suite_idx] |
| .test_cases[testset_idx] |
| .inputs, |
| ) |
| self.assertIOListEqual( |
| # pyre-ignore [6]: expected `List[Union[bool, float, int, Tensor]]` but got `Sequence[Union[bool, float, int, Tensor]] |
| rand_expected_outpus[method_test_suite_idx][testset_idx], |
| method_test_suites[method_test_suite_idx] |
| .test_cases[testset_idx] |
| .expected_outputs, |
| ) |
| |
| def test_create_test_suites_from_eager_model(self) -> None: |
| n_sets_per_plan_test = 10 |
| eager_model = SampleModel() |
| method_names: List[str] = eager_model.method_names |
| |
| rand_inputs, method_test_suites = get_random_test_suites_with_eager_model( |
| eager_model=eager_model, |
| method_names=method_names, |
| n_model_inputs=2, |
| model_input_sizes=[[2, 2], [2, 2]], |
| dtype=torch.int32, |
| n_sets_per_plan_test=n_sets_per_plan_test, |
| ) |
| |
| self.assertEqual(len(method_test_suites), len(method_names)) |
| |
| # Compare to see if bundled testcases match expectations. |
| for method_test_suite_idx in range(len(method_names)): |
| self.assertEqual( |
| method_test_suites[method_test_suite_idx].method_name, |
| method_names[method_test_suite_idx], |
| ) |
| for testset_idx in range(n_sets_per_plan_test): |
| ri = rand_inputs[method_test_suite_idx][testset_idx] |
| self.assertIOListEqual( |
| # pyre-ignore [6]: expected `List[Union[bool, float, int, Tensor]]` but got `Sequence[Union[bool, float, int, Tensor]] |
| ri, |
| method_test_suites[method_test_suite_idx] |
| .test_cases[testset_idx] |
| .inputs, |
| ) |
| |
| model_outputs = getattr( |
| eager_model, method_names[method_test_suite_idx] |
| )(*ri) |
| if isinstance(model_outputs, get_args(DataContainer)): |
| # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. |
| flatten_eager_model_outputs = tree_flatten(model_outputs) |
| else: |
| flatten_eager_model_outputs = [ |
| model_outputs, |
| ] |
| |
| self.assertIOListEqual( |
| flatten_eager_model_outputs, |
| method_test_suites[method_test_suite_idx] |
| .test_cases[testset_idx] |
| .expected_outputs, |
| ) |