blob: 5d40d05f751f92c62cc5406ea3c0f77b85ca0084 [file] [log] [blame]
from contextlib import contextmanager
from itertools import chain
from threading import local
import sympy
from torch.fx.graph import inplace_methods, magic_methods
from .utils import sympy_str, sympy_symbol
threadlocal = local()
class Virtualized:
"""
A global variable that redirects via thread local variable
This allows us to swap in different op implementations in codegen.
"""
def __init__(self, vname, default):
self._key = f"__torchinductor_{vname}"
self._default = default
def _set_handler(self, value):
prior = self._get_handler()
setattr(threadlocal, self._key, value)
@contextmanager
def ctx():
try:
yield
finally:
self._set_handler(prior)
return ctx()
def _get_handler(self):
try:
return getattr(threadlocal, self._key)
except AttributeError:
return self._default()
def __getattr__(self, name):
return getattr(self._get_handler(), name)
class NullHandler:
pass
def _arg_str(a):
if isinstance(a, sympy.Expr):
return sympy_str(a)
return str(a)
class MockHandler:
def __getattr__(self, name):
def inner(*args, **kwargs):
fargs = [_arg_str(a) for a in args]
fargs.extend(f"{k}={v}" for k, v in kwargs.items())
return self.truncate_expr(f"{name}({', '.join(fargs)})")
return inner
@staticmethod
def truncate_expr(expr):
return expr
@classmethod
def masked(cls, mask, body, other):
return cls.truncate_expr(f"masked({mask}, {body()}, {other})")
@staticmethod
def indirect_indexing(index_var):
return sympy_symbol(str(index_var))
@classmethod
def _init_cls(cls):
def make_handler(format_string):
@staticmethod
def inner(*args):
return format_string.format(*args)
return inner
for name, format_string in chain(
magic_methods.items(), inplace_methods.items()
):
setattr(cls, name, make_handler(format_string))
class WrapperHandler:
def __init__(self, inner):
self._inner = inner
def __getattr__(self, item):
return getattr(self._inner, item)
MockHandler._init_cls()
ops = Virtualized("ops", MockHandler)
_graph = Virtualized("graph", NullHandler)
_kernel = Virtualized("kernel", NullHandler)
_debug = Virtualized("debug", NullHandler)
class _V:
MockHandler = MockHandler
WrapperHandler = WrapperHandler
set_ops_handler = ops._set_handler
get_ops_handler = ops._get_handler
set_graph_handler = _graph._set_handler
set_kernel_handler = _kernel._set_handler
set_debug_handler = _debug._set_handler
@property
def ops(self) -> MockHandler:
"""The operator handler specific to the current codegen task"""
return ops._get_handler()
@property
def graph(self):
"""The graph currently being generated"""
return _graph._get_handler()
@property
def kernel(self):
"""The kernel currently being generated"""
return _kernel._get_handler()
@property
def debug(self):
return _debug._get_handler()
V = _V()