blob: 5a8b88f5a2081a6544cb66c62849a94a5255009c [file] [log] [blame]
import re
from typing import Tuple, List, Iterable, Iterator, Callable, Sequence, TypeVar, Optional
from enum import Enum
import contextlib
import textwrap
# Many of these functions share logic for defining both the definition
# and declaration (for example, the function signature is the same), so
# we organize them into one function that takes a Target to say which
# code we want.
#
# This is an OPEN enum (we may add more cases to it in the future), so be sure
# to explicitly specify with Union[Literal[Target.XXX]] what targets are valid
# for your use.
Target = Enum('Target', (
# top level namespace (not including at)
'DEFINITION',
'DECLARATION',
# TORCH_LIBRARY(...) { ... }
'REGISTRATION',
# namespace { ... }
'ANONYMOUS_DEFINITION',
# namespace cpu { ... }
'NAMESPACED_DEFINITION',
'NAMESPACED_DECLARATION',
))
# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
# occurrence of a parameter in the derivative formula
IDENT_REGEX = r'(^|\W){}($|\W)'
# TODO: Use a real parser here; this will get bamboozled
def split_name_params(schema: str) -> Tuple[str, List[str]]:
m = re.match(r'(\w+)(\.\w+)?\((.*)\)', schema)
if m is None:
raise RuntimeError(f'Unsupported function schema: {schema}')
name, _, params = m.groups()
return name, params.split(', ')
T = TypeVar('T')
S = TypeVar('S')
# These two functions purposely return generators in analogy to map()
# so that you don't mix up when you need to list() them
# Map over function that may return None; omit Nones from output sequence
def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]:
for x in xs:
r = func(x)
if r is not None:
yield r
# Map over function that returns sequences and cat them all together
def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
for x in xs:
for r in func(x):
yield r
# Conveniently add error context to exceptions raised. Lets us
# easily say that an error occurred while processing a specific
# context.
@contextlib.contextmanager
def context(msg: str) -> Iterator[None]:
try:
yield
except Exception as e:
# TODO: this does the wrong thing with KeyError
msg = textwrap.indent(msg, ' ')
msg = f'{e.args[0]}\n{msg}' if e.args else msg
e.args = (msg,) + e.args[1:]
raise