blob: 94dd63ea572538c383550938fc671f1c48ec24b0 [file] [log] [blame]
#!/usr/bin/env python3
# Copyright © 2020 Arm Ltd. All rights reserved.
# Copyright © 2020 NXP and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
"""Python bindings for Arm NN
PyArmNN is a python extension for Arm NN SDK providing an interface similar to Arm NN C++ API.
"""
__version__ = None
__arm_ml_version__ = None
import logging
import os
import sys
import subprocess
from functools import lru_cache
from pathlib import Path
from itertools import chain
from setuptools import setup
from distutils.core import Extension
from setuptools.command.build_py import build_py
from setuptools.command.build_ext import build_ext
logger = logging.Logger(__name__)
DOCLINES = __doc__.split("\n")
LIB_ENV_NAME = "ARMNN_LIB"
INCLUDE_ENV_NAME = "ARMNN_INCLUDE"
def check_armnn_version(*args):
pass
__current_dir = os.path.dirname(os.path.realpath(__file__))
exec(open(os.path.join(__current_dir, 'src', 'pyarmnn', '_version.py'), encoding="utf-8").read())
class ExtensionPriorityBuilder(build_py):
"""Runs extension builder before other stages. Otherwise generated files are not included to the distribution.
"""
def run(self):
self.run_command('build_ext')
return super().run()
class ArmnnVersionCheckerExtBuilder(build_ext):
"""Builds an extension (i.e. wrapper). Additionally checks for version.
"""
def __init__(self, dist):
super().__init__(dist)
self.failed_ext = []
def build_extension(self, ext):
if ext.optional:
try:
super().build_extension(ext)
except Exception as err:
self.failed_ext.append(ext)
logger.warning('Failed to build extension %s. \n %s', ext.name, str(err))
else:
super().build_extension(ext)
if ext.name == 'pyarmnn._generated._pyarmnn_version':
sys.path.append(os.path.abspath(os.path.join(self.build_lib, str(Path(ext._file_name).parent))))
from _pyarmnn_version import GetVersion
check_armnn_version(GetVersion(), __arm_ml_version__)
def copy_extensions_to_source(self):
for ext in self.failed_ext:
self.extensions.remove(ext)
super().copy_extensions_to_source()
def linux_gcc_name():
"""Returns the name of the `gcc` compiler. Might happen that we are cross-compiling and the
compiler has a longer name.
Args:
None
Returns:
str: Name of the `gcc` compiler or None
"""
cc_env = os.getenv('CC')
if cc_env is not None:
if subprocess.Popen([cc_env, "--version"], stdout=subprocess.DEVNULL):
return cc_env
return "gcc" if subprocess.Popen(["gcc", "--version"], stdout=subprocess.DEVNULL) else None
def linux_gcc_lib_search(gcc_compiler_name: str = linux_gcc_name()):
"""Calls the `gcc` to get linker default system paths.
Args:
gcc_compiler_name(str): Name of the GCC compiler
Returns:
list: A list of paths.
Raises:
RuntimeError: If unable to find GCC.
"""
if gcc_compiler_name is None:
raise RuntimeError("Unable to find gcc compiler")
cmd1 = subprocess.Popen([gcc_compiler_name, "--print-search-dirs"], stdout=subprocess.PIPE)
cmd2 = subprocess.Popen(["grep", "libraries"], stdin=cmd1.stdout,
stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
cmd1.stdout.close()
out, _ = cmd2.communicate()
out = out.decode("utf-8").split('=')
return tuple(out[1].split(':')) if len(out) > 0 else None
def find_includes(armnn_include_env: str = INCLUDE_ENV_NAME):
"""Searches for ArmNN includes.
Args:
armnn_include_env(str): Environmental variable to use as path.
Returns:
list: A list of paths to include.
"""
armnn_include_path = os.getenv(armnn_include_env)
if armnn_include_path is not None and os.path.exists(armnn_include_path):
armnn_include_path = [armnn_include_path]
else:
armnn_include_path = ['/usr/local/include', '/usr/include']
return armnn_include_path
@lru_cache(maxsize=1)
def find_armnn(lib_name: str,
optional: bool = False,
armnn_libs_env: str = LIB_ENV_NAME,
default_lib_search: tuple = linux_gcc_lib_search()):
"""Searches for ArmNN installation on the local machine.
Args:
lib_name(str): Lib name to find.
optional(bool): Do not fail if optional. Default is False - fail if library was not found.
armnn_libs_env(str): Custom environment variable pointing to ArmNN libraries location, default is 'ARMNN_LIBS'
default_lib_search(tuple): list of paths to search for ArmNN if not found within path provided by 'ARMNN_LIBS'
env variable
Returns:
tuple: Contains name of the armnn libs, paths to the libs.
Raises:
RuntimeError: If armnn libs are not found.
"""
armnn_lib_path = os.getenv(armnn_libs_env)
lib_search = [armnn_lib_path] if armnn_lib_path is not None else default_lib_search
armnn_libs = dict(map(lambda path: (':{}'.format(path.name), path),
chain.from_iterable(map(lambda lib_path: Path(lib_path).glob(lib_name),
lib_search))))
if not optional and len(armnn_libs) == 0:
raise RuntimeError("""ArmNN library {} was not found in {}. Please install ArmNN to one of the standard
locations or set correct ARMNN_INCLUDE and ARMNN_LIB env variables.""".format(lib_name,
lib_search))
if optional and len(armnn_libs) == 0:
logger.warning("""Optional parser library %s was not found in %s and will not be installed.""", lib_name,
lib_search)
# gives back tuple of names of the libs, set of unique libs locations and includes.
return list(armnn_libs.keys()), list(set(
map(lambda path: str(path.absolute().parent), armnn_libs.values())))
class LazyArmnnFinderExtension(Extension):
"""Derived from `Extension` this class adds ArmNN libraries search on the user's machine.
SWIG options and compilation flags are updated with relevant ArmNN libraries files locations (-L) and headers (-I).
Search for ArmNN is executed only when attributes include_dirs, library_dirs, runtime_library_dirs, libraries or
swig_opts are queried.
"""
def __init__(self, name, sources, armnn_libs, include_dirs=None, define_macros=None, undef_macros=None,
library_dirs=None,
libraries=None, runtime_library_dirs=None, extra_objects=None, extra_compile_args=None,
extra_link_args=None, export_symbols=None, language=None, optional=None, **kw):
self._include_dirs = None
self._library_dirs = None
self._runtime_library_dirs = None
self._armnn_libs = armnn_libs
self._optional = optional[0]
# self.__swig_opts = None
super().__init__(name, sources, include_dirs, define_macros, undef_macros, library_dirs, libraries,
runtime_library_dirs, extra_objects, extra_compile_args, extra_link_args, export_symbols,
language, optional, **kw)
@property
def include_dirs(self):
return self._include_dirs + find_includes()
@include_dirs.setter
def include_dirs(self, include_dirs):
self._include_dirs = include_dirs
@property
def library_dirs(self):
library_dirs = self._library_dirs
for lib in self._armnn_libs:
_, lib_path = find_armnn(lib, self._optional)
library_dirs = library_dirs + lib_path
return library_dirs
@library_dirs.setter
def library_dirs(self, library_dirs):
self._library_dirs = library_dirs
@property
def runtime_library_dirs(self):
library_dirs = self._runtime_library_dirs
for lib in self._armnn_libs:
_, lib_path = find_armnn(lib, self._optional)
library_dirs = library_dirs + lib_path
return library_dirs
@runtime_library_dirs.setter
def runtime_library_dirs(self, runtime_library_dirs):
self._runtime_library_dirs = runtime_library_dirs
@property
def libraries(self):
libraries = self._libraries
for lib in self._armnn_libs:
lib_names, _ = find_armnn(lib, self._optional)
libraries = libraries + lib_names
return libraries
@libraries.setter
def libraries(self, libraries):
self._libraries = libraries
def __eq__(self, other):
return self.__class__ == other.__class__ and self.name == other.name
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return self.name.__hash__()
if __name__ == '__main__':
# mandatory extensions
pyarmnn_module = LazyArmnnFinderExtension('pyarmnn._generated._pyarmnn',
sources=['src/pyarmnn/_generated/armnn_wrap.cpp'],
extra_compile_args=['-std=c++14'],
language='c++',
armnn_libs=['libarmnn.so'],
optional=[False]
)
pyarmnn_v_module = LazyArmnnFinderExtension('pyarmnn._generated._pyarmnn_version',
sources=['src/pyarmnn/_generated/armnn_version_wrap.cpp'],
extra_compile_args=['-std=c++14'],
language='c++',
armnn_libs=['libarmnn.so'],
optional=[False]
)
extensions_to_build = [pyarmnn_v_module, pyarmnn_module]
# optional extensions
def add_parsers_ext(name: str, ext_list: list):
pyarmnn_optional_module = LazyArmnnFinderExtension('pyarmnn._generated._pyarmnn_{}'.format(name.lower()),
sources=['src/pyarmnn/_generated/armnn_{}_wrap.cpp'.format(
name.lower())],
extra_compile_args=['-std=c++14'],
language='c++',
armnn_libs=['libarmnn.so', 'libarmnn{}.so'.format(name)],
optional=[True]
)
ext_list.append(pyarmnn_optional_module)
add_parsers_ext('CaffeParser', extensions_to_build)
add_parsers_ext('OnnxParser', extensions_to_build)
add_parsers_ext('TfParser', extensions_to_build)
add_parsers_ext('TfLiteParser', extensions_to_build)
add_parsers_ext('Deserializer', extensions_to_build)
setup(
name='pyarmnn',
version=__version__,
author='Arm Ltd, NXP Semiconductors',
author_email='support@linaro.org',
description=DOCLINES[0],
long_description="\n".join(DOCLINES[2:]),
url='https://mlplatform.org/',
license='MIT',
keywords='armnn neural network machine learning',
classifiers=[
'Development Status :: 3 - Alpha',
'Intended Audience :: Developers',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3 :: Only',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',
'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules',
],
package_dir={'': 'src'},
packages=[
'pyarmnn',
'pyarmnn._generated',
'pyarmnn._quantization',
'pyarmnn._tensor',
'pyarmnn._utilities'
],
data_files=[('', ['LICENSE'])],
python_requires='>=3.5',
install_requires=['numpy'],
cmdclass={
'build_py': ExtensionPriorityBuilder,
'build_ext': ArmnnVersionCheckerExtBuilder
},
ext_modules=extensions_to_build
)