blob: 2826f35999126e36d237c09adc187b896f0e59d0 [file] [log] [blame]
import base64
import enum
import functools
import getpass
import hashlib
import logging
import multiprocessing
import os
import re
import shutil
import signal
import subprocess
import sys
import sysconfig
import tempfile
import types
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
from ctypes import cdll
from threading import Thread
from time import sleep, time
from typing import Any, Dict
import torch
from torch.utils import cpp_extension
from . import config, cuda_properties, exc
LOCK_TIMEOUT = 600
# timing metrics for time spent in the compilation
_cumulative_compile_time = 0
_t0 = None
def _compile_start():
global _t0
if _t0 is None:
_t0 = time()
def _compile_end():
global _cumulative_compile_time, _t0
if _t0 is not None:
t1 = time()
_cumulative_compile_time += t1 - _t0
_t0 = None
# print("CUMULATIVE COMPILE TIME", _cumulative_compile_time)
log = logging.getLogger(__name__)
logging.getLogger("filelock").setLevel(logging.DEBUG if config.debug else logging.INFO)
@functools.lru_cache(None)
def cache_dir():
return os.environ.get(
"TORCHINDUCTOR_CACHE_DIR", f"/tmp/torchinductor_{getpass.getuser()}"
)
def get_lock_dir():
lock_dir = os.path.join(cache_dir(), "locks")
if not os.path.exists(lock_dir):
os.makedirs(lock_dir, exist_ok=True)
return lock_dir
def code_hash(code):
return (
"c"
+ base64.b32encode(hashlib.sha256(code.encode("utf-8")).digest())[:51]
.decode("utf-8")
.lower()
)
def write(source_code, ext, extra=""):
basename = code_hash(source_code + extra)
subdir = os.path.join(cache_dir(), basename[1:3])
if not os.path.exists(subdir):
os.makedirs(subdir, exist_ok=True)
path = os.path.join(subdir, f"{basename}.{ext}")
if not os.path.exists(path):
# use a temp file for thread safety
fd, tmp_path = tempfile.mkstemp(dir=subdir)
with os.fdopen(fd, "w") as f:
f.write(source_code)
os.rename(tmp_path, path)
return basename, path
def cpp_compiler():
if isinstance(config.cpp.cxx, (list, tuple)):
search = tuple(config.cpp.cxx)
else:
search = (config.cpp.cxx,)
return cpp_compiler_search(search)
@functools.lru_cache(1)
def cpp_compiler_search(search):
for cxx in search:
try:
if cxx is None:
from filelock import FileLock
lock_dir = get_lock_dir()
lock = FileLock(
os.path.join(lock_dir, "g++.lock"), timeout=LOCK_TIMEOUT
)
with lock:
cxx = install_gcc_via_conda()
subprocess.check_output([cxx, "--version"])
return cxx
except (subprocess.SubprocessError, FileNotFoundError, ImportError):
continue
raise exc.InvalidCxxCompiler()
def install_gcc_via_conda():
"""On older systems, this is a quick way to get a modern compiler"""
prefix = os.path.join(cache_dir(), "gcc")
cxx_path = os.path.join(prefix, "bin", "g++")
if not os.path.exists(cxx_path):
log.info("Downloading GCC via conda")
conda = os.environ.get("CONDA_EXE", "conda")
if conda is None:
conda = shutil.which("conda")
if conda is not None:
subprocess.check_call(
[
conda,
"create",
f"--prefix={prefix}",
"--channel=conda-forge",
"--quiet",
"-y",
"python=3.8",
"gxx",
],
stdout=subprocess.PIPE,
)
return cxx_path
def is_gcc():
return re.search(r"(gcc|g\+\+)", cpp_compiler())
class _SupportedVecIsa(enum.Enum):
AVX512 = 1
AVX2 = 2
INVALID = -1
def __bool__(self):
return self != _SupportedVecIsa.INVALID
@staticmethod
def isa_str(supported_isa: enum.Enum):
if supported_isa == _SupportedVecIsa.AVX512:
return "avx512"
elif supported_isa == _SupportedVecIsa.AVX2:
return "avx2"
else:
return ""
@staticmethod
def vec_macro(supported_isa: enum.Enum):
if supported_isa == _SupportedVecIsa.AVX512:
return "CPU_CAPABILITY_AVX512"
elif supported_isa == _SupportedVecIsa.AVX2:
return "CPU_CAPABILITY_AVX2"
else:
return ""
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
# might have too much redundant content that is useless for ISA check. Hence,
# we only cache some key isa information.
@functools.lru_cache(1)
def get_cpu_proc_info():
if sys.platform != "linux":
return []
isa_info = []
with open("/proc/cpuinfo") as _cpu_info:
_cpu_info_content = _cpu_info.read()
if _SupportedVecIsa.isa_str(_SupportedVecIsa.AVX512) in _cpu_info_content:
isa_info.append(_SupportedVecIsa.AVX512)
if _SupportedVecIsa.isa_str(_SupportedVecIsa.AVX2) in _cpu_info_content:
isa_info.append(_SupportedVecIsa.AVX2)
return isa_info
def supported_vector_isa():
# TODO: Add ARM Vec here.
# Dict(k: isa, v: number of float element)
vec_isa_info = {
_SupportedVecIsa.AVX512: 16,
_SupportedVecIsa.AVX2: 8,
}
if config.cpp.simdlen is None or config.cpp.simdlen <= 1:
return _SupportedVecIsa.INVALID
cpu_info_content = get_cpu_proc_info()
for isa in vec_isa_info.keys():
if isa in cpu_info_content and config.cpp.simdlen == vec_isa_info[isa]:
return isa
return _SupportedVecIsa.INVALID
def cpp_compile_command(input, output, include_pytorch=False):
valid_isa = supported_vector_isa()
if include_pytorch or valid_isa:
ipaths = cpp_extension.include_paths() + [sysconfig.get_path("include")]
lpaths = cpp_extension.library_paths() + [sysconfig.get_config_var("LIBDIR")]
libs = ["c10", "torch", "torch_cpu", "torch_python", "gomp"]
macros = _SupportedVecIsa.vec_macro(valid_isa)
if macros:
macros = f"-D{macros}"
else:
# Note - this is effectively a header only inclusion. Usage of some header files may result in
# symbol not found, if those header files require a library.
# For those cases, include the lpath and libs command as we do for pytorch above.
# This approach allows us to only pay for what we use.
ipaths = cpp_extension.include_paths() + [sysconfig.get_path("include")]
lpaths = []
libs = ["gomp"]
macros = ""
ipaths = " ".join(["-I" + p for p in ipaths])
lpaths = " ".join(["-L" + p for p in lpaths])
libs = " ".join(["-l" + p for p in libs])
return re.sub(
r"[ \n]+",
" ",
f"""
{cpp_compiler()} {input} -shared -fPIC -Wall -std=c++14 -Wno-unused-variable
{ipaths} {lpaths} {libs} {macros}
-march=native -O3 -ffast-math -fno-finite-math-only -fopenmp
-D C10_USING_CUSTOM_GENERATED_MACROS
-o{output}
""",
).strip()
class CppCodeCache:
cache = dict()
clear = staticmethod(cache.clear)
@staticmethod
def _load_library(path):
try:
return cdll.LoadLibrary(path)
except OSError as e:
if "gomp" in str(e) and os.path.exists("/usr/lib64/libgomp.so.1"):
# hacky workaround for fbcode/buck
global _libgomp
_libgomp = cdll.LoadLibrary("/usr/lib64/libgomp.so.1")
return cdll.LoadLibrary(path)
raise
@classmethod
def load(cls, source_code):
key, input_path = write(source_code, "cpp", extra=cpp_compile_command("i", "o"))
if key not in cls.cache:
from filelock import FileLock
lock_dir = get_lock_dir()
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
with lock:
output_path = input_path[:-3] + "so"
if not os.path.exists(output_path):
cmd = cpp_compile_command(
input=input_path, output=output_path
).split(" ")
try:
subprocess.check_output(cmd, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
raise exc.CppCompileError(cmd, e.output)
cls.cache[key] = cls._load_library(output_path)
cls.cache[key].key = key
return cls.cache[key]
class PyCodeCache:
cache = dict()
clear = staticmethod(cache.clear)
@classmethod
def load(cls, source_code):
key, path = write(source_code, "py")
if key not in cls.cache:
with open(path) as f:
code = compile(f.read(), path, "exec")
mod = types.ModuleType(f"{__name__}.{key}")
mod.__file__ = path
mod.key = key
exec(code, mod.__dict__, mod.__dict__)
# another thread might set this first
cls.cache.setdefault(key, mod)
return cls.cache[key]
@functools.lru_cache(None)
def patch_triton_dir():
os.environ["TRITON_CACHE_DIR"] = os.environ.get(
"TRITON_CACHE_DIR", os.path.join(cache_dir(), "triton")
)
class TritonCodeCache:
@staticmethod
def get_name(mod):
(name,) = [n for n in dir(mod) if n.startswith("kernel")]
return name
@classmethod
def load(cls, source_code):
patch_triton_dir()
mod = PyCodeCache.load(source_code)
return getattr(mod, cls.get_name(mod))
def _worker_compile(source_code, cc, device):
cuda_properties.set_compiler_worker_current_device(device)
kernel = TritonCodeCache.load(source_code)
kernel.precompile(warm_cache_only_with_cc=cc)
def _load_kernel(source_code):
kernel = TritonCodeCache.load(source_code)
kernel.precompile()
return kernel
class TritonFuture:
def __init__(self, source_code, future):
self.source_code = source_code
self.future = future
def result(self):
if hasattr(self, "kernel"):
return self.kernel
# If the worker failed this will throw an exception.
self.future.result()
kernel = self.kernel = _load_kernel(self.source_code)
del self.source_code, self.future
return kernel
class AsyncCompile:
def __init__(self):
self._context_keepalive = None
@staticmethod
@functools.lru_cache(1)
def pool():
assert config.compile_threads > 1
return ThreadPoolExecutor(config.compile_threads)
@staticmethod
@functools.lru_cache(1)
def process_pool():
# ensure properties have been calculated before processes
# are forked
cuda_properties._properties()
assert config.compile_threads > 1
orig_ppid = os.getpid()
# if this process dies abnormally (e.g. segfault)
# it will not shut down the workers. Instead
# the workers will have their parent reassigned to the
# init process. This launches a separate thread to
# watch for the worker getting reassigned,
# and cleans it up in this case.
def init():
def run():
while True:
sleep(1)
if orig_ppid != os.getppid():
os.kill(os.getpid(), signal.SIGKILL)
global _watchdog_thread
_watchdog_thread = Thread(target=run, daemon=True)
_watchdog_thread.start()
# we rely on 'fork' because we cannot control whether users
# have an `if __name__ == '__main__'` in their main process.
fork_context = multiprocessing.get_context("fork")
pool = ProcessPoolExecutor(
config.compile_threads, mp_context=fork_context, initializer=init
)
# when this pool is created in a subprocess object, the normal exit handler
# doesn't run, and we need to register our own handler.
# exitpriority has to be high, because another one of the finalizers will
# kill the worker thread that sends the shutdown message to the workers...
multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
return pool
@classmethod
def warm_pool(cls):
if config.compile_threads <= 1:
return
_compile_start()
pool = cls.process_pool()
# We have to fork processes for compiler workers, but the more memory and other resources that are loaded, the
# slower the os.fork time is, quite drastically. It also holds the GIL so we can't put it on another thread.
# Examples:
# A simple x + x + x script: 10ms seconds in the middle of the program, 2ms at startup
# tf_efficientnet_b0 benchmark: 50ms! in the middle of the program , 3ms at startup
# So we want to start the workers early when it is still cheap, and also to allow the workers to get
# ready before we have work for them.
# ProcessPoolExecutor also does not launch the workers until it finds a point when all the workers are idle.
# But if we waited until then fork time will be long and we will be waiting for the processes to initialize.
# We force them to start here with some YOLOing of the internal methods.
if hasattr(pool, "_start_queue_management_thread"):
pool._start_queue_management_thread()
else:
for i in range(config.compile_threads):
pool._adjust_process_count()
pool._start_executor_manager_thread()
_compile_end()
@classmethod
def submit(cls, task):
if config.compile_threads <= 1:
return task()
return cls.pool().submit(task)
@classmethod
def map(cls, fn, seq):
if config.compile_threads <= 1 or len(seq) <= 1:
return list(map(fn, seq))
return [t.result() for t in [cls.pool().submit(fn, x) for x in seq]]
def triton(self, source_code):
_compile_start()
if self._context_keepalive is None:
# Workaround `CUDA: Error- context is destroyed`
self._context_keepalive = torch.tensor([1], device="cuda")
if config.compile_threads > 1:
major, minor = torch.cuda.get_device_capability()
device = torch.cuda.current_device()
cc = major * 10 + minor
future = self.process_pool().submit(
_worker_compile, source_code, cc, device
)
return TritonFuture(source_code, future)
else:
return _load_kernel(source_code)
def cpp(self, source_code):
def task():
return CppCodeCache.load(source_code).kernel
return self.submit(task)
def wait(self, scope: Dict[str, Any]):
if config.compile_threads > 1:
for key, result in list(scope.items()):
if isinstance(result, (Future, TritonFuture)):
scope[key] = result.result()
_compile_end()
AsyncCompile.warm_pool()