blob: 3b1e6ce86b5d9f12ef65f5f2f6106169c9b60aca [file]
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import platform
import shutil
import subprocess
import sys
import tempfile
from datetime import datetime
from enum import auto, Enum
from pathlib import Path
from typing import Any
import pytest
import torch
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
from executorch.exir.backend.compile_spec_schema import CompileSpec
class arm_test_options(Enum):
quantize_io = auto()
corstone300 = auto()
dump_path = auto()
date_format = auto()
fast_fvp = auto()
_test_options: dict[arm_test_options, Any] = {}
# ==== Pytest hooks ====
def pytest_addoption(parser):
parser.addoption("--arm_quantize_io", action="store_true")
parser.addoption("--arm_run_corstone300", action="store_true")
parser.addoption("--default_dump_path", default=None)
parser.addoption("--date_format", default="%d-%b-%H:%M:%S")
parser.addoption("--fast_fvp", action="store_true")
def pytest_configure(config):
if config.option.arm_quantize_io:
load_libquantized_ops_aot_lib()
_test_options[arm_test_options.quantize_io] = True
if config.option.arm_run_corstone300:
corstone300_exists = shutil.which("FVP_Corstone_SSE-300_Ethos-U55")
if not corstone300_exists:
raise RuntimeError(
"Tests are run with --arm_run_corstone300 but corstone300 FVP is not installed."
)
_test_options[arm_test_options.corstone300] = True
if config.option.default_dump_path:
dump_path = Path(config.option.default_dump_path).expanduser()
if dump_path.exists() and os.path.isdir(dump_path):
_test_options[arm_test_options.dump_path] = dump_path
else:
raise RuntimeError(
f"Supplied argument 'default_dump_path={dump_path}' that does not exist or is not a directory."
)
_test_options[arm_test_options.date_format] = config.option.date_format
_test_options[arm_test_options.fast_fvp] = config.option.fast_fvp
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
def pytest_collection_modifyitems(config, items):
if not config.option.arm_quantize_io:
skip_if_aot_lib_not_loaded = pytest.mark.skip(
"u55 tests can only run with quantize_io=True."
)
for item in items:
if "u55" in item.name:
item.add_marker(skip_if_aot_lib_not_loaded)
def pytest_sessionstart(session):
pass
def pytest_sessionfinish(session, exitstatus):
if get_option(arm_test_options.dump_path):
_clean_dir(
get_option(arm_test_options.dump_path),
f"ArmTester_{get_option(arm_test_options.date_format)}.log",
)
# ==== End of Pytest hooks =====
# ==== Custom Pytest decorators =====
def expectedFailureOnFVP(test_item):
if is_option_enabled("corstone300"):
test_item.__unittest_expecting_failure__ = True
return test_item
# ==== End of Custom Pytest decorators =====
def load_libquantized_ops_aot_lib():
so_ext = {
"Darwin": "dylib",
"Linux": "so",
"Windows": "dll",
}.get(platform.system(), None)
find_lib_cmd = [
"find",
"cmake-out-aot-lib",
"-name",
f"libquantized_ops_aot_lib.{so_ext}",
]
res = subprocess.run(find_lib_cmd, capture_output=True)
if res.returncode == 0:
library_path = res.stdout.decode().strip()
torch.ops.load_library(library_path)
def is_option_enabled(
option: str | arm_test_options, fail_if_not_enabled: bool = False
) -> bool:
"""
Returns whether an option is successfully enabled, i.e. if the flag was
given to pytest and the necessary requirements are available.
Implemented options are:
- corstone300.
- quantize_io.
The optional parameter 'fail_if_not_enabled' makes the function raise
a RuntimeError instead of returning False.
"""
if isinstance(option, str):
option = arm_test_options[option.lower()]
if option in _test_options and _test_options[option]:
return True
else:
if fail_if_not_enabled:
raise RuntimeError(f"Required option '{option}' for test is not enabled")
else:
return False
def get_option(option: arm_test_options) -> Any | None:
if option in _test_options:
return _test_options[option]
return None
def maybe_get_tosa_collate_path() -> str | None:
"""
Checks the environment variable TOSA_TESTCASES_BASE_PATH and returns the
path to the where to store the current tests if it is set.
"""
tosa_test_base = os.environ.get("TOSA_TESTCASES_BASE_PATH")
if tosa_test_base:
current_test = os.environ.get("PYTEST_CURRENT_TEST")
#'backends/arm/test/ops/test_mean_dim.py::TestMeanDim::test_meandim_tosa_BI_0_zeros (call)'
test_class = current_test.split("::")[1]
test_name = current_test.split("::")[-1].split(" ")[0]
if "BI" in test_name:
tosa_test_base = os.path.join(tosa_test_base, "tosa-bi")
elif "MI" in test_name:
tosa_test_base = os.path.join(tosa_test_base, "tosa-mi")
else:
tosa_test_base = os.path.join(tosa_test_base, "other")
return os.path.join(tosa_test_base, test_class, test_name)
return None
def get_tosa_compile_spec(
tosa_version: str, permute_memory_to_nhwc=True, custom_path=None
) -> list[CompileSpec]:
"""
Default compile spec for TOSA tests.
"""
return get_tosa_compile_spec_unbuilt(
tosa_version, permute_memory_to_nhwc, custom_path
).build()
def get_tosa_compile_spec_unbuilt(
tosa_version: str, permute_memory_to_nhwc=False, custom_path=None
) -> ArmCompileSpecBuilder:
"""Get the ArmCompileSpecBuilder for the default TOSA tests, to modify
the compile spec before calling .build() to finalize it.
"""
if not custom_path:
intermediate_path = maybe_get_tosa_collate_path() or tempfile.mkdtemp(
prefix="arm_tosa_"
)
else:
intermediate_path = custom_path
if not os.path.exists(intermediate_path):
os.makedirs(intermediate_path, exist_ok=True)
compile_spec_builder = (
ArmCompileSpecBuilder()
.tosa_compile_spec(tosa_version)
.set_permute_memory_format(permute_memory_to_nhwc)
.dump_intermediate_artifacts_to(intermediate_path)
)
return compile_spec_builder
def get_u55_compile_spec(
permute_memory_to_nhwc=True, quantize_io=False, custom_path=None
) -> list[CompileSpec]:
"""
Default compile spec for Ethos-U55 tests.
"""
return get_u55_compile_spec_unbuilt(
permute_memory_to_nhwc, quantize_io=quantize_io, custom_path=custom_path
).build()
def get_u85_compile_spec(
permute_memory_to_nhwc=True, quantize_io=False, custom_path=None
) -> list[CompileSpec]:
"""
Default compile spec for Ethos-U85 tests.
"""
return get_u85_compile_spec_unbuilt(
permute_memory_to_nhwc, quantize_io=quantize_io, custom_path=custom_path
).build()
def get_u55_compile_spec_unbuilt(
permute_memory_to_nhwc=True, quantize_io=False, custom_path=None
) -> ArmCompileSpecBuilder:
"""Get the ArmCompileSpecBuilder for the Ethos-U55 tests, to modify
the compile spec before calling .build() to finalize it.
"""
artifact_path = custom_path or tempfile.mkdtemp(prefix="arm_u55_")
if not os.path.exists(artifact_path):
os.makedirs(artifact_path, exist_ok=True)
compile_spec = (
ArmCompileSpecBuilder()
.ethosu_compile_spec(
"ethos-u55-128",
system_config="Ethos_U55_High_End_Embedded",
memory_mode="Shared_Sram",
extra_flags="--debug-force-regor --output-format=raw",
)
.set_quantize_io(is_option_enabled("quantize_io") or quantize_io)
.set_permute_memory_format(permute_memory_to_nhwc)
.dump_intermediate_artifacts_to(artifact_path)
)
return compile_spec
def get_u85_compile_spec_unbuilt(
permute_memory_to_nhwc=True, quantize_io=False, custom_path=None
) -> list[CompileSpec]:
"""Get the ArmCompileSpecBuilder for the Ethos-U85 tests, to modify
the compile spec before calling .build() to finalize it.
"""
artifact_path = custom_path or tempfile.mkdtemp(prefix="arm_u85_")
compile_spec = (
ArmCompileSpecBuilder()
.ethosu_compile_spec(
"ethos-u85-128",
system_config="Ethos_U85_SYS_DRAM_Mid",
memory_mode="Shared_Sram",
extra_flags="--output-format=raw",
)
.set_quantize_io(is_option_enabled("quantize_io") or quantize_io)
.set_permute_memory_format(permute_memory_to_nhwc)
.dump_intermediate_artifacts_to(artifact_path)
)
return compile_spec
def current_time_formated() -> str:
"""Return current time as a formated string"""
return datetime.now().strftime(get_option(arm_test_options.date_format))
def _clean_dir(dir: Path, filter: str, num_save=10):
sorted_files: list[tuple[datetime, Path]] = []
for file in dir.iterdir():
try:
creation_time = datetime.strptime(file.name, filter)
insert_index = -1
for i, to_compare in enumerate(sorted_files):
compare_time = to_compare[0]
if creation_time < compare_time:
insert_index = i
break
if insert_index == -1 and len(sorted_files) < num_save:
sorted_files.append((creation_time, file))
else:
sorted_files.insert(insert_index, (creation_time, file))
except ValueError:
continue
if len(sorted_files) > num_save:
for remove in sorted_files[0 : len(sorted_files) - num_save]:
file = remove[1]
file.unlink()