blob: fa6fc30172eced70a6acfdcf693c376f672f9ad5 [file] [log] [blame]
# Copyright (C) 2020 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions"""
from array import array
import abc
import sys
from collections import namedtuple, OrderedDict
import collections.abc
from bisect import bisect_right
from concurrent.futures import Future, ThreadPoolExecutor, CancelledError
from contextlib import contextmanager
from itertools import tee, filterfalse, takewhile, chain
from enum import Enum
import warnings
import weakref
import inspect
import logging
import operator
import os
from os.path import (
dirname,
isabs,
join as pjoin,
normpath,
)
import threading
import time
import re
import numpy as np
# Just import pint eagerly, since we rely on it during
# bootstrap anyway.
# TODO(dancol): figure out whether we can get pint's load time down
import pint
from cytoolz import first, second, merge
from cytoolz.dicttoolz import valmap as dtz_valmap
from modernmp.util import (
ChainableFuture,
STARTUP_CWD,
assert_seq_type,
cached_property,
once,
)
from modernmp.shm import SharedObject
warnings.filterwarnings("ignore", category=pint.UnitStrippedWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
log = logging.getLogger(__name__)
@once()
def _get_timing_tls():
return threading.local()
@once()
def get_cpu_count():
"""Get the number of cores on the system"""
return os.cpu_count()
@once()
def get_process_thread_pool():
"""Get the global process thread pool"""
return ThreadPoolExecutor(max_workers=get_cpu_count())
def _dummy_for_freeze():
# pylint: disable=unused-variable,unused-import
import cytoolz._signatures
def module_loader(module_name):
"""Caching and reporting for module load function"""
def _decorator(real_importer):
loader_holder = [None]
# pylint: disable=dangerous-default-value
def _loader(real_importer=real_importer,
loader_holder=loader_holder,
module_name=module_name): # pylint: disable=unused-argument
#TODO(dancol): module loading benchmarks
#with Timed("loading {}".format(module_name)):
# module = real_importer()
module = real_importer()
def _fast_loader(module=module):
return module
loader_fn = loader_holder[0]
loader_fn.__code__ = _fast_loader.__code__
loader_fn.__defaults__ = _fast_loader.__defaults__
return module
loader_holder[0] = _loader
return _loader
return _decorator
@module_loader("pandas")
def load_pandas(_lock=threading.Lock(), _state=[False]): # pylint: disable=dangerous-default-value
"""Load pandas"""
# Keep open-coded so freeze notices the module
import pandas
return pandas
@module_loader("networkx")
def load_networkx():
"""Load networkx"""
# Keep open-coded so freeze notices the module
import networkx
return networkx
@once()
def ureg():
"""Return the DCTV Pint unit registry"""
return pint.UnitRegistry(pjoin(dirname(__file__), "units.txt"))
def make_pd_dataframe(column_dict):
"""Make Pandas dataframe from column contents without copying"""
pd = load_pandas()
column_df = [pd.DataFrame(v, columns=[k], copy=False)
for k, v in column_dict.items()]
return pd.concat(column_df, axis=1, copy=False)
def do_mt(func, *args, **kwargs):
"""Call FUNC in a worker thread in the process thread pool"""
return get_process_thread_pool().submit(func, *args, **kwargs)
def map_mt(func, arg_list):
"""Apply FUNC to each item in ARG_LIST in process thread pool"""
return map(Future.result,
(get_process_thread_pool().submit(func, arg)
for arg in arg_list))
def lmap(*args, **kwargs):
"""Like map, but build a list synchronously"""
return list(map(*args, **kwargs))
def get_object_cache(obj,
_ga=getattr,
_sa=object.__setattr__):
"""Retrieve the cache directory for OBJ.
Creates the dictionary if it does not already exist.
"""
cache = _ga(obj, "__cache", None)
if cache is None:
cache = {}
_sa(obj, "__cache", cache)
return cache
TIMING_FIELDS = [("wall", time.monotonic)]
if hasattr(time, "CLOCK_THREAD_CPUTIME_ID"):
TIMING_FIELDS.append(
("threadcpu",
lambda: time.clock_gettime(time.CLOCK_THREAD_CPUTIME_ID)))
class TimeSnapshot(
namedtuple("TimeSnapshot", map(lambda tf: tf[0], TIMING_FIELDS))): # pylint: disable=undefined-variable
"""Snapshot of process and thread performance counts information"""
__slots__ = []
@staticmethod
def current():
"""Sample performance counters"""
return TimeSnapshot(*(x[1]() for x in TIMING_FIELDS))
def __sub__(self, other):
return TimeSnapshot(*map(operator.sub, self, other))
class NoStripString(str):
"""Hack to prevent tabulate from stripping leading whitespace"""
def strip(self, chars=None):
return NoStripString(self.rstrip(chars))
def lstrip(self, chars=None):
return self
def __str__(self):
return self
class DummyTimed:
"""Swap in for Timed to make a noop"""
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
pass
class Timed:
# pylint: disable=attribute-defined-outside-init
"""Context manager that records timing for code sections"""
INDENT = 2
def __init__(self, name, independent=False):
self.name = name
self.children = []
self.independent = independent
def __enter__(self):
if not hasattr(_get_timing_tls(), "current"):
_get_timing_tls().current = None
self.parent = _get_timing_tls().current
_get_timing_tls().current = self
self.start = TimeSnapshot.current()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop = TimeSnapshot.current()
assert _get_timing_tls().current is self
_get_timing_tls().current = self.parent
if self.parent is not None:
self.parent.children.append(self)
self.parent = None # Avoid cycle
if not self.independent:
return
if exc_type is not None:
return
xlog = inspect.currentframe().f_back.f_globals.get("log", log)
from io import StringIO
buf = StringIO()
self.dump(buf, self.INDENT)
sys.stdout.flush()
xlog.debug("TIMING INFORMATION %s\n%s",
self.name,
buf.getvalue())
def __generate_rows(self, indent):
elapsed = self.stop - self.start
yield ([NoStripString("%*s%s" % (indent, "", self.name))] +
[x * 1000.0 for x in elapsed])
for child in self.children:
# pylint: disable=protected-access
yield from child.__generate_rows(indent + self.INDENT)
def dump(self, out, indent=0):
"""Print accumulated timing information"""
rows = list(self.__generate_rows(indent))
from tabulate import tabulate
out.write(
tabulate(
rows,
tablefmt="plain",
floatfmt="4.3f",
headers=[""] + list(map(str.upper, TimeSnapshot._fields))
))
out.write("\n")
def collect_profile(fn, profile_out_file_name):
"""Call a function with a profile; dump in kcachegrind format"""
import cProfile
profiler = cProfile.Profile(time.process_time)
def _wrapper():
try:
fn()
except:
log.warning("fn failed")
profiler.runcall(_wrapper)
from pyprof2calltree import convert
convert(profiler.getstats(), profile_out_file_name)
log.info("wrote profiler output to %r", profile_out_file_name)
@contextmanager
def envvar_set(variable, value):
"""Sets an environment variable while code runs"""
old_value = os.environ.get(variable)
os.environ[variable] = value
try:
yield
finally:
if old_value is None:
del os.environ[variable]
else:
os.environ[variable] = old_value
class FrozenDict(collections.abc.Hashable,
collections.abc.Mapping):
"""Immutable, hashable dictionary"""
def __init__(self, *args, **kwargs):
self._data = dict(*args, **kwargs)
self._hash = None
def __getitem__(self, key):
return self._data.__getitem__(key)
def __iter__(self):
return iter(self._data)
def __len__(self):
return len(self._data)
def __hash__(self):
my_hash = self._hash
if my_hash is None:
self._hash = my_hash = hash(frozenset(self._data.items()))
return my_hash
def __eq__(self, other):
if type(self) is not type(other):
return NotImplemented
# pylint: disable=protected-access
return self._data == other._data
def __repr__(self):
return "FrozenDict({})".format(
", ".join("{!r}: {!r}".format(k, v)
for k, v in sorted(self._data.items())))
def copy(self):
"""Clone dict: in immutable case, no need to copy"""
return self
def __reduce__(self):
return FrozenDict, (self._data,)
__slots__ = "_data", "_hash"
# Autonumber
class AutoNumberSafeMeta(type(Enum)):
"""Metaclass for AutoNumberSafe"""
def __new__(mcs, name, bases, dict_): # pylint: disable=bad-classmethod-argument
cls = super().__new__(mcs, name, bases, dict_)
cls.owns = lambda value, _cls=cls: isinstance(value, _cls)
return cls
class AutoNumberSafe(Enum, metaclass=AutoNumberSafeMeta): # pylint: disable=invalid-metaclass
"""Enum wrapper"""
def __new__(cls):
value = len(cls.__members__) + 1
obj = object.__new__(cls)
obj._value_ = value # pylint: disable=protected-access
return obj
class AutoNumberFastMeta(type):
"""Metaclass for AutoNumberFast"""
def __new__(mcs, cls_name, bases, dict_):
# TODO(dancol): pseudorandomly vary start so collisions are less
# likely even though we're using ints
start = dict_.pop("__start__", 1)
idx = start
members = OrderedDict()
for name, value in tuple(dict_.items()):
if value == ():
dict_[name] = idx
members[name] = idx
idx += 1
valrange = range(start, idx)
@staticmethod
def _owns(value, _range=valrange):
return isinstance(value, int) and value in _range
dict_["__members__"] = members
dict_["__range__"] = valrange
new_cls = super().__new__(mcs, cls_name, bases, dict_)
return new_cls
def owns(cls, value):
"""Return whether this enum-like class owns the given value"""
return isinstance(value, int) and value in cls.__range__
def __iter__(cls):
return iter(cls.__range__)
def label_of(cls, value):
""""String name of enumeration value"""
for xkey, xvalue in cls.__members__.items():
if xvalue == value:
return xkey
raise KeyError(value)
def __call__(cls, value):
# pylint: disable=no-value-for-parameter
if isinstance(value, str):
return getattr(cls, value.upper())
assert cls.owns(value)
return value
def __instancecheck__(cls, value):
return cls.owns(value) # pylint: disable=no-value-for-parameter
class AutoNumberFast(metaclass=AutoNumberFastMeta):
"""Lightweight integer mapping"""
value_type = int
if __debug__:
#AutoNumber = AutoNumberSafe
AutoNumber = AutoNumberFast
else:
AutoNumber = AutoNumberFast
# Paranoid checking for object inheritance. Objects with metaclass
# ExplicitInheritanceMeta fail construction when they don't satisfy
# Java-style constraints on methods. This little facility helps
# address the "fragile base class" problem.
class InheritMode(AutoNumber):
"""Ways in which a field or class can be inherited"""
ABSTRACT = ()
FINAL = ()
OVERRIDE = ()
OVERRIDE_FINAL = ()
UNSPECIFIED = ()
_INHERIT_MODE = "_inherit_mode"
_inherit_by_type = weakref.WeakKeyDictionary()
def _make_inheritance_decorator(mode):
def _decorator(obj, mode=mode):
obj_inherit_info = (not isinstance(obj, (classmethod, staticmethod))
and _inherit_by_type.get(obj))
if obj_inherit_info:
obj_class_mode, obj_class_fields = obj_inherit_info
if (mode is obj_class_mode or
obj_class_mode == InheritMode.UNSPECIFIED):
obj_class_mode = mode
else:
raise ValueError(
"conflicting class types for {!r}: {!r} vs {!r}"
.format(obj, mode, obj_class_mode))
_inherit_by_type[obj] = (obj_class_mode, obj_class_fields)
return obj
preset_mode = getattr(obj, _INHERIT_MODE, None)
if preset_mode:
if ((mode == InheritMode.FINAL
and preset_mode == InheritMode.OVERRIDE)
or
(mode == InheritMode.OVERRIDE
and preset_mode == InheritMode.FINAL)):
mode = InheritMode.OVERRIDE_FINAL
else:
raise ValueError("obj already marked")
setattr(obj, _INHERIT_MODE, mode)
return obj
setattr(_decorator, "_for_inherit_mode", mode)
return _decorator
abstract = _make_inheritance_decorator(InheritMode.ABSTRACT)
final = _make_inheritance_decorator(InheritMode.FINAL)
override = _make_inheritance_decorator(InheritMode.OVERRIDE)
override_final = _make_inheritance_decorator(InheritMode.OVERRIDE_FINAL)
class InheritanceConstraintViolationError(RuntimeError):
"""Exception raised when a class violates inheritance constraints"""
def _get_cls_inherit_info(base):
info = _inherit_by_type.get(base)
if not info:
# Outside ExplicitInheritance system, so synthesize
member_info = {
member_name: (InheritMode.ABSTRACT
if getattr(member, "__isabstractmethod__", False)
else InheritMode.UNSPECIFIED)
for member_name, member in base.__dict__.items()}
_inherit_by_type[base] = info = InheritMode.UNSPECIFIED, member_info
return info
def _check_inheritance(cls,
dict_,
marked_abstract,
marked_inherit,
virtual_mro):
violations = []
def _scan_base(base):
base_mode, base_fields = _get_cls_inherit_info(base)
if base_mode == InheritMode.FINAL:
violations.append(
"cannot derive from final class {!r}".format(base.__name__))
return base_fields
def _scan_bases(cls, virtual_mro):
all_base_fields = {}
virtual_base_fields = {}
seen_bases = set()
for base in chain(reversed(cls.mro()[1:]), reversed(virtual_mro)):
if base in seen_bases:
continue
seen_bases.add(base)
base_fields = _scan_base(base)
all_base_fields.update(base_fields)
if base in virtual_mro:
virtual_base_fields.update(base_fields)
return all_base_fields, virtual_base_fields
all_base_fields, virtual_base_fields = _scan_bases(cls, virtual_mro)
explicit_inheritance = dtz_valmap(
lambda v: getattr(v, "_for_inherit_mode"),
marked_inherit)
def _scan_field(name, inherit_mode, base_inherit_mode):
if base_inherit_mode in (InheritMode.FINAL,
InheritMode.OVERRIDE_FINAL):
violations.append("final method {!r} overridden".format(name))
if (base_inherit_mode and
inherit_mode not in (InheritMode.ABSTRACT,
InheritMode.OVERRIDE,
InheritMode.OVERRIDE_FINAL)):
if name not in (
"__abstract__",
"__classcell__",
"__doc__",
"__inherit__",
"__module__",
"__qualname__",
"__slots__",
):
violations.append(
"method {!r} overridden without @override".format(name))
if (inherit_mode in (InheritMode.OVERRIDE, InheritMode.OVERRIDE_FINAL)
and base_inherit_mode is None):
violations.append(
"method {!r} is marked @override but overrides nothing".format(name))
if inherit_mode == InheritMode.ABSTRACT:
if base_inherit_mode == InheritMode.ABSTRACT:
violations.append(
"abstract method {!r} overrides abstract method".format(name))
def _scan_fields(fields):
field_info = {}
for name, value in fields:
if isinstance(value, property):
field_inherit_mode = InheritMode.FINAL
elif hasattr(value, _INHERIT_MODE):
field_inherit_mode = getattr(value, _INHERIT_MODE)
delattr(value, _INHERIT_MODE)
if name in explicit_inheritance:
violations.append("method {!r} already described in __inherit__"
.format(name))
else:
field_inherit_mode = \
explicit_inheritance.get(name, InheritMode.UNSPECIFIED)
field_info[name] = field_inherit_mode
base_field_inherit_mode = all_base_fields.get(name, None)
_scan_field(name, field_inherit_mode, base_field_inherit_mode)
return field_info
field_info = _scan_fields(dict_.items())
unused_explicit_inheritance = set(explicit_inheritance) - set(field_info)
if unused_explicit_inheritance:
violations.append(
"fields {} in __inherit__ but not in class".format(
",".join(sorted(unused_explicit_inheritance))))
# Need __abstract__ because @abstract class decorator won't run
# until after we're done chewing through the class dictionary here.
has_any_abstract = InheritMode.ABSTRACT in field_info.values()
is_cls_abstract = marked_abstract
if not (is_cls_abstract or has_any_abstract):
for name, mode in all_base_fields.items():
if mode == InheritMode.ABSTRACT and name not in field_info:
violations.append(
"abstract method {} not overridden".format(name))
if violations:
raise InheritanceConstraintViolationError(
"inheritance constraints violated for {}: {}".format(
cls,
"; ".join(sorted(violations))))
cls_mode = InheritMode.UNSPECIFIED
if is_cls_abstract:
cls_mode = InheritMode.ABSTRACT
if virtual_base_fields:
field_info = merge(virtual_base_fields, field_info)
_inherit_by_type[cls] = (cls_mode, field_info)
def _cls_new_members(cls):
"""Return members introduced in a class"""
members = set(cls.__dict__)
for base in cls.mro()[1:]:
members -= set(base.__dict__)
return members
def _do_brain_suck_abc(brain_suck_abc, dict_, bases):
"""Fake inheritance by copying members into a new class"""
real_bases = frozenset(chain.from_iterable(base.mro() for base in bases))
new_members = set()
for base in reversed(brain_suck_abc.mro()):
if base not in real_bases:
new_members.update(_cls_new_members(base))
for name in sorted(new_members):
if (name not in (
"__abstractmethods__",
"__module__",
"__slots__",
"__subclasshook__",
) and not name.startswith("_abc_")
and not name in dict_):
new_member = getattr(brain_suck_abc, name)
if not getattr(new_member, "__isabstractmethod__", False):
dict_[name] = new_member
class ExplicitInheritanceMeta(type):
"""Metaclass that enforces inheritance invariants"""
def __new__(mcs, cls_name, bases, dict_,
tweaked_dict=None,
brain_suck_abc=None):
# Need to remove these from dict_ *before* creating the type
marked_abstract = dict_.pop("__abstract__", False)
marked_inherit = dict_.pop("__inherit__", {})
virtual_mro = ()
if brain_suck_abc:
# Special case that allows for "inheritance" from container ABC
# helper classes without involving ABCMeta in the picture and
# screwing up our own metaclass hierarchy. This approach works
# only because all of these mixins are stateless.
assert issubclass(type(brain_suck_abc), abc.ABCMeta)
virtual_mro = brain_suck_abc.mro()
tweaked_dict = tweaked_dict or dict_.copy()
_do_brain_suck_abc(brain_suck_abc, dict_, bases)
new_cls = super().__new__(mcs, cls_name, bases, dict_)
_check_inheritance(new_cls,
tweaked_dict or dict_,
marked_abstract,
marked_inherit,
virtual_mro)
if brain_suck_abc:
brain_suck_abc.register(new_cls)
return new_cls
def __init__(cls, cls_name, bases, dict_, **kwargs):
kwargs.pop("brain_suck_abc", None)
super().__init__(cls_name, bases, dict_, **kwargs)
class ExplicitInheritance(metaclass=ExplicitInheritanceMeta):
"""Base class for objects with explicit inheritance checking"""
# Immutable
NO_DEFAULT = object()
class ImmutableProperty(ExplicitInheritance):
"""Description of an immutable property
Classes deriving from Immutable auto-generate constructors that set
the field described by these properties, including those found in
subclasses.
"""
@override
def __init__(self, type_=None, *,
assert_checker=None,
type_if_debug=None,
default=NO_DEFAULT,
kwonly=False,
converter=None,
name=None,
nullable=False):
"""Build an immutable property.
TYPE_ is either a type or a tuple of types (just like for
isinstance). It's idiomatically given to this constructor by
position; the rest of the arguments are keyword-only. In debug
builds, we assert that the class property matches TYPE_ (but after
CONVERTER, if present, runs).
ASSERT_CHECKER is a callable that runs only in debug mode and that
checks values for validity. It should return true if everything
is okay. (The generated code calls ASSERT_CHECKER inside an
assert.) ASSERT_CHECKER runs _after_ CONVERTER if both
are present.
TYPE_IF_DEBUG is a callable that yields the actual type of the
iattr. (TODO(dancol): get rid of this. We used it for ureg()
types, but we can load pint eagerly now.)
DEFAULT is a default value for this property, or NO_DEFAULT if the
property should be required. Default property values obey usual
Python function argument rules: non-defaults must precede
all defaults.
If KWONLY is true, this property can be given to the class
constructor only by keyword, not by position.
If CONVERTER is given, it's a callable that takes as input the raw
value given to the clss constructor and returns the actual value
the class field should have. It runs before ASSERT_CHECKER.
NAME, if given, overrides the constructor argument name, which
normally matches the property name.
NULLABLE is a convenience function that automatically adds
NoneType to the list of allowed types and that avoids calling
CONVERTER and ASSERT_CHECKER if the value given to the field is
None. (You can do the same thing by hand, but it's annoying.)
N.B. NULLABLE and DEFAULT are orthogonal: if a value is nullable
but has no default, callers must specify a value for the field,
but that value may be None. If a value has a default but is not
nullable, then the value may not have the value None (unless None
is explicitly allowed) and the default value must match the type
spec, but callers don't have to explicitly supply the value.
"""
if __debug__:
if type_:
assert not type_if_debug
elif type_if_debug:
type_ = type_if_debug()
assert not type_ or \
isinstance(type_, type) or \
assert_seq_type(tuple, type, type_)
self.type_ = type_
assert not assert_checker or callable(assert_checker)
self._assert_checker = assert_checker
self._default = default
self._kwonly = bool(kwonly)
assert not converter or callable(converter)
self._converter = converter
assert not name or isinstance(name, str)
self.name = name
self.propname = None
self._ns = {}
assert isinstance(nullable, bool)
self.nullable = nullable
def format_value(self, value): # pylint: disable=no-self-use
"""Format a field value for __repr__"""
return repr(value)
def set_name(self, name):
"""Set the name of the property; used during metaclass setup"""
assert not self.name
assert name and isinstance(name, str)
if __debug__:
if name.startswith("__") or name.startswith("_R_"):
raise ValueError("invalid iprop name {!r}".format(name))
self.name = name
@property
def has_default(self):
"""Do we have a default value?"""
return self._default is not NO_DEFAULT
@property
def kwonly(self):
"""Is this property keyword-only?"""
return self._kwonly
def _mkref(self, attr):
ds = "_R_v_{attr}_{propname}".format(
attr=attr, propname=self.propname)
self._ns[ds] = getattr(self, attr)
return ds
@cached_property
def _default_ref(self):
return self._mkref("_default")
@cached_property
def _type_ref(self):
return self._mkref("type_")
@cached_property
def _converter_ref(self):
return self._mkref("_converter")
@cached_property
def _assert_checker_ref(self):
return self._mkref("_assert_checker")
@property
def _convert_expr(self):
if self._converter:
return "{attr._converter_ref}({attr.name})".format(attr=self)
return self.name
@property
def _checked_convert_expr(self):
if self.nullable and self._converter:
return ("None if {attr.name} is None else {attr._convert_expr}"
.format(attr=self))
return self._convert_expr
@property
def __is_specified_expr(self):
# N.B. Python compiles "if True" to a no-op, so returning True
# here doesn't impact runtime performance.
if not self.nullable:
return "True"
return "{attr.name} is not None".format(attr=self)
@staticmethod
def has_conversion_code(do_intern, emit_checks):
"""Return whether we need a separate conversion function step"""
return emit_checks or do_intern
def emit_conversion_code(self, cg, do_intern, emit_checks):
"""Emit the type conversion and checking code: used during codegen"""
# TODO(dancol): we can avoid the intermediate value assignment if
# we're smarter.
if not self.has_conversion_code(do_intern, emit_checks):
return # Will call converter in assignment
have_checks = emit_checks and (self._assert_checker or self.type_)
if have_checks or self._converter:
with cg.indent("if {}:".format(self.__is_specified_expr)):
if self._converter:
cg("{attr.name} = {attr._convert_expr}", attr=self)
if emit_checks and self._assert_checker:
assert have_checks
cg("assert {attr._assert_checker_ref}({attr.name}), "
"{msg!r}.format(name={attr.name!r}, "
"checker={attr._assert_checker_ref})",
msg=("type check {checker} "
"failed for {name!r}"),
attr=self)
if emit_checks and self.type_:
assert have_checks
cg("assert isinstance({attr.name}, {attr._type_ref}), "
"{msg!r}.format(name={attr.name!r}, "
"have={attr.name}, want={attr._type_ref}, cls=_R_the_cls)",
msg=("type mismatch in {cls} for {name!r}: "
"have {have!r} want {want!r}"),
attr=self)
def emit_assignment_code(self, cg, do_intern, emit_checks):
"""Emit the field assignment line: used during codegen"""
if not self.has_conversion_code(do_intern, emit_checks):
cg("_R_obj.{attr.propname} = {attr._checked_convert_expr}", attr=self)
else:
cg("_R_obj.{attr.propname} = {attr.name}", attr=self)
def emit_evolve_assignment_code(self, cg, emit_checks):
"""Emit code to implement a field assignment in evolve()"""
if not self.has_conversion_code(False, emit_checks):
cg("_R_obj.{attr.propname} = _R_old.{attr.propname} "
"if {attr.name} is _R_dummy else ({attr._checked_convert_expr})",
attr=self)
else:
with cg.indent("if {attr.name} is _R_dummy:", attr=self):
cg("_R_obj.{attr.propname} = _R_old.{attr.propname}",
attr=self)
with cg.indent("else:"):
self.emit_conversion_code(cg, False, emit_checks)
self.emit_assignment_code(cg, False, emit_checks)
@property
def arg_decl(self):
"""Constructor argument expression: used during codegen"""
assert self.name
if self.has_default:
return "{attr.name}={attr._default_ref}".format(attr=self)
return self.name
def augment_ns(self, ns):
"""Update a dict with this property: used during codegen"""
ns.define_globals(**self._ns)
# N.B. Type argument must be called "type_" for pylint inference to
# work correctly.
iattr = ImmutableProperty
def sattr(type_=None, converter=frozenset, nonempty=False, **kwargs):
"""Immutable attribute with set value"""
def _assert_checker(value):
if nonempty:
assert value
if type_:
assert _assert_seq_elements(type_, value)
return True
return ImmutableProperty(
frozenset,
converter=converter,
assert_checker=_assert_checker,
**kwargs)
def tattr(type_=None, converter=tuple, nonempty=False, **kwargs):
"""Immutable attribute with tuple value"""
def _assert_checker(value):
if nonempty:
assert value
if type_:
assert _assert_seq_elements(type_, value)
return True
return ImmutableProperty(
tuple,
converter=converter,
assert_checker=_assert_checker,
**kwargs)
class ImmutablePropertyEnum(ImmutableProperty):
"""Immutable property holding an enumeration"""
@override
def __init__(self, type_, **kwargs):
if isinstance(type_, tuple):
assert all(issubclass(enum, AutoNumber) for enum in type_)
def _checker(value):
return any(enum.owns(value) for enum in type_)
assert_checker = _checker
self.__enum_types = type_
else:
assert issubclass(type_, AutoNumber)
assert_checker = type_.owns
self.__enum_types = (type_,)
super().__init__(AutoNumber.value_type,
assert_checker=assert_checker,
**kwargs)
@override
def format_value(self, value):
for enum_type in self.__enum_types:
if enum_type.owns(value):
return enum_type.__qualname__ + "." + enum_type.label_of(value)
return "{!r}=???".format(value)
enumattr = ImmutablePropertyEnum
class ImmutableError(ValueError):
"""Exception for immutable constraint violations"""
def __init__(self, for_, msg):
super().__init__("error in class {!r}: {}".format(for_, msg))
class ImmutableClassInfo(object):
"""Information for each immutable class type"""
name = None
call_post_init_assert = False
call_post_init_check = False
newfn_source = None
def __init__(self):
self.iprops = OrderedDict()
self.all_iprops = OrderedDict()
@staticmethod
def scan(cls_name, bases, dict_):
"""Extract information from class under construction"""
scanned_bases = set()
nr_immutable_bases = 0
iinfo = ImmutableClassInfo()
saw_inherited_nonkw = set()
def _scan_base(base):
if base in scanned_bases:
return
scanned_bases.add(base)
base_iinfo = _IMMUTABLE_INFO.get(base)
if not base_iinfo:
return
nonlocal nr_immutable_bases
nr_immutable_bases += 1
iinfo.call_post_init_check = iinfo.call_post_init_check or \
(base is not Immutable and base_iinfo.call_post_init_check)
iinfo.call_post_init_assert = iinfo.call_post_init_assert or \
(base is not Immutable and base_iinfo.call_post_init_assert)
for name, descr in base_iinfo.iprops.items():
if name in iinfo.all_iprops:
raise ImmutableError(cls_name,
"duplicate iprop {!r}".format(name))
if not descr.kwonly:
saw_inherited_nonkw.add(name)
iinfo.all_iprops[name] = descr
for nb in base.mro():
_scan_base(nb)
for base in bases:
_scan_base(base)
for name in tuple(dict_):
if name == "__init__":
raise ImmutableError(cls_name,
"do not define {!r}".format(name))
if name == "_post_init_assert":
iinfo.call_post_init_assert = True
if name == "_post_init_check":
iinfo.call_post_init_check = True
value = dict_[name]
if isinstance(value, ImmutableProperty):
assert not name.startswith("__")
value.propname = name
if not value.name:
value.set_name(name)
del dict_[name]
if name in iinfo.all_iprops:
raise ImmutableError(
cls_name, "iprop conflict for {!r}".format(name))
iinfo.all_iprops[name] = value
iinfo.iprops[name] = value
if nr_immutable_bases > 1 and iinfo.iprops and saw_inherited_nonkw:
raise ImmutableError(
cls_name,
"cannot inherit non-kw-only iprop {!r}"
.format(saw_inherited_nonkw))
return iinfo
_IMMUTABLE_INFO = weakref.WeakKeyDictionary()
class _ImmutableAssignDummy(object):
"""Temporary class for immutable object construction"""
# To make an immutable object, we first make one of these objects,
# assign its properties normally, then transmute its __class__ to
# the desired object type. This way, we bypass the __setattr__
# mutability check, but only during construction.
class CodeGenerator(object):
"""Format Python code conveniently"""
def __init__(self):
self.__lines = []
self.__indent = 0
self.__globals = {}
def __call__(self, text, *args, **kwargs):
"""Add a formatted line of text"""
self.__lines.append(" " * self.__indent +
text.format(*args, **kwargs))
@contextmanager
def indent(self, *args, **kwargs):
"""Context manager indenting subsequent lines."""
if args:
self(*args, **kwargs)
else:
assert not kwargs
self.__indent += 2
try:
yield
finally:
self.__indent -= 2
def define_global(self, name, value):
"""Define a global for subsequent exec"""
assert name not in self.__globals
self.__globals[name] = value
def define_globals(self, **kwargs):
"""Alternate variant of define_global"""
assert not any(kwarg in self.__globals for kwarg in kwargs)
self.__globals.update(kwargs)
def get_source(self):
"""Return source generated so far"""
return "\n".join(self.__lines)
def compile(self):
"""Generate code
Return a pair (LOCALS, SRC) where LOCALS is the locals dictionary
that might have been modified by side effect of codegen and SRC is
the generated source code.
"""
assert not self.__indent
src = self.get_source()
dummy_locals = {}
exec(src, self.__globals, dummy_locals)
return dummy_locals, self.__globals, src
def generate_immutable_constructor(iinfo, do_intern, emit_checks):
"""Generate code for an immutable object constructor"""
def _iinfo_sort_key(descr):
return descr.kwonly, (descr.kwonly and descr.has_default)
all_iprops = list(iinfo.all_iprops.values())
all_iprops.sort(key=_iinfo_sort_key)
cg = CodeGenerator()
# The dummy object has a sacrificial dict that we transplant into
# the new object.
cg.define_globals(
_R_object_new=object.__new__,
_R_dummy=_ImmutableAssignDummy)
new_args = ["_R_cls"]
for descr in all_iprops:
if descr.kwonly and "*" not in new_args:
new_args.append("*")
new_args.append(descr.arg_decl)
with cg.indent("def __new__({}):", ", ".join(new_args)):
for descr in all_iprops:
descr.emit_conversion_code(cg, do_intern, emit_checks)
def _gen_construction_code():
cg("_R_obj = _R_dummy()")
for descr in all_iprops:
descr.emit_assignment_code(cg, do_intern, emit_checks)
cg("_R_obj.__class__ = _R_cls")
if iinfo.call_post_init_check:
cg("_R_obj._post_init_check()")
if emit_checks and iinfo.call_post_init_assert:
cg("_R_obj._post_init_assert()")
if do_intern:
the_cache = weakref.WeakValueDictionary()
cg.define_globals(_R_cache=the_cache)
cg.define_globals(_R_cache_setdefault=the_cache.setdefault)
cache_key_parts = []
for descr in all_iprops:
cache_key_parts.append(descr.name)
cg("_R_cache_key = ({})", ", ".join(cache_key_parts))
with cg.indent("try:"):
cg("return _R_cache[_R_cache_key]")
with cg.indent("except KeyError:"):
_gen_construction_code()
cg("return _R_cache_setdefault(_R_cache_key, _R_obj)")
else:
_gen_construction_code()
cg("return _R_obj")
with cg.indent("def __immutable_key__(self):"):
cg("return ({})", ", ".join(
"self." + attr.propname
for attr in all_iprops))
with cg.indent("def __rehydrate_from_immutable_key__(cls, key):"):
def _emit_rehydrate():
cg("_R_obj = _R_dummy()")
if len(all_iprops) == 1:
cg("_R_obj.{attr.propname} = key".format(attr=all_iprops[0]))
elif len(all_iprops) > 1:
cg("({}) = key".format(", ".join(
"_R_obj.{attr.propname}".format(attr=attr) for attr in all_iprops
)))
cg("_R_obj.__class__ = cls")
if do_intern:
with cg.indent("try:"):
cg("return _R_cache[key]")
with cg.indent("except KeyError:"):
_emit_rehydrate()
cg("return _R_cache_setdefault(key, _R_obj)")
else:
_emit_rehydrate()
cg("return _R_obj")
evolve_args = ["_R_old"]
for descr in all_iprops:
if descr.kwonly and "*" not in evolve_args:
new_args.append("*")
evolve_args.append("{}=_R_dummy".format(descr.name))
with cg.indent("def evolve({}):", ", ".join(evolve_args)):
cg("_R_obj = _R_dummy()")
for descr in all_iprops:
descr.emit_evolve_assignment_code(cg, emit_checks)
cg("_R_obj.__class__ = _R_the_cls")
def _emit_post_init_checks(cg):
if iinfo.call_post_init_check:
cg("_R_obj._post_init_check()")
if emit_checks and iinfo.call_post_init_assert:
cg("_R_obj._post_init_assert()")
if do_intern:
post_checks = (iinfo.call_post_init_check or
(emit_checks and iinfo.call_post_init_assert))
if post_checks:
cg("_R_old = _R_obj")
cg("_R_obj = "
"_R_cache_setdefault(_R_obj.__immutable_key__, _R_obj)")
if post_checks:
with cg.indent("if _R_old is _R_obj:"):
_emit_post_init_checks(cg)
else:
_emit_post_init_checks(cg)
cg("return _R_obj")
for descr in all_iprops:
descr.augment_ns(cg)
return cg.compile()
def _munge_immutable(cls_name, bases, dict_, do_intern, emit_checks):
# TODO(dancol): investigate doing the immutable initialization in
# C instead of using codegen.
iinfo = ImmutableClassInfo.scan(cls_name, bases, dict_)
dummy_locals, dummy_globals, src = generate_immutable_constructor(
iinfo, do_intern, emit_checks)
iinfo.newfn_source = src # For pylint
new_func = dummy_locals["__new__"]
if "__new__" in dict_:
dict_["_do_new"] = new_func # __new__ must invoke
else:
dict_["__new__"] = new_func
dict_["__immutable_key__"] = \
cached_property(dummy_locals["__immutable_key__"])
dict_["__rehydrate_from_immutable_key__"] = \
classmethod(dummy_locals["__rehydrate_from_immutable_key__"])
dict_["evolve"] = dummy_locals["evolve"]
# Tweak the dict so inheritance checks work correctly for
# immutableproperties and so that we don't get a false positive on
# __new__.
tweaked_dict = dict_.copy()
tweaked_dict.pop("__new__", None)
tweaked_dict.pop("__immutable_key__", None)
tweaked_dict.pop("__rehydrate_from_immutable_key__", None)
tweaked_dict.pop("evolve")
for descr in iinfo.iprops.values():
tweaked_dict[descr.propname] = False # Dummy
return iinfo, tweaked_dict, dummy_globals
class ImmutableMeta(ExplicitInheritanceMeta):
"""Metaclass for immutable system"""
def __new__(mcs, cls_name, bases, dict_,
do_intern=False,
emit_checks=__debug__,
**kwargs):
iinfo, tweaked_dict, dummy_globals = \
_munge_immutable(cls_name, bases, dict_, do_intern, emit_checks)
cls = super().__new__(mcs, cls_name, bases, dict_,
tweaked_dict=tweaked_dict,
**kwargs)
dummy_globals["_R_the_cls"] = cls
_IMMUTABLE_INFO[cls] = iinfo
return cls
def __init__(cls, cls_name, bases, dict_, **kwargs):
kwargs.pop("emit_checks", None)
kwargs.pop("do_intern", None)
super().__init__(cls_name, bases, dict_, **kwargs)
def get_newfn_source(cls):
"""For pylint: retrieve source code of generated __new__ function"""
return _IMMUTABLE_INFO[cls].newfn_source
@property
def fields(cls):
"""Get immutable field information
Return a mapping. Keys are field names; values are
ImmutableClassInfo objects. The fields appear in the dict in the
order in which they were declared.
"""
return _IMMUTABLE_INFO[cls].all_iprops
class Immutable(ExplicitInheritance,
metaclass=ImmutableMeta):
"""Base class for objects immutable after initialization"""
def _post_init_assert(self):
"""Function called to verify invariants, but only when __debug__
Call through to super.
"""
def _post_init_check(self):
"""Function called to verify invariants in all builds
Call through to super.
"""
# Unless someone defined comparison operators, define them ourselves
# to make the resulting class unhashable and non-comparable. If you
# want unsafe by-identity non-interned comparison, you have to opt
# into it.
@override
def __eq__(self, other):
if type(self) is not type(other):
return NotImplemented
raise TypeError("{} is not comparable".format(type(self)))
@override
def __hash__(self):
raise TypeError("{} is not hashable".format(type(self)))
@final
@override
def __setattr__(self, _attrib, _value):
raise RuntimeError("immutable object is immutable")
@final
@override
def __delattr__(self, _attrib):
raise RuntimeError("immutable object is immutable")
@override
def __reduce__(self):
# pylint: disable=no-member
return _rebuild_immutable, (type(self), self.__immutable_key__)
def _no_print_keys(self): # pylint: disable=no-self-use
return frozenset()
@cached_property
def __cached_repr(self):
no_print_keys = self._no_print_keys()
iinfo = _IMMUTABLE_INFO[type(self)]
return "<{} {}>".format(
type(self).__name__,
", ".join("{}={}".format(k, i.format_value(getattr(self, k)))
for k, i in iinfo.all_iprops.items()
if k not in no_print_keys))
@override
def __repr__(self):
return self.__cached_repr
def _rebuild_immutable(immutable_type, key):
return immutable_type.__rehydrate_from_immutable_key__(key)
class IdentityHashedImmutable(Immutable):
"""Immutable with default object hash and equality semantics"""
__inherit__ = dict(
__eq__=override_final,
__hash__=override_final,
)
__eq__ = object.__eq__
__hash__ = object.__hash__
class EqImmutable(Immutable):
"""Immutable with content comparison"""
@override
def __eq__(self, other):
# TODO(dancol): generate in metaclass?
if type(self) is not type(other):
return NotImplemented
return (self.__hash == other.__hash and # pylint: disable=protected-access
self.__immutable_key__ == other.__immutable_key__)
@cached_property
def __hash(self):
return hash(self.__immutable_key__)
@override
def __hash__(self):
return self.__hash
class InternedMeta(ImmutableMeta):
"""Metaclass for interned immutables"""
def __new__(mcs, cls_name, bases, dict_, **kwargs): # pylint: disable=signature-differs
return super().__new__(mcs, cls_name, bases, dict_,
do_intern=True, **kwargs)
class Interned(Immutable, metaclass=InternedMeta):
"""Immutable objects that we hash-cons"""
__inherit__ = dict(
__eq__=override_final,
__hash__=override_final,
)
# Because the object is interned, we can just use the fast Python
# built-in comparators.
__eq__ = object.__eq__
__hash__ = object.__hash__
class FrozenIntRangeDict(collections.abc.Mapping):
"""Mapping from number ranges to values"""
def __init__(self, contents=(), end=None):
"""Make a dictionary mapping integer ranges to values.
CONTENTS is a sequence of (KEY, VALUE) pairs sorted according to
KEY. As a special case, if the last item of CONTENTS is an
integer, use that value at END. END is the total size of the
mapped range. KEY must be an integer. VALUE can be anything. KEY
values must be unique.
"""
index = array("i")
values = []
for item in contents:
if isinstance(item, int):
end = item
break
key, value = item
if __debug__:
i = len(index)
assert isinstance(key, int)
assert not i or index[i - 1] < key, "end too small or not given"
try:
index.append(key)
except OverflowError:
index = array("l", index)
index.append(key)
values.append(value)
assert not index or end > index[-1]
self._index = index
self._values = values
self._end = end
def chunks(self):
"""Yield (CHUNK_START, CHUNK_END, VALUE) seq describing true contents"""
index = self._index
if index:
for i in range(1, len(index)):
yield index[i - 1], index[i], self._values[i - 1]
yield index[-1], self._end, self._values[-1]
def __getitem__(self, key):
if key >= self._end:
raise KeyError(key)
index = bisect_right(self._index, key)
if not index:
raise KeyError(key)
return self._values[index - 1]
def __iter__(self):
for chunk_start, chunk_end, _ in self.chunks():
yield from range(chunk_start, chunk_end)
def __len__(self):
if self._index:
return self._end - self._index[0]
return 0
__slots__ = "_index", "_values", "_end"
def future_exception(future):
"""Return exception for non-success completion of FUTURE
If the future is canceled, return a CancelledError instead of
raising it. (Future.exception() bizarrely raises CancelledError
instead of returning it.)
"""
try:
return future.exception()
except CancelledError as ex:
return ex
def future_success_p(future):
"""Return (without blocking) whether FUTURE has completed successfully."""
return future.done() and not future.cancelled() and not future.exception()
def future_result_now(future, default=None):
"""Return future result or DEFAULT if not completed successfully"""
return default if ((not future.done()) or future_exception(future)) \
else future.result()
class AllCompleteFuture(ChainableFuture, ExplicitInheritance):
"""Future that completes when all indicated futures complete"""
@override
def __init__(self, futures):
super().__init__()
self.__futures = tuple(futures)
self.__pending_count = len(self.__futures)
for future in self.__futures:
future.add_done_callback(self.__on_done)
def __cancel_dependencies(self):
for future in self.__futures:
future.cancel()
self.__futures = ()
@override
def cancel(self):
super().cancel()
self.__cancel_dependencies()
@override
def set_exception(self, exception):
super().set_exception(exception)
self.__cancel_dependencies()
@override
def set_result(self, value):
super().set_result(value)
self.__cancel_dependencies()
def __on_done(self, future):
assert self.__pending_count
self.__pending_count -= 1
if not self.done():
ex = future_exception(future)
if ex:
self.set_exception(ex)
elif not self.__pending_count:
self.__futures = ()
self.set_result(None)
class DeferredSynchronousFuture(ChainableFuture, ExplicitInheritance):
"""Future that resolves its value by calling a function.
Function is called in caller's context on demand."""
@override
def __init__(self, fn=None, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg
super().__init__()
if not fn:
assert not args
assert not kwargs
else:
self.__fn = (fn, args, kwargs)
def __complete(self):
# pylint: disable=protected-access
if self.__base:
self.__base.__complete()
del self.__base
return
if self.__fn:
with self._condition: # pylint: disable=protected-access
if self.done() or self.running():
return
if not self.set_running_or_notify_cancel():
return
fn, args, kwargs = self.__fn
del self.__fn
try:
self.set_result(fn(*args, **kwargs))
except BaseException as ex:
self.set_exception(ex)
@override
def result(self, timeout=None):
self.__complete()
return super().result(timeout)
@override
def exception(self, timeout=None):
self.__complete()
return super().exception(timeout)
@override
def then(self, on_fulfilled=None, on_rejected=None):
new_future = super().then(on_fulfilled, on_rejected)
assert isinstance(new_future, type(self))
new_future.__base = self # pylint: disable=protected-access
return new_future
__fn = None
__base = None
CAPTURE_GROUP_RE = re.compile(r"(?<!\\)\(\?P<([^>]+)>")
def passivize_re(re_str, predicate=lambda _: False):
"""Conditionally make re_str named capture groups insert.
PREDICATE is a function of one argument, a regex named capture group
name. If PREDICATE returns true, keep the match in the returned
regex string; otherwise, transform the named group to a
non-capturing one.
"""
def _callback(m):
if predicate(m.group(1)):
return m.group(0)
return r"(?:" # Transform to insert group
return CAPTURE_GROUP_RE.sub(_callback, re_str)
class TmpDir(SharedObject):
"""Temporary directory cleaned up via modernmp"""
def __init__(self, suffix="", prefix="tmp", dir_=None):
super().__init__()
from tempfile import mkdtemp
self.name = mkdtemp(suffix, prefix, dir_)
def _resman_destroy(self):
from shutil import rmtree
rmtree(self.name, ignore_errors=True)
# Misc
class UsageTrackingDictionary(collections.abc.MutableMapping):
"""Dict wrapper that tracks which keys were accessed"""
def __init__(self, data):
assert isinstance(data, collections.abc.Mapping)
self.__data = data
self.__used_keys = set()
@property
def used(self):
"""Return a set of keys which we accessed"""
return frozenset(self.__used_keys)
def __getitem__(self, key):
value = self.__data[key]
self.__used_keys.add(key)
return value
def __iter__(self):
return iter(self.__data)
def clear(self):
self.__data.clear()
def __len__(self):
return len(self.__data)
def __delitem__(self, key):
del self.__data[key]
def __setitem__(self, key, value):
raise Exception("assigning inputs dict not supported")
def timeline_to_eta(timeline):
"""Compute time-to-arrival for entries in TIMELINE.
Each value in sequence TIMELINE is a set EVENTS of EVENT, where
EVENT is an event that "arrives" at TIMELINE[t], 0 <= t <
len(TIMELINE).
Return a list of the same length as TIMELINE in which each item is
a dict mapping V to the remaining "time" (as measured by t above)
until the value arrives. Having arrives, the value disappears from
the resulting set.
If an event arrives more than once, its distance is reset at the
point at which it arrives.
"""
arrival_timeline = []
arrivals = {}
for t in range(len(timeline) - 1, -1, -1):
for event in tuple(arrivals):
arrivals[event] += 1
for event in timeline[t]:
arrivals[event] = 0
arrival_timeline.append(arrivals.copy())
arrival_timeline.reverse()
return arrival_timeline
def unlink_if_exists(file_name):
"""Like os.unlink, but succeeds if the file did not exist"""
try:
os.unlink(file_name)
except FileNotFoundError:
pass
DCTV_DIR = None
def dctv_dir():
"""Return the top-level DCTV distribution directory"""
global DCTV_DIR
the_dctv_dir = DCTV_DIR
if not the_dctv_dir:
mydir = dirname(__file__)
DCTV_DIR = the_dctv_dir = normpath(pjoin(mydir, "..", ".."))
return the_dctv_dir
NoneType = type(None)
# Numpy utilities
# Using these ends up being much faster than calling np.dtype
# (np.dtype(np.int64) costs around 158ns.)
INT8 = np.dtype(np.int8)
UINT8 = np.dtype(np.uint8)
INT16 = np.dtype(np.int16)
UINT16 = np.dtype(np.uint16)
INT32 = np.dtype(np.int32)
UINT32 = np.dtype(np.uint32)
INT64 = np.dtype(np.int64)
UINT64 = np.dtype(np.uint64)
BOOL = np.dtype(np.bool_)
FLOAT32 = np.dtype(np.float32)
FLOAT64 = np.dtype(np.float64)
def xabspath(path):
"""Like os.path.abspath, but used cached cwd"""
if not isabs(path):
path = pjoin(STARTUP_CWD, path)
return path
def ignore_numpy_errors():
"""Context manager that ignores all numpy errors"""
return np.errstate(divide="ignore",
over="ignore",
under="ignore",
invalid="ignore")
def is_nondecreasing(arr):
"""Determine whether array ARR is nondecreasing.
Returns a boolean.
"""
return (arr[:-1] <= arr[1:]).all()
# SortedDict (Banyan replacement)
class SortedDict(ExplicitInheritance,
brain_suck_abc=collections.abc.MutableMapping):
"""Dictionary that maintains its keys in sorted order"""
@override
def __init__(self):
self._data = {}
self._sorted_keys = None
@override
def __getitem__(self, key):
return self._data[key]
@override
def __setitem__(self, key, value):
self._data[key] = value
self._sorted_keys = None
@override
def __delitem__(self, key):
del self._data[key]
self._sorted_keys = None
@override
def __iter__(self):
return iter(self.keys())
@override
def __len__(self):
return len(self.keys())
@override
def keys(self):
keys = self._sorted_keys
if keys is None:
self._sorted_keys = keys = tuple(sorted(self._data))
return keys
__slots__ = ["_data", "_sorted_keys"]
# Sequence utilities
def all_unique(sequence):
"""Return whether all items in SEQUENCE are unique"""
return len(sequence) == len(set(sequence))
def all_same(sequence):
"""Return whether all items in SEQUENCE are equal"""
items = list(sequence)
if not items:
return True
first_item = items[0]
return all(item == first_item for item in items[1:])
def drain(iterable):
"""Deplete ITERABLE. Like list(map(...)), but without storing results."""
for _ in iterable:
pass
def partition(predicate, sequence):
"""Split a sequence by a predicate.
Return a tuple (YES_SEQUENCE, NO_SEQUENCE). Each sequence contains
the elements of SEQUENCE that did or did not, respectively, match
the predicate.
"""
t1, t2 = tee(sequence)
return filter(predicate, t1), filterfalse(predicate, t2)
def _assert_seq_elements(elem_type, seq):
if __debug__:
if not all(isinstance(elem, elem_type) for elem in seq):
bad = [elem for elem in seq if not isinstance(elem, elem_type)]
raise AssertionError(
"elements did not meet type requirement {!r}: {!r}"
.format(elem_type, bad))
return True
def argmin(sequence):
"""Return the index of the smallest item in SEQUENCE"""
return min(enumerate(sequence), key=second)[0]
def argmax(sequence):
"""Return the index of the largest item in SEQUENCE"""
return max(enumerate(sequence), key=second)[0]
def common_prefix(compare, *seqs):
"""Lazily compute the common prefix of SEQS.
The comparison is done with the COMPARE function, which should take
as many arguments as there are in SEQS and return whether they are
all equal.
"""
return map(first, takewhile(lambda items: compare(*items), zip(*seqs)))
def consume_list(alist):
"""Iterate over ALIST, destroying it as we go.
Assumes ownership of ALIST. The function makes no guarantees about
what ALIST looks like while we're iterating over it.
"""
assert isinstance(alist, list)
alist.reverse()
while alist:
yield alist.pop()
class CaseInsensitiveCasePreservingDict(
ExplicitInheritance,
brain_suck_abc=collections.abc.MutableMapping):
"""Case-preserving case-insensitive dict"""
__slots__ = ("_contents",)
@override
def __init__(self):
self._contents = {}
@override
def __getitem__(self, key):
try:
return second(self._contents[key.lower()])
except KeyError:
return self.__missing__(key)
@override
def __setitem__(self, key, value):
self._contents[key.lower()] = (key, value)
@override
def __delitem__(self, key):
del self._contents[key.lower()]
@override
def __iter__(self):
for original_name, _ in self._contents.values():
yield original_name
@override
def __len__(self):
return len(self._contents)
@override
def clear(self):
self._contents.clear()
def __missing__(self, key):
raise KeyError(key)
def vard(*varnames):
"""Debug facility: dump VARNAMES from enclosing scope
Each VARNAME in VARNAMES is a string.
"""
xlog = inspect.currentframe().f_back.f_globals.get("log", log)
for varname in varnames:
value = inspect.currentframe().f_back.f_locals[varname]
# Don't hit on greps for the three-X string
xlog.debug("X""XX %s=%r", varname, value)