| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-strict |
| |
| import typing |
| import unittest |
| from typing import Any, get_args, List, Union |
| |
| import torch |
| from executorch.bundled_program.config import BundledConfig, DataContainer |
| from executorch.bundled_program.schema import BundledAttachment |
| |
| from executorch.bundled_program.tests.common import ( |
| get_random_config, |
| get_random_config_with_eager_model, |
| MISOModel, |
| ) |
| from executorch.extension.pytree import tree_flatten |
| |
| |
| class TestConfig(unittest.TestCase): |
| def assertTensorEqual(self, t1: torch.Tensor, t2: torch.Tensor) -> None: |
| self.assertTrue((t1 == t2).all()) |
| |
| def assertIOListEqual( |
| self, |
| tl1: List[Union[bool, float, int, torch.Tensor]], |
| tl2: List[Union[bool, float, int, torch.Tensor]], |
| ) -> None: |
| self.assertEqual(len(tl1), len(tl2)) |
| for t1, t2 in zip(tl1, tl2): |
| if type(t1) == torch.Tensor: |
| assert type(t1) == type(t2) |
| # pyre-fixme[6]: For 2nd argument expected `Tensor` but got |
| # `Union[bool, float, int, Tensor]`. |
| self.assertTensorEqual(t1, t2) |
| else: |
| self.assertTrue(t1 == t2) |
| |
| def assertAttachmentDictEqual( |
| self, |
| attachments: List[BundledAttachment], |
| kwargs: typing.Dict[str, Any], |
| ) -> None: |
| self.assertEqual(len(attachments), len(kwargs)) |
| for attachment, k in zip(attachments, kwargs): |
| self.assertEqual(attachment.key, k) |
| self.assertEqual( |
| attachment.val, |
| BundledConfig.convert_prim_val_to_attachement_val(kwargs[k]), |
| ) |
| |
| def test_create_config(self) -> None: |
| n_sets_per_plan_test = 10 |
| n_execution_plan_tests = 5 |
| |
| ( |
| rand_inputs, |
| rand_expected_outpus, |
| metadatas, |
| attachment, |
| bundled_config, |
| ) = get_random_config( |
| n_model_inputs=2, |
| model_input_sizes=[[2, 2], [2, 2]], |
| n_model_outputs=1, |
| model_output_sizes=[[2, 2]], |
| dtype=torch.int32, |
| n_sets_per_plan_test=n_sets_per_plan_test, |
| n_execution_plan_tests=n_execution_plan_tests, |
| ) |
| |
| self.assertEqual( |
| len(bundled_config.execution_plan_tests), n_execution_plan_tests |
| ) |
| |
| # Compare to see if bundled execution plan test match expectations. |
| for plan_test_idx in range(n_execution_plan_tests): |
| for testset_idx in range(n_sets_per_plan_test): |
| self.assertIOListEqual( |
| # pyre-ignore |
| rand_inputs[plan_test_idx][testset_idx], |
| bundled_config.execution_plan_tests[plan_test_idx] |
| .test_sets[testset_idx] |
| .inputs, |
| ) |
| self.assertIOListEqual( |
| # pyre-ignore |
| rand_expected_outpus[plan_test_idx][testset_idx], |
| bundled_config.execution_plan_tests[plan_test_idx] |
| .test_sets[testset_idx] |
| .expected_outputs, |
| ) |
| self.assertAttachmentDictEqual( |
| bundled_config.execution_plan_tests[plan_test_idx].metadata, |
| metadatas[plan_test_idx], |
| ) |
| |
| self.assertAttachmentDictEqual(bundled_config.attachments, attachment) |
| |
| def test_create_config_from_eager_model(self) -> None: |
| n_sets_per_plan_test = 10 |
| n_execution_plan_tests = 5 |
| eager_model = MISOModel() |
| |
| rand_inputs, bundled_config = get_random_config_with_eager_model( |
| eager_model=eager_model, |
| n_model_inputs=2, |
| model_input_sizes=[[2, 2], [2, 2]], |
| dtype=torch.int32, |
| n_sets_per_plan_test=n_sets_per_plan_test, |
| n_execution_plan_tests=n_execution_plan_tests, |
| ) |
| |
| self.assertEqual( |
| len(bundled_config.execution_plan_tests), n_execution_plan_tests |
| ) |
| |
| # Compare to see if bundled testcases match expectations. |
| for plan_test_idx in range(n_execution_plan_tests): |
| for testset_idx in range(n_sets_per_plan_test): |
| ri = rand_inputs[plan_test_idx][testset_idx] |
| self.assertIOListEqual( |
| # pyre-ignore[6] |
| ri, |
| bundled_config.execution_plan_tests[plan_test_idx] |
| .test_sets[testset_idx] |
| .inputs, |
| ) |
| |
| model_outputs = eager_model(*ri) |
| if isinstance(model_outputs, get_args(DataContainer)): |
| # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. |
| flatten_eager_model_outputs = tree_flatten(model_outputs) |
| else: |
| flatten_eager_model_outputs = [ |
| model_outputs, |
| ] |
| |
| self.assertIOListEqual( |
| flatten_eager_model_outputs, |
| bundled_config.execution_plan_tests[plan_test_idx] |
| .test_sets[testset_idx] |
| .expected_outputs, |
| ) |