blob: 21f3d48042366d526a32d1cf727357c1b0e51cf8 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import unittest
from typing import get_args, List, Union
import torch
from executorch.devtools.bundled_program.config import DataContainer
from executorch.devtools.bundled_program.util.test_util import (
get_random_test_suites,
get_random_test_suites_with_eager_model,
SampleModel,
)
from executorch.extension.pytree import tree_flatten
class TestConfig(unittest.TestCase):
def assertTensorEqual(self, t1: torch.Tensor, t2: torch.Tensor) -> None:
self.assertTrue((t1 == t2).all())
def assertIOListEqual(
self,
tl1: List[Union[bool, float, int, torch.Tensor]],
tl2: List[Union[bool, float, int, torch.Tensor]],
) -> None:
self.assertEqual(len(tl1), len(tl2))
for t1, t2 in zip(tl1, tl2):
if isinstance(t1, torch.Tensor):
assert isinstance(t2, torch.Tensor)
self.assertTensorEqual(t1, t2)
else:
self.assertTrue(t1 == t2)
def test_create_test_suites(self) -> None:
n_sets_per_plan_test = 10
n_method_test_suites = 5
(
rand_method_names,
rand_inputs,
rand_expected_outpus,
method_test_suites,
) = get_random_test_suites(
n_model_inputs=2,
model_input_sizes=[[2, 2], [2, 2]],
n_model_outputs=1,
model_output_sizes=[[2, 2]],
dtype=torch.int32,
n_sets_per_plan_test=n_sets_per_plan_test,
n_method_test_suites=n_method_test_suites,
)
self.assertEqual(len(method_test_suites), n_method_test_suites)
# Compare to see if bundled execution plan test match expectations.
for method_test_suite_idx in range(n_method_test_suites):
self.assertEqual(
method_test_suites[method_test_suite_idx].method_name,
rand_method_names[method_test_suite_idx],
)
for testset_idx in range(n_sets_per_plan_test):
self.assertIOListEqual(
# pyre-ignore [6]: expected `List[Union[bool, float, int, Tensor]]` but got `Sequence[Union[bool, float, int, Tensor]]
rand_inputs[method_test_suite_idx][testset_idx],
method_test_suites[method_test_suite_idx]
.test_cases[testset_idx]
.inputs,
)
self.assertIOListEqual(
# pyre-ignore [6]: expected `List[Union[bool, float, int, Tensor]]` but got `Sequence[Union[bool, float, int, Tensor]]
rand_expected_outpus[method_test_suite_idx][testset_idx],
method_test_suites[method_test_suite_idx]
.test_cases[testset_idx]
.expected_outputs,
)
def test_create_test_suites_from_eager_model(self) -> None:
n_sets_per_plan_test = 10
eager_model = SampleModel()
method_names: List[str] = eager_model.method_names
rand_inputs, method_test_suites = get_random_test_suites_with_eager_model(
eager_model=eager_model,
method_names=method_names,
n_model_inputs=2,
model_input_sizes=[[2, 2], [2, 2]],
dtype=torch.int32,
n_sets_per_plan_test=n_sets_per_plan_test,
)
self.assertEqual(len(method_test_suites), len(method_names))
# Compare to see if bundled testcases match expectations.
for method_test_suite_idx in range(len(method_names)):
self.assertEqual(
method_test_suites[method_test_suite_idx].method_name,
method_names[method_test_suite_idx],
)
for testset_idx in range(n_sets_per_plan_test):
ri = rand_inputs[method_test_suite_idx][testset_idx]
self.assertIOListEqual(
# pyre-ignore [6]: expected `List[Union[bool, float, int, Tensor]]` but got `Sequence[Union[bool, float, int, Tensor]]
ri,
method_test_suites[method_test_suite_idx]
.test_cases[testset_idx]
.inputs,
)
model_outputs = getattr(
eager_model, method_names[method_test_suite_idx]
)(*ri)
if isinstance(model_outputs, get_args(DataContainer)):
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
flatten_eager_model_outputs = tree_flatten(model_outputs)
else:
flatten_eager_model_outputs = [
model_outputs,
]
self.assertIOListEqual(
flatten_eager_model_outputs,
method_test_suites[method_test_suite_idx]
.test_cases[testset_idx]
.expected_outputs,
)