| # Copyright (c) Facebook, Inc. and its affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| import distutils.command.clean |
| import shutil |
| import glob |
| import os |
| import subprocess |
| from setuptools import setup, find_packages |
| from torch.utils.cpp_extension import ( |
| CppExtension, |
| BuildExtension, |
| ) |
| |
| cwd = os.path.dirname(os.path.abspath(__file__)) |
| version_txt = os.path.join(cwd, 'version.txt') |
| with open(version_txt, 'r') as f: |
| version = f.readline().strip() |
| |
| try: |
| sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip() |
| except Exception: |
| sha = 'Unknown' |
| package_name = 'functorch' |
| |
| if os.getenv('BUILD_VERSION'): |
| version = os.getenv('BUILD_VERSION') |
| elif sha != 'Unknown': |
| version += '+' + sha[:7] |
| |
| |
| def write_version_file(): |
| version_path = os.path.join(cwd, 'functorch', 'version.py') |
| with open(version_path, 'w') as f: |
| f.write("__version__ = '{}'\n".format(version)) |
| f.write("git_version = {}\n".format(repr(sha))) |
| |
| |
| # pytorch_dep = 'torch' |
| # if os.getenv('PYTORCH_VERSION'): |
| # pytorch_dep += "==" + os.getenv('PYTORCH_VERSION') |
| requirements = [ |
| # This represents a nightly version of PyTorch. |
| # It can be installed as a binary or from source. |
| "torch>=1.10.0.dev", |
| ] |
| |
| |
| class clean(distutils.command.clean.clean): |
| def run(self): |
| with open(".gitignore", "r") as f: |
| ignores = f.read() |
| for wildcard in filter(None, ignores.split("\n")): |
| for filename in glob.glob(wildcard): |
| try: |
| os.remove(filename) |
| except OSError: |
| shutil.rmtree(filename, ignore_errors=True) |
| |
| # It's an old-style class in Python 2.7... |
| distutils.command.clean.clean.run(self) |
| |
| |
| def get_extensions(): |
| extension = CppExtension |
| |
| define_macros = [] |
| |
| extra_link_args = [] |
| extra_compile_args = {"cxx": [ |
| "-O3", |
| "-g", |
| "-std=c++14", |
| "-fdiagnostics-color=always", |
| # PyTorch SmallVector has this |
| '-Wno-array-bounds', |
| ]} |
| debug_mode = os.getenv('DEBUG', '0') == '1' |
| if debug_mode: |
| print("Compiling in debug mode") |
| extra_compile_args = { |
| "cxx": [ |
| "-O0", |
| "-fno-inline", |
| "-g", |
| "-std=c++14", |
| "-fdiagnostics-color=always", |
| # PyTorch SmallVector has this |
| '-Wno-array-bounds', |
| ]} |
| extra_link_args = ["-O0", "-g"] |
| |
| this_dir = os.path.dirname(os.path.abspath(__file__)) |
| extensions_dir = os.path.join(this_dir, "functorch", "csrc") |
| |
| extension_sources = set( |
| os.path.join(extensions_dir, p) |
| for p in glob.glob(os.path.join(extensions_dir, "*.cpp")) |
| ) |
| sources = list(extension_sources) |
| |
| ext_modules = [ |
| extension( |
| "functorch._C", |
| sources, |
| include_dirs=[this_dir], |
| define_macros=define_macros, |
| extra_compile_args=extra_compile_args, |
| extra_link_args=extra_link_args, |
| ) |
| ] |
| |
| return ext_modules |
| |
| |
| if __name__ == '__main__': |
| print("Building wheel {}-{}".format(package_name, version)) |
| write_version_file() |
| |
| setup( |
| # Metadata |
| name=package_name, |
| version=version, |
| author='PyTorch Core Team', |
| url="https://github.com/facebookresearch/functorch", |
| description='prototype of composable function transforms for PyTorch', |
| license='BSD', |
| |
| # Package info |
| packages=find_packages(), |
| install_requires=requirements, |
| ext_modules=get_extensions(), |
| cmdclass={ |
| "build_ext": BuildExtension.with_options(no_python_abi_suffix=True), |
| 'clean': clean, |
| }) |