blob: 716e8daee237e13e3cd25856420b5d1d0a758618 [file]
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
#
# Main implementation of AoT flow to partition and preprocess for Arm target
# backends. Converts via TOSA as an intermediate form supported by AoT and
# JIT compiler flows.
#
import re
from typing import List
from executorch.exir.backend.compile_spec_schema import CompileSpec
from packaging.version import Version
class TosaSpecification:
"""
This class implements a representation of TOSA specification
(https://www.mlplatform.org/tosa/tosa_spec.html) with a version, a profile
(with extension) and a level (8k).
For 0.80 releases the profile is BI or MI, with u55 handled as an inofficial extension
For 1.00 releases the profile is INT or FP, and the extensions are for
INT: int16, int4, var, cf
FP: bf16, fp8e4m3, fp8e5m2, fft, var, cf
The TOSA specification is encoded in the string represenatation
TOSA-major.minor.patch+profile[+level][+extensions]
For 0.80 MI implies BI, while for 1.0 the profiles has to explicitely be specified.
Profiles are uppercase letters and extensions and level is lowercase.
"""
version: Version
def support_integer(self) -> bool:
"""
Returns true if any integer operations are supported for the specification.
"""
raise NotImplementedError
def support_float(self) -> bool:
"""
Returns true if any float operations are supported for the specification.
"""
raise NotImplementedError
def __init__(self, version: Version):
self.version = version
@staticmethod
def create_from_compilespecs(
compile_specs: List[CompileSpec],
) -> "TosaSpecification":
"""
Search the CompileSpec list for 'tosa_version' and instantiate a
class from the found value or return None on failure.
"""
for spec in compile_specs:
if spec.key == "tosa_version":
return TosaSpecification.create_from_string(spec.value.decode())
raise ValueError(
"No TOSA version key found in any of the supplied CompileSpecs"
)
@staticmethod
def create_from_string(repr: str) -> "TosaSpecification":
"""
Creates a TOSA specification class from a string representation:
TOSA-0.80.0+MI
TOSA-0.80.0+BI+8k
TOSA-0.80.0+BI+u55 # Ethos-U55 extension to handle TOSA subset
TOSA-0.90.0+MI
TOSA-1.00.0+INT+FP+int4+cf
"""
pattern = r"^(TOSA)-([\d.]+)\+(.+)$"
match = re.match(pattern, repr)
if match:
name = match.group(1)
version = Version(match.group(2))
extras = match.group(3).split("+")
if name != "TOSA":
raise ValueError(f"Malformed TOSA specification representation: {repr}")
match version:
case _ if version.major == 0 and version.minor == 80:
return Tosa_0_80(version, extras)
case _ if version.major == 1 and version.minor == 0:
return Tosa_1_00(version, extras)
case _:
raise ValueError(f"Wrong TOSA version: {version} from {repr}")
raise ValueError(f"Failed to parse TOSA specification representation: {repr}")
class Tosa_0_80(TosaSpecification):
profile: str
level_8k: bool
is_U55_subset: bool
available_profiles = ["BI", "MI"] # MT is not defined
def __init__(self, version: Version, extras: List[str]):
super().__init__(version)
assert version >= Version("0.80") and version < Version("0.90")
# Check that we only have one profile in the extensions list
if [e in Tosa_0_80.available_profiles for e in extras].count(True) != 1:
raise ValueError(
f"Bad combination of extras: {extras}, more than one of {Tosa_0_80.available_profiles} found."
)
# The list contains one profile at most, so pick it
self.profile = [e for e in extras if e in Tosa_0_80.available_profiles][0]
extras.remove(self.profile)
self.level_8k = "8k" in extras
if self.level_8k:
extras.remove("8k")
self.is_U55_subset = "u55" in extras
if self.is_U55_subset:
extras.remove("u55")
if len(extras) > 0:
raise ValueError(f"Unhandled extras found: {extras}")
def __repr__(self):
extensions = ""
if self.level_8k:
extensions += "+8K"
if self.is_U55_subset:
extensions += "+u55"
return f"TOSA-{str(self.version)}+{self.profile}{extensions}"
def __hash__(self) -> int:
return hash(str(self.version) + self.profile)
def __eq__(self, other: object) -> bool:
if isinstance(other, Tosa_0_80):
return (self.version == other.version) and (self.profile == other.profile)
return False
def support_integer(self):
return True
def support_float(self):
return self.profile == "MI"
class Tosa_1_00(TosaSpecification):
profiles: List[str]
level_8k: bool
extensions: List[str]
available_profiles = ["INT", "FP"]
valid_extensions = {
"INT": ["int16", "int4", "var", "cf"],
"FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"],
}
def __init__(self, version: Version, extras: List[str]):
super().__init__(version)
# Check that we have at least one profile in the extensions list
if [e in Tosa_1_00.available_profiles for e in extras].count(True) == 0:
raise ValueError(
f"No profile ({Tosa_1_00.available_profiles}) found in: {extras}."
)
# and not more than number of available profiles
if [e in Tosa_1_00.available_profiles for e in extras].count(True) > len(
Tosa_1_00.available_profiles
):
raise ValueError(
f"Too many profiles ({Tosa_1_00.available_profiles}) found in: {extras}."
)
# The list contains one profile at least, so pick them
self.profiles = [e for e in extras if e in Tosa_1_00.available_profiles]
for p in self.profiles:
extras.remove(p)
self.level_8k = "8k" in extras
if self.level_8k:
extras.remove("8k")
combined_extensions = []
for p in self.profiles:
combined_extensions += Tosa_1_00.valid_extensions[p]
if not all(e in combined_extensions for e in extras):
raise ValueError(
f"Bad extensions for TOSA-{version}{self._get_profiles_string()}: {extras}"
)
# all the rest of the extras are handled extensions
self.extensions = extras
def _get_profiles_string(self) -> str:
return "".join(["+" + p for p in self.profiles])
def _get_extensions_string(self) -> str:
return "".join(["+" + e for e in self.extensions])
def __repr__(self):
return f"TOSA-{self.version}{self._get_profiles_string()}{self._get_profiles_string()}"
def __hash__(self) -> int:
return hash(str(self.version) + self._get_profiles_string())
def __eq__(self, other: object) -> bool:
if isinstance(other, Tosa_1_00):
return (self.version == other.version) and (
self._get_profiles_string() == other._get_profiles_string()
)
return False
def support_integer(self):
return "INT" in self.profiles
def support_float(self):
return "FP" in self.profiles