| # Copyright (C) 2020 The Android Open Source Project |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http:#www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| """Utility functions""" |
| from array import array |
| import abc |
| import sys |
| |
| from collections import namedtuple, OrderedDict |
| import collections.abc |
| |
| from bisect import bisect_right |
| from concurrent.futures import Future, ThreadPoolExecutor, CancelledError |
| from contextlib import contextmanager |
| from itertools import tee, filterfalse, takewhile, chain |
| from enum import Enum |
| |
| import warnings |
| import weakref |
| import inspect |
| import logging |
| import operator |
| import os |
| from os.path import ( |
| dirname, |
| isabs, |
| join as pjoin, |
| normpath, |
| ) |
| |
| import threading |
| import time |
| import re |
| |
| import numpy as np |
| |
| # Just import pint eagerly, since we rely on it during |
| # bootstrap anyway. |
| |
| # TODO(dancol): figure out whether we can get pint's load time down |
| import pint |
| |
| from cytoolz import first, second, merge |
| from cytoolz.dicttoolz import valmap as dtz_valmap |
| |
| from modernmp.util import ( |
| ChainableFuture, |
| STARTUP_CWD, |
| assert_seq_type, |
| cached_property, |
| once, |
| ) |
| |
| from modernmp.shm import SharedObject |
| |
| warnings.filterwarnings("ignore", category=pint.UnitStrippedWarning) |
| warnings.filterwarnings("ignore", category=DeprecationWarning) |
| |
| log = logging.getLogger(__name__) |
| |
| @once() |
| def _get_timing_tls(): |
| return threading.local() |
| |
| @once() |
| def get_cpu_count(): |
| """Get the number of cores on the system""" |
| return os.cpu_count() |
| |
| @once() |
| def get_process_thread_pool(): |
| """Get the global process thread pool""" |
| return ThreadPoolExecutor(max_workers=get_cpu_count()) |
| |
| |
| |
| def _dummy_for_freeze(): |
| # pylint: disable=unused-variable,unused-import |
| import cytoolz._signatures |
| |
| def module_loader(module_name): |
| """Caching and reporting for module load function""" |
| def _decorator(real_importer): |
| loader_holder = [None] |
| # pylint: disable=dangerous-default-value |
| def _loader(real_importer=real_importer, |
| loader_holder=loader_holder, |
| module_name=module_name): # pylint: disable=unused-argument |
| #TODO(dancol): module loading benchmarks |
| #with Timed("loading {}".format(module_name)): |
| # module = real_importer() |
| module = real_importer() |
| def _fast_loader(module=module): |
| return module |
| loader_fn = loader_holder[0] |
| loader_fn.__code__ = _fast_loader.__code__ |
| loader_fn.__defaults__ = _fast_loader.__defaults__ |
| return module |
| loader_holder[0] = _loader |
| return _loader |
| return _decorator |
| |
| @module_loader("pandas") |
| def load_pandas(_lock=threading.Lock(), _state=[False]): # pylint: disable=dangerous-default-value |
| """Load pandas""" |
| # Keep open-coded so freeze notices the module |
| import pandas |
| return pandas |
| |
| @module_loader("networkx") |
| def load_networkx(): |
| """Load networkx""" |
| # Keep open-coded so freeze notices the module |
| import networkx |
| return networkx |
| |
| @once() |
| def ureg(): |
| """Return the DCTV Pint unit registry""" |
| return pint.UnitRegistry(pjoin(dirname(__file__), "units.txt")) |
| |
| |
| |
| def make_pd_dataframe(column_dict): |
| """Make Pandas dataframe from column contents without copying""" |
| pd = load_pandas() |
| column_df = [pd.DataFrame(v, columns=[k], copy=False) |
| for k, v in column_dict.items()] |
| return pd.concat(column_df, axis=1, copy=False) |
| |
| |
| |
| def do_mt(func, *args, **kwargs): |
| """Call FUNC in a worker thread in the process thread pool""" |
| return get_process_thread_pool().submit(func, *args, **kwargs) |
| |
| def map_mt(func, arg_list): |
| """Apply FUNC to each item in ARG_LIST in process thread pool""" |
| return map(Future.result, |
| (get_process_thread_pool().submit(func, arg) |
| for arg in arg_list)) |
| |
| def lmap(*args, **kwargs): |
| """Like map, but build a list synchronously""" |
| return list(map(*args, **kwargs)) |
| |
| def get_object_cache(obj, |
| _ga=getattr, |
| _sa=object.__setattr__): |
| """Retrieve the cache directory for OBJ. |
| |
| Creates the dictionary if it does not already exist. |
| """ |
| cache = _ga(obj, "__cache", None) |
| if cache is None: |
| cache = {} |
| _sa(obj, "__cache", cache) |
| return cache |
| |
| TIMING_FIELDS = [("wall", time.monotonic)] |
| if hasattr(time, "CLOCK_THREAD_CPUTIME_ID"): |
| TIMING_FIELDS.append( |
| ("threadcpu", |
| lambda: time.clock_gettime(time.CLOCK_THREAD_CPUTIME_ID))) |
| |
| class TimeSnapshot( |
| namedtuple("TimeSnapshot", map(lambda tf: tf[0], TIMING_FIELDS))): # pylint: disable=undefined-variable |
| """Snapshot of process and thread performance counts information""" |
| __slots__ = [] |
| @staticmethod |
| def current(): |
| """Sample performance counters""" |
| return TimeSnapshot(*(x[1]() for x in TIMING_FIELDS)) |
| def __sub__(self, other): |
| return TimeSnapshot(*map(operator.sub, self, other)) |
| |
| class NoStripString(str): |
| """Hack to prevent tabulate from stripping leading whitespace""" |
| def strip(self, chars=None): |
| return NoStripString(self.rstrip(chars)) |
| def lstrip(self, chars=None): |
| return self |
| def __str__(self): |
| return self |
| |
| class DummyTimed: |
| """Swap in for Timed to make a noop""" |
| def __init__(self, *args, **kwargs): |
| pass |
| |
| def __enter__(self): |
| pass |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| pass |
| |
| class Timed: |
| # pylint: disable=attribute-defined-outside-init |
| """Context manager that records timing for code sections""" |
| INDENT = 2 |
| |
| def __init__(self, name, independent=False): |
| self.name = name |
| self.children = [] |
| self.independent = independent |
| |
| def __enter__(self): |
| if not hasattr(_get_timing_tls(), "current"): |
| _get_timing_tls().current = None |
| self.parent = _get_timing_tls().current |
| _get_timing_tls().current = self |
| self.start = TimeSnapshot.current() |
| return self |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| self.stop = TimeSnapshot.current() |
| assert _get_timing_tls().current is self |
| _get_timing_tls().current = self.parent |
| if self.parent is not None: |
| self.parent.children.append(self) |
| self.parent = None # Avoid cycle |
| if not self.independent: |
| return |
| if exc_type is not None: |
| return |
| |
| xlog = inspect.currentframe().f_back.f_globals.get("log", log) |
| |
| from io import StringIO |
| buf = StringIO() |
| self.dump(buf, self.INDENT) |
| sys.stdout.flush() |
| xlog.debug("TIMING INFORMATION %s\n%s", |
| self.name, |
| buf.getvalue()) |
| |
| def __generate_rows(self, indent): |
| elapsed = self.stop - self.start |
| yield ([NoStripString("%*s%s" % (indent, "", self.name))] + |
| [x * 1000.0 for x in elapsed]) |
| for child in self.children: |
| # pylint: disable=protected-access |
| yield from child.__generate_rows(indent + self.INDENT) |
| |
| def dump(self, out, indent=0): |
| """Print accumulated timing information""" |
| rows = list(self.__generate_rows(indent)) |
| from tabulate import tabulate |
| out.write( |
| tabulate( |
| rows, |
| tablefmt="plain", |
| floatfmt="4.3f", |
| headers=[""] + list(map(str.upper, TimeSnapshot._fields)) |
| )) |
| out.write("\n") |
| |
| def collect_profile(fn, profile_out_file_name): |
| """Call a function with a profile; dump in kcachegrind format""" |
| import cProfile |
| profiler = cProfile.Profile(time.process_time) |
| def _wrapper(): |
| try: |
| fn() |
| except: |
| log.warning("fn failed") |
| profiler.runcall(_wrapper) |
| from pyprof2calltree import convert |
| convert(profiler.getstats(), profile_out_file_name) |
| log.info("wrote profiler output to %r", profile_out_file_name) |
| |
| @contextmanager |
| def envvar_set(variable, value): |
| """Sets an environment variable while code runs""" |
| old_value = os.environ.get(variable) |
| os.environ[variable] = value |
| try: |
| yield |
| finally: |
| if old_value is None: |
| del os.environ[variable] |
| else: |
| os.environ[variable] = old_value |
| |
| class FrozenDict(collections.abc.Hashable, |
| collections.abc.Mapping): |
| """Immutable, hashable dictionary""" |
| def __init__(self, *args, **kwargs): |
| self._data = dict(*args, **kwargs) |
| self._hash = None |
| |
| def __getitem__(self, key): |
| return self._data.__getitem__(key) |
| |
| def __iter__(self): |
| return iter(self._data) |
| |
| def __len__(self): |
| return len(self._data) |
| |
| def __hash__(self): |
| my_hash = self._hash |
| if my_hash is None: |
| self._hash = my_hash = hash(frozenset(self._data.items())) |
| return my_hash |
| |
| def __eq__(self, other): |
| if type(self) is not type(other): |
| return NotImplemented |
| # pylint: disable=protected-access |
| return self._data == other._data |
| |
| def __repr__(self): |
| return "FrozenDict({})".format( |
| ", ".join("{!r}: {!r}".format(k, v) |
| for k, v in sorted(self._data.items()))) |
| |
| def copy(self): |
| """Clone dict: in immutable case, no need to copy""" |
| return self |
| |
| def __reduce__(self): |
| return FrozenDict, (self._data,) |
| |
| __slots__ = "_data", "_hash" |
| |
| |
| # Autonumber |
| |
| class AutoNumberSafeMeta(type(Enum)): |
| """Metaclass for AutoNumberSafe""" |
| def __new__(mcs, name, bases, dict_): # pylint: disable=bad-classmethod-argument |
| cls = super().__new__(mcs, name, bases, dict_) |
| cls.owns = lambda value, _cls=cls: isinstance(value, _cls) |
| return cls |
| |
| class AutoNumberSafe(Enum, metaclass=AutoNumberSafeMeta): # pylint: disable=invalid-metaclass |
| """Enum wrapper""" |
| def __new__(cls): |
| value = len(cls.__members__) + 1 |
| obj = object.__new__(cls) |
| obj._value_ = value # pylint: disable=protected-access |
| return obj |
| |
| class AutoNumberFastMeta(type): |
| """Metaclass for AutoNumberFast""" |
| def __new__(mcs, cls_name, bases, dict_): |
| # TODO(dancol): pseudorandomly vary start so collisions are less |
| # likely even though we're using ints |
| start = dict_.pop("__start__", 1) |
| idx = start |
| members = OrderedDict() |
| for name, value in tuple(dict_.items()): |
| if value == (): |
| dict_[name] = idx |
| members[name] = idx |
| idx += 1 |
| valrange = range(start, idx) |
| @staticmethod |
| def _owns(value, _range=valrange): |
| return isinstance(value, int) and value in _range |
| dict_["__members__"] = members |
| dict_["__range__"] = valrange |
| new_cls = super().__new__(mcs, cls_name, bases, dict_) |
| return new_cls |
| |
| def owns(cls, value): |
| """Return whether this enum-like class owns the given value""" |
| return isinstance(value, int) and value in cls.__range__ |
| |
| def __iter__(cls): |
| return iter(cls.__range__) |
| |
| def label_of(cls, value): |
| """"String name of enumeration value""" |
| for xkey, xvalue in cls.__members__.items(): |
| if xvalue == value: |
| return xkey |
| raise KeyError(value) |
| |
| def __call__(cls, value): |
| # pylint: disable=no-value-for-parameter |
| if isinstance(value, str): |
| return getattr(cls, value.upper()) |
| assert cls.owns(value) |
| return value |
| |
| def __instancecheck__(cls, value): |
| return cls.owns(value) # pylint: disable=no-value-for-parameter |
| |
| class AutoNumberFast(metaclass=AutoNumberFastMeta): |
| """Lightweight integer mapping""" |
| value_type = int |
| |
| if __debug__: |
| #AutoNumber = AutoNumberSafe |
| AutoNumber = AutoNumberFast |
| else: |
| AutoNumber = AutoNumberFast |
| |
| |
| # Paranoid checking for object inheritance. Objects with metaclass |
| # ExplicitInheritanceMeta fail construction when they don't satisfy |
| # Java-style constraints on methods. This little facility helps |
| # address the "fragile base class" problem. |
| |
| class InheritMode(AutoNumber): |
| """Ways in which a field or class can be inherited""" |
| ABSTRACT = () |
| FINAL = () |
| OVERRIDE = () |
| OVERRIDE_FINAL = () |
| UNSPECIFIED = () |
| |
| _INHERIT_MODE = "_inherit_mode" |
| _inherit_by_type = weakref.WeakKeyDictionary() |
| |
| def _make_inheritance_decorator(mode): |
| def _decorator(obj, mode=mode): |
| obj_inherit_info = (not isinstance(obj, (classmethod, staticmethod)) |
| and _inherit_by_type.get(obj)) |
| if obj_inherit_info: |
| obj_class_mode, obj_class_fields = obj_inherit_info |
| if (mode is obj_class_mode or |
| obj_class_mode == InheritMode.UNSPECIFIED): |
| obj_class_mode = mode |
| else: |
| raise ValueError( |
| "conflicting class types for {!r}: {!r} vs {!r}" |
| .format(obj, mode, obj_class_mode)) |
| _inherit_by_type[obj] = (obj_class_mode, obj_class_fields) |
| return obj |
| |
| preset_mode = getattr(obj, _INHERIT_MODE, None) |
| if preset_mode: |
| if ((mode == InheritMode.FINAL |
| and preset_mode == InheritMode.OVERRIDE) |
| or |
| (mode == InheritMode.OVERRIDE |
| and preset_mode == InheritMode.FINAL)): |
| mode = InheritMode.OVERRIDE_FINAL |
| else: |
| raise ValueError("obj already marked") |
| setattr(obj, _INHERIT_MODE, mode) |
| return obj |
| setattr(_decorator, "_for_inherit_mode", mode) |
| return _decorator |
| |
| abstract = _make_inheritance_decorator(InheritMode.ABSTRACT) |
| final = _make_inheritance_decorator(InheritMode.FINAL) |
| override = _make_inheritance_decorator(InheritMode.OVERRIDE) |
| override_final = _make_inheritance_decorator(InheritMode.OVERRIDE_FINAL) |
| |
| class InheritanceConstraintViolationError(RuntimeError): |
| """Exception raised when a class violates inheritance constraints""" |
| |
| def _get_cls_inherit_info(base): |
| info = _inherit_by_type.get(base) |
| if not info: |
| # Outside ExplicitInheritance system, so synthesize |
| member_info = { |
| member_name: (InheritMode.ABSTRACT |
| if getattr(member, "__isabstractmethod__", False) |
| else InheritMode.UNSPECIFIED) |
| for member_name, member in base.__dict__.items()} |
| _inherit_by_type[base] = info = InheritMode.UNSPECIFIED, member_info |
| return info |
| |
| def _check_inheritance(cls, |
| dict_, |
| marked_abstract, |
| marked_inherit, |
| virtual_mro): |
| violations = [] |
| |
| def _scan_base(base): |
| base_mode, base_fields = _get_cls_inherit_info(base) |
| if base_mode == InheritMode.FINAL: |
| violations.append( |
| "cannot derive from final class {!r}".format(base.__name__)) |
| return base_fields |
| |
| def _scan_bases(cls, virtual_mro): |
| all_base_fields = {} |
| virtual_base_fields = {} |
| seen_bases = set() |
| for base in chain(reversed(cls.mro()[1:]), reversed(virtual_mro)): |
| if base in seen_bases: |
| continue |
| seen_bases.add(base) |
| base_fields = _scan_base(base) |
| all_base_fields.update(base_fields) |
| if base in virtual_mro: |
| virtual_base_fields.update(base_fields) |
| return all_base_fields, virtual_base_fields |
| |
| all_base_fields, virtual_base_fields = _scan_bases(cls, virtual_mro) |
| explicit_inheritance = dtz_valmap( |
| lambda v: getattr(v, "_for_inherit_mode"), |
| marked_inherit) |
| |
| def _scan_field(name, inherit_mode, base_inherit_mode): |
| if base_inherit_mode in (InheritMode.FINAL, |
| InheritMode.OVERRIDE_FINAL): |
| violations.append("final method {!r} overridden".format(name)) |
| if (base_inherit_mode and |
| inherit_mode not in (InheritMode.ABSTRACT, |
| InheritMode.OVERRIDE, |
| InheritMode.OVERRIDE_FINAL)): |
| if name not in ( |
| "__abstract__", |
| "__classcell__", |
| "__doc__", |
| "__inherit__", |
| "__module__", |
| "__qualname__", |
| "__slots__", |
| ): |
| violations.append( |
| "method {!r} overridden without @override".format(name)) |
| if (inherit_mode in (InheritMode.OVERRIDE, InheritMode.OVERRIDE_FINAL) |
| and base_inherit_mode is None): |
| violations.append( |
| "method {!r} is marked @override but overrides nothing".format(name)) |
| if inherit_mode == InheritMode.ABSTRACT: |
| if base_inherit_mode == InheritMode.ABSTRACT: |
| violations.append( |
| "abstract method {!r} overrides abstract method".format(name)) |
| |
| def _scan_fields(fields): |
| field_info = {} |
| for name, value in fields: |
| if isinstance(value, property): |
| field_inherit_mode = InheritMode.FINAL |
| elif hasattr(value, _INHERIT_MODE): |
| field_inherit_mode = getattr(value, _INHERIT_MODE) |
| delattr(value, _INHERIT_MODE) |
| if name in explicit_inheritance: |
| violations.append("method {!r} already described in __inherit__" |
| .format(name)) |
| else: |
| field_inherit_mode = \ |
| explicit_inheritance.get(name, InheritMode.UNSPECIFIED) |
| field_info[name] = field_inherit_mode |
| base_field_inherit_mode = all_base_fields.get(name, None) |
| _scan_field(name, field_inherit_mode, base_field_inherit_mode) |
| return field_info |
| |
| field_info = _scan_fields(dict_.items()) |
| |
| unused_explicit_inheritance = set(explicit_inheritance) - set(field_info) |
| if unused_explicit_inheritance: |
| violations.append( |
| "fields {} in __inherit__ but not in class".format( |
| ",".join(sorted(unused_explicit_inheritance)))) |
| |
| # Need __abstract__ because @abstract class decorator won't run |
| # until after we're done chewing through the class dictionary here. |
| has_any_abstract = InheritMode.ABSTRACT in field_info.values() |
| is_cls_abstract = marked_abstract |
| if not (is_cls_abstract or has_any_abstract): |
| for name, mode in all_base_fields.items(): |
| if mode == InheritMode.ABSTRACT and name not in field_info: |
| violations.append( |
| "abstract method {} not overridden".format(name)) |
| |
| if violations: |
| raise InheritanceConstraintViolationError( |
| "inheritance constraints violated for {}: {}".format( |
| cls, |
| "; ".join(sorted(violations)))) |
| |
| cls_mode = InheritMode.UNSPECIFIED |
| if is_cls_abstract: |
| cls_mode = InheritMode.ABSTRACT |
| if virtual_base_fields: |
| field_info = merge(virtual_base_fields, field_info) |
| _inherit_by_type[cls] = (cls_mode, field_info) |
| |
| def _cls_new_members(cls): |
| """Return members introduced in a class""" |
| members = set(cls.__dict__) |
| for base in cls.mro()[1:]: |
| members -= set(base.__dict__) |
| return members |
| |
| def _do_brain_suck_abc(brain_suck_abc, dict_, bases): |
| """Fake inheritance by copying members into a new class""" |
| real_bases = frozenset(chain.from_iterable(base.mro() for base in bases)) |
| new_members = set() |
| for base in reversed(brain_suck_abc.mro()): |
| if base not in real_bases: |
| new_members.update(_cls_new_members(base)) |
| for name in sorted(new_members): |
| if (name not in ( |
| "__abstractmethods__", |
| "__module__", |
| "__slots__", |
| "__subclasshook__", |
| ) and not name.startswith("_abc_") |
| and not name in dict_): |
| new_member = getattr(brain_suck_abc, name) |
| if not getattr(new_member, "__isabstractmethod__", False): |
| dict_[name] = new_member |
| |
| class ExplicitInheritanceMeta(type): |
| """Metaclass that enforces inheritance invariants""" |
| def __new__(mcs, cls_name, bases, dict_, |
| tweaked_dict=None, |
| brain_suck_abc=None): |
| # Need to remove these from dict_ *before* creating the type |
| marked_abstract = dict_.pop("__abstract__", False) |
| marked_inherit = dict_.pop("__inherit__", {}) |
| virtual_mro = () |
| if brain_suck_abc: |
| # Special case that allows for "inheritance" from container ABC |
| # helper classes without involving ABCMeta in the picture and |
| # screwing up our own metaclass hierarchy. This approach works |
| # only because all of these mixins are stateless. |
| assert issubclass(type(brain_suck_abc), abc.ABCMeta) |
| virtual_mro = brain_suck_abc.mro() |
| tweaked_dict = tweaked_dict or dict_.copy() |
| _do_brain_suck_abc(brain_suck_abc, dict_, bases) |
| new_cls = super().__new__(mcs, cls_name, bases, dict_) |
| _check_inheritance(new_cls, |
| tweaked_dict or dict_, |
| marked_abstract, |
| marked_inherit, |
| virtual_mro) |
| if brain_suck_abc: |
| brain_suck_abc.register(new_cls) |
| return new_cls |
| |
| def __init__(cls, cls_name, bases, dict_, **kwargs): |
| kwargs.pop("brain_suck_abc", None) |
| super().__init__(cls_name, bases, dict_, **kwargs) |
| |
| class ExplicitInheritance(metaclass=ExplicitInheritanceMeta): |
| """Base class for objects with explicit inheritance checking""" |
| |
| |
| # Immutable |
| |
| NO_DEFAULT = object() |
| |
| class ImmutableProperty(ExplicitInheritance): |
| """Description of an immutable property |
| |
| Classes deriving from Immutable auto-generate constructors that set |
| the field described by these properties, including those found in |
| subclasses. |
| """ |
| |
| @override |
| def __init__(self, type_=None, *, |
| assert_checker=None, |
| type_if_debug=None, |
| default=NO_DEFAULT, |
| kwonly=False, |
| converter=None, |
| name=None, |
| nullable=False): |
| """Build an immutable property. |
| |
| TYPE_ is either a type or a tuple of types (just like for |
| isinstance). It's idiomatically given to this constructor by |
| position; the rest of the arguments are keyword-only. In debug |
| builds, we assert that the class property matches TYPE_ (but after |
| CONVERTER, if present, runs). |
| |
| ASSERT_CHECKER is a callable that runs only in debug mode and that |
| checks values for validity. It should return true if everything |
| is okay. (The generated code calls ASSERT_CHECKER inside an |
| assert.) ASSERT_CHECKER runs _after_ CONVERTER if both |
| are present. |
| |
| TYPE_IF_DEBUG is a callable that yields the actual type of the |
| iattr. (TODO(dancol): get rid of this. We used it for ureg() |
| types, but we can load pint eagerly now.) |
| |
| DEFAULT is a default value for this property, or NO_DEFAULT if the |
| property should be required. Default property values obey usual |
| Python function argument rules: non-defaults must precede |
| all defaults. |
| |
| If KWONLY is true, this property can be given to the class |
| constructor only by keyword, not by position. |
| |
| If CONVERTER is given, it's a callable that takes as input the raw |
| value given to the clss constructor and returns the actual value |
| the class field should have. It runs before ASSERT_CHECKER. |
| |
| NAME, if given, overrides the constructor argument name, which |
| normally matches the property name. |
| |
| NULLABLE is a convenience function that automatically adds |
| NoneType to the list of allowed types and that avoids calling |
| CONVERTER and ASSERT_CHECKER if the value given to the field is |
| None. (You can do the same thing by hand, but it's annoying.) |
| |
| N.B. NULLABLE and DEFAULT are orthogonal: if a value is nullable |
| but has no default, callers must specify a value for the field, |
| but that value may be None. If a value has a default but is not |
| nullable, then the value may not have the value None (unless None |
| is explicitly allowed) and the default value must match the type |
| spec, but callers don't have to explicitly supply the value. |
| """ |
| |
| if __debug__: |
| if type_: |
| assert not type_if_debug |
| elif type_if_debug: |
| type_ = type_if_debug() |
| assert not type_ or \ |
| isinstance(type_, type) or \ |
| assert_seq_type(tuple, type, type_) |
| self.type_ = type_ |
| assert not assert_checker or callable(assert_checker) |
| self._assert_checker = assert_checker |
| self._default = default |
| self._kwonly = bool(kwonly) |
| assert not converter or callable(converter) |
| self._converter = converter |
| assert not name or isinstance(name, str) |
| self.name = name |
| self.propname = None |
| self._ns = {} |
| assert isinstance(nullable, bool) |
| self.nullable = nullable |
| |
| def format_value(self, value): # pylint: disable=no-self-use |
| """Format a field value for __repr__""" |
| return repr(value) |
| |
| def set_name(self, name): |
| """Set the name of the property; used during metaclass setup""" |
| assert not self.name |
| assert name and isinstance(name, str) |
| if __debug__: |
| if name.startswith("__") or name.startswith("_R_"): |
| raise ValueError("invalid iprop name {!r}".format(name)) |
| self.name = name |
| |
| @property |
| def has_default(self): |
| """Do we have a default value?""" |
| return self._default is not NO_DEFAULT |
| |
| @property |
| def kwonly(self): |
| """Is this property keyword-only?""" |
| return self._kwonly |
| |
| def _mkref(self, attr): |
| ds = "_R_v_{attr}_{propname}".format( |
| attr=attr, propname=self.propname) |
| self._ns[ds] = getattr(self, attr) |
| return ds |
| |
| @cached_property |
| def _default_ref(self): |
| return self._mkref("_default") |
| |
| @cached_property |
| def _type_ref(self): |
| return self._mkref("type_") |
| |
| @cached_property |
| def _converter_ref(self): |
| return self._mkref("_converter") |
| |
| @cached_property |
| def _assert_checker_ref(self): |
| return self._mkref("_assert_checker") |
| |
| @property |
| def _convert_expr(self): |
| if self._converter: |
| return "{attr._converter_ref}({attr.name})".format(attr=self) |
| return self.name |
| |
| @property |
| def _checked_convert_expr(self): |
| if self.nullable and self._converter: |
| return ("None if {attr.name} is None else {attr._convert_expr}" |
| .format(attr=self)) |
| return self._convert_expr |
| |
| @property |
| def __is_specified_expr(self): |
| # N.B. Python compiles "if True" to a no-op, so returning True |
| # here doesn't impact runtime performance. |
| if not self.nullable: |
| return "True" |
| return "{attr.name} is not None".format(attr=self) |
| |
| @staticmethod |
| def has_conversion_code(do_intern, emit_checks): |
| """Return whether we need a separate conversion function step""" |
| return emit_checks or do_intern |
| |
| def emit_conversion_code(self, cg, do_intern, emit_checks): |
| """Emit the type conversion and checking code: used during codegen""" |
| # TODO(dancol): we can avoid the intermediate value assignment if |
| # we're smarter. |
| if not self.has_conversion_code(do_intern, emit_checks): |
| return # Will call converter in assignment |
| have_checks = emit_checks and (self._assert_checker or self.type_) |
| if have_checks or self._converter: |
| with cg.indent("if {}:".format(self.__is_specified_expr)): |
| if self._converter: |
| cg("{attr.name} = {attr._convert_expr}", attr=self) |
| if emit_checks and self._assert_checker: |
| assert have_checks |
| cg("assert {attr._assert_checker_ref}({attr.name}), " |
| "{msg!r}.format(name={attr.name!r}, " |
| "checker={attr._assert_checker_ref})", |
| msg=("type check {checker} " |
| "failed for {name!r}"), |
| attr=self) |
| if emit_checks and self.type_: |
| assert have_checks |
| cg("assert isinstance({attr.name}, {attr._type_ref}), " |
| "{msg!r}.format(name={attr.name!r}, " |
| "have={attr.name}, want={attr._type_ref}, cls=_R_the_cls)", |
| msg=("type mismatch in {cls} for {name!r}: " |
| "have {have!r} want {want!r}"), |
| attr=self) |
| |
| def emit_assignment_code(self, cg, do_intern, emit_checks): |
| """Emit the field assignment line: used during codegen""" |
| if not self.has_conversion_code(do_intern, emit_checks): |
| cg("_R_obj.{attr.propname} = {attr._checked_convert_expr}", attr=self) |
| else: |
| cg("_R_obj.{attr.propname} = {attr.name}", attr=self) |
| |
| def emit_evolve_assignment_code(self, cg, emit_checks): |
| """Emit code to implement a field assignment in evolve()""" |
| if not self.has_conversion_code(False, emit_checks): |
| cg("_R_obj.{attr.propname} = _R_old.{attr.propname} " |
| "if {attr.name} is _R_dummy else ({attr._checked_convert_expr})", |
| attr=self) |
| else: |
| with cg.indent("if {attr.name} is _R_dummy:", attr=self): |
| cg("_R_obj.{attr.propname} = _R_old.{attr.propname}", |
| attr=self) |
| with cg.indent("else:"): |
| self.emit_conversion_code(cg, False, emit_checks) |
| self.emit_assignment_code(cg, False, emit_checks) |
| |
| @property |
| def arg_decl(self): |
| """Constructor argument expression: used during codegen""" |
| assert self.name |
| if self.has_default: |
| return "{attr.name}={attr._default_ref}".format(attr=self) |
| return self.name |
| |
| def augment_ns(self, ns): |
| """Update a dict with this property: used during codegen""" |
| ns.define_globals(**self._ns) |
| |
| # N.B. Type argument must be called "type_" for pylint inference to |
| # work correctly. |
| |
| iattr = ImmutableProperty |
| def sattr(type_=None, converter=frozenset, nonempty=False, **kwargs): |
| """Immutable attribute with set value""" |
| def _assert_checker(value): |
| if nonempty: |
| assert value |
| if type_: |
| assert _assert_seq_elements(type_, value) |
| return True |
| return ImmutableProperty( |
| frozenset, |
| converter=converter, |
| assert_checker=_assert_checker, |
| **kwargs) |
| |
| def tattr(type_=None, converter=tuple, nonempty=False, **kwargs): |
| """Immutable attribute with tuple value""" |
| def _assert_checker(value): |
| if nonempty: |
| assert value |
| if type_: |
| assert _assert_seq_elements(type_, value) |
| return True |
| return ImmutableProperty( |
| tuple, |
| converter=converter, |
| assert_checker=_assert_checker, |
| **kwargs) |
| |
| class ImmutablePropertyEnum(ImmutableProperty): |
| """Immutable property holding an enumeration""" |
| |
| @override |
| def __init__(self, type_, **kwargs): |
| if isinstance(type_, tuple): |
| assert all(issubclass(enum, AutoNumber) for enum in type_) |
| def _checker(value): |
| return any(enum.owns(value) for enum in type_) |
| assert_checker = _checker |
| self.__enum_types = type_ |
| else: |
| assert issubclass(type_, AutoNumber) |
| assert_checker = type_.owns |
| self.__enum_types = (type_,) |
| super().__init__(AutoNumber.value_type, |
| assert_checker=assert_checker, |
| **kwargs) |
| |
| @override |
| def format_value(self, value): |
| for enum_type in self.__enum_types: |
| if enum_type.owns(value): |
| return enum_type.__qualname__ + "." + enum_type.label_of(value) |
| return "{!r}=???".format(value) |
| |
| enumattr = ImmutablePropertyEnum |
| |
| class ImmutableError(ValueError): |
| """Exception for immutable constraint violations""" |
| def __init__(self, for_, msg): |
| super().__init__("error in class {!r}: {}".format(for_, msg)) |
| |
| class ImmutableClassInfo(object): |
| """Information for each immutable class type""" |
| name = None |
| call_post_init_assert = False |
| call_post_init_check = False |
| newfn_source = None |
| def __init__(self): |
| self.iprops = OrderedDict() |
| self.all_iprops = OrderedDict() |
| @staticmethod |
| def scan(cls_name, bases, dict_): |
| """Extract information from class under construction""" |
| scanned_bases = set() |
| nr_immutable_bases = 0 |
| iinfo = ImmutableClassInfo() |
| saw_inherited_nonkw = set() |
| def _scan_base(base): |
| if base in scanned_bases: |
| return |
| scanned_bases.add(base) |
| base_iinfo = _IMMUTABLE_INFO.get(base) |
| if not base_iinfo: |
| return |
| nonlocal nr_immutable_bases |
| nr_immutable_bases += 1 |
| iinfo.call_post_init_check = iinfo.call_post_init_check or \ |
| (base is not Immutable and base_iinfo.call_post_init_check) |
| iinfo.call_post_init_assert = iinfo.call_post_init_assert or \ |
| (base is not Immutable and base_iinfo.call_post_init_assert) |
| for name, descr in base_iinfo.iprops.items(): |
| if name in iinfo.all_iprops: |
| raise ImmutableError(cls_name, |
| "duplicate iprop {!r}".format(name)) |
| if not descr.kwonly: |
| saw_inherited_nonkw.add(name) |
| iinfo.all_iprops[name] = descr |
| for nb in base.mro(): |
| _scan_base(nb) |
| for base in bases: |
| _scan_base(base) |
| for name in tuple(dict_): |
| if name == "__init__": |
| raise ImmutableError(cls_name, |
| "do not define {!r}".format(name)) |
| if name == "_post_init_assert": |
| iinfo.call_post_init_assert = True |
| if name == "_post_init_check": |
| iinfo.call_post_init_check = True |
| value = dict_[name] |
| if isinstance(value, ImmutableProperty): |
| assert not name.startswith("__") |
| value.propname = name |
| if not value.name: |
| value.set_name(name) |
| del dict_[name] |
| if name in iinfo.all_iprops: |
| raise ImmutableError( |
| cls_name, "iprop conflict for {!r}".format(name)) |
| iinfo.all_iprops[name] = value |
| iinfo.iprops[name] = value |
| if nr_immutable_bases > 1 and iinfo.iprops and saw_inherited_nonkw: |
| raise ImmutableError( |
| cls_name, |
| "cannot inherit non-kw-only iprop {!r}" |
| .format(saw_inherited_nonkw)) |
| return iinfo |
| |
| _IMMUTABLE_INFO = weakref.WeakKeyDictionary() |
| |
| class _ImmutableAssignDummy(object): |
| """Temporary class for immutable object construction""" |
| # To make an immutable object, we first make one of these objects, |
| # assign its properties normally, then transmute its __class__ to |
| # the desired object type. This way, we bypass the __setattr__ |
| # mutability check, but only during construction. |
| |
| class CodeGenerator(object): |
| """Format Python code conveniently""" |
| def __init__(self): |
| self.__lines = [] |
| self.__indent = 0 |
| self.__globals = {} |
| |
| def __call__(self, text, *args, **kwargs): |
| """Add a formatted line of text""" |
| self.__lines.append(" " * self.__indent + |
| text.format(*args, **kwargs)) |
| |
| @contextmanager |
| def indent(self, *args, **kwargs): |
| """Context manager indenting subsequent lines.""" |
| if args: |
| self(*args, **kwargs) |
| else: |
| assert not kwargs |
| self.__indent += 2 |
| try: |
| yield |
| finally: |
| self.__indent -= 2 |
| |
| def define_global(self, name, value): |
| """Define a global for subsequent exec""" |
| assert name not in self.__globals |
| self.__globals[name] = value |
| |
| def define_globals(self, **kwargs): |
| """Alternate variant of define_global""" |
| assert not any(kwarg in self.__globals for kwarg in kwargs) |
| self.__globals.update(kwargs) |
| |
| def get_source(self): |
| """Return source generated so far""" |
| return "\n".join(self.__lines) |
| |
| def compile(self): |
| """Generate code |
| |
| Return a pair (LOCALS, SRC) where LOCALS is the locals dictionary |
| that might have been modified by side effect of codegen and SRC is |
| the generated source code. |
| """ |
| assert not self.__indent |
| src = self.get_source() |
| dummy_locals = {} |
| exec(src, self.__globals, dummy_locals) |
| return dummy_locals, self.__globals, src |
| |
| def generate_immutable_constructor(iinfo, do_intern, emit_checks): |
| """Generate code for an immutable object constructor""" |
| def _iinfo_sort_key(descr): |
| return descr.kwonly, (descr.kwonly and descr.has_default) |
| all_iprops = list(iinfo.all_iprops.values()) |
| all_iprops.sort(key=_iinfo_sort_key) |
| cg = CodeGenerator() |
| |
| # The dummy object has a sacrificial dict that we transplant into |
| # the new object. |
| |
| cg.define_globals( |
| _R_object_new=object.__new__, |
| _R_dummy=_ImmutableAssignDummy) |
| |
| new_args = ["_R_cls"] |
| for descr in all_iprops: |
| if descr.kwonly and "*" not in new_args: |
| new_args.append("*") |
| new_args.append(descr.arg_decl) |
| with cg.indent("def __new__({}):", ", ".join(new_args)): |
| for descr in all_iprops: |
| descr.emit_conversion_code(cg, do_intern, emit_checks) |
| |
| def _gen_construction_code(): |
| cg("_R_obj = _R_dummy()") |
| for descr in all_iprops: |
| descr.emit_assignment_code(cg, do_intern, emit_checks) |
| cg("_R_obj.__class__ = _R_cls") |
| if iinfo.call_post_init_check: |
| cg("_R_obj._post_init_check()") |
| if emit_checks and iinfo.call_post_init_assert: |
| cg("_R_obj._post_init_assert()") |
| if do_intern: |
| the_cache = weakref.WeakValueDictionary() |
| cg.define_globals(_R_cache=the_cache) |
| cg.define_globals(_R_cache_setdefault=the_cache.setdefault) |
| cache_key_parts = [] |
| for descr in all_iprops: |
| cache_key_parts.append(descr.name) |
| cg("_R_cache_key = ({})", ", ".join(cache_key_parts)) |
| with cg.indent("try:"): |
| cg("return _R_cache[_R_cache_key]") |
| with cg.indent("except KeyError:"): |
| _gen_construction_code() |
| cg("return _R_cache_setdefault(_R_cache_key, _R_obj)") |
| else: |
| _gen_construction_code() |
| cg("return _R_obj") |
| with cg.indent("def __immutable_key__(self):"): |
| cg("return ({})", ", ".join( |
| "self." + attr.propname |
| for attr in all_iprops)) |
| with cg.indent("def __rehydrate_from_immutable_key__(cls, key):"): |
| def _emit_rehydrate(): |
| cg("_R_obj = _R_dummy()") |
| if len(all_iprops) == 1: |
| cg("_R_obj.{attr.propname} = key".format(attr=all_iprops[0])) |
| elif len(all_iprops) > 1: |
| cg("({}) = key".format(", ".join( |
| "_R_obj.{attr.propname}".format(attr=attr) for attr in all_iprops |
| ))) |
| cg("_R_obj.__class__ = cls") |
| if do_intern: |
| with cg.indent("try:"): |
| cg("return _R_cache[key]") |
| with cg.indent("except KeyError:"): |
| _emit_rehydrate() |
| cg("return _R_cache_setdefault(key, _R_obj)") |
| else: |
| _emit_rehydrate() |
| cg("return _R_obj") |
| evolve_args = ["_R_old"] |
| for descr in all_iprops: |
| if descr.kwonly and "*" not in evolve_args: |
| new_args.append("*") |
| evolve_args.append("{}=_R_dummy".format(descr.name)) |
| with cg.indent("def evolve({}):", ", ".join(evolve_args)): |
| cg("_R_obj = _R_dummy()") |
| for descr in all_iprops: |
| descr.emit_evolve_assignment_code(cg, emit_checks) |
| cg("_R_obj.__class__ = _R_the_cls") |
| def _emit_post_init_checks(cg): |
| if iinfo.call_post_init_check: |
| cg("_R_obj._post_init_check()") |
| if emit_checks and iinfo.call_post_init_assert: |
| cg("_R_obj._post_init_assert()") |
| if do_intern: |
| post_checks = (iinfo.call_post_init_check or |
| (emit_checks and iinfo.call_post_init_assert)) |
| if post_checks: |
| cg("_R_old = _R_obj") |
| cg("_R_obj = " |
| "_R_cache_setdefault(_R_obj.__immutable_key__, _R_obj)") |
| if post_checks: |
| with cg.indent("if _R_old is _R_obj:"): |
| _emit_post_init_checks(cg) |
| else: |
| _emit_post_init_checks(cg) |
| cg("return _R_obj") |
| for descr in all_iprops: |
| descr.augment_ns(cg) |
| return cg.compile() |
| |
| def _munge_immutable(cls_name, bases, dict_, do_intern, emit_checks): |
| # TODO(dancol): investigate doing the immutable initialization in |
| # C instead of using codegen. |
| iinfo = ImmutableClassInfo.scan(cls_name, bases, dict_) |
| dummy_locals, dummy_globals, src = generate_immutable_constructor( |
| iinfo, do_intern, emit_checks) |
| iinfo.newfn_source = src # For pylint |
| new_func = dummy_locals["__new__"] |
| if "__new__" in dict_: |
| dict_["_do_new"] = new_func # __new__ must invoke |
| else: |
| dict_["__new__"] = new_func |
| dict_["__immutable_key__"] = \ |
| cached_property(dummy_locals["__immutable_key__"]) |
| dict_["__rehydrate_from_immutable_key__"] = \ |
| classmethod(dummy_locals["__rehydrate_from_immutable_key__"]) |
| dict_["evolve"] = dummy_locals["evolve"] |
| |
| # Tweak the dict so inheritance checks work correctly for |
| # immutableproperties and so that we don't get a false positive on |
| # __new__. |
| tweaked_dict = dict_.copy() |
| tweaked_dict.pop("__new__", None) |
| tweaked_dict.pop("__immutable_key__", None) |
| tweaked_dict.pop("__rehydrate_from_immutable_key__", None) |
| tweaked_dict.pop("evolve") |
| for descr in iinfo.iprops.values(): |
| tweaked_dict[descr.propname] = False # Dummy |
| return iinfo, tweaked_dict, dummy_globals |
| |
| class ImmutableMeta(ExplicitInheritanceMeta): |
| """Metaclass for immutable system""" |
| def __new__(mcs, cls_name, bases, dict_, |
| do_intern=False, |
| emit_checks=__debug__, |
| **kwargs): |
| iinfo, tweaked_dict, dummy_globals = \ |
| _munge_immutable(cls_name, bases, dict_, do_intern, emit_checks) |
| cls = super().__new__(mcs, cls_name, bases, dict_, |
| tweaked_dict=tweaked_dict, |
| **kwargs) |
| dummy_globals["_R_the_cls"] = cls |
| _IMMUTABLE_INFO[cls] = iinfo |
| return cls |
| |
| def __init__(cls, cls_name, bases, dict_, **kwargs): |
| kwargs.pop("emit_checks", None) |
| kwargs.pop("do_intern", None) |
| super().__init__(cls_name, bases, dict_, **kwargs) |
| |
| def get_newfn_source(cls): |
| """For pylint: retrieve source code of generated __new__ function""" |
| return _IMMUTABLE_INFO[cls].newfn_source |
| |
| @property |
| def fields(cls): |
| """Get immutable field information |
| |
| Return a mapping. Keys are field names; values are |
| ImmutableClassInfo objects. The fields appear in the dict in the |
| order in which they were declared. |
| """ |
| return _IMMUTABLE_INFO[cls].all_iprops |
| |
| class Immutable(ExplicitInheritance, |
| metaclass=ImmutableMeta): |
| """Base class for objects immutable after initialization""" |
| |
| def _post_init_assert(self): |
| """Function called to verify invariants, but only when __debug__ |
| |
| Call through to super. |
| """ |
| |
| def _post_init_check(self): |
| """Function called to verify invariants in all builds |
| |
| Call through to super. |
| """ |
| |
| # Unless someone defined comparison operators, define them ourselves |
| # to make the resulting class unhashable and non-comparable. If you |
| # want unsafe by-identity non-interned comparison, you have to opt |
| # into it. |
| |
| @override |
| def __eq__(self, other): |
| if type(self) is not type(other): |
| return NotImplemented |
| raise TypeError("{} is not comparable".format(type(self))) |
| |
| @override |
| def __hash__(self): |
| raise TypeError("{} is not hashable".format(type(self))) |
| |
| @final |
| @override |
| def __setattr__(self, _attrib, _value): |
| raise RuntimeError("immutable object is immutable") |
| |
| @final |
| @override |
| def __delattr__(self, _attrib): |
| raise RuntimeError("immutable object is immutable") |
| |
| @override |
| def __reduce__(self): |
| # pylint: disable=no-member |
| return _rebuild_immutable, (type(self), self.__immutable_key__) |
| |
| def _no_print_keys(self): # pylint: disable=no-self-use |
| return frozenset() |
| |
| @cached_property |
| def __cached_repr(self): |
| no_print_keys = self._no_print_keys() |
| iinfo = _IMMUTABLE_INFO[type(self)] |
| return "<{} {}>".format( |
| type(self).__name__, |
| ", ".join("{}={}".format(k, i.format_value(getattr(self, k))) |
| for k, i in iinfo.all_iprops.items() |
| if k not in no_print_keys)) |
| @override |
| def __repr__(self): |
| return self.__cached_repr |
| |
| def _rebuild_immutable(immutable_type, key): |
| return immutable_type.__rehydrate_from_immutable_key__(key) |
| |
| class IdentityHashedImmutable(Immutable): |
| """Immutable with default object hash and equality semantics""" |
| __inherit__ = dict( |
| __eq__=override_final, |
| __hash__=override_final, |
| ) |
| __eq__ = object.__eq__ |
| __hash__ = object.__hash__ |
| |
| class EqImmutable(Immutable): |
| """Immutable with content comparison""" |
| |
| @override |
| def __eq__(self, other): |
| # TODO(dancol): generate in metaclass? |
| if type(self) is not type(other): |
| return NotImplemented |
| return (self.__hash == other.__hash and # pylint: disable=protected-access |
| self.__immutable_key__ == other.__immutable_key__) |
| |
| @cached_property |
| def __hash(self): |
| return hash(self.__immutable_key__) |
| |
| @override |
| def __hash__(self): |
| return self.__hash |
| |
| class InternedMeta(ImmutableMeta): |
| """Metaclass for interned immutables""" |
| def __new__(mcs, cls_name, bases, dict_, **kwargs): # pylint: disable=signature-differs |
| return super().__new__(mcs, cls_name, bases, dict_, |
| do_intern=True, **kwargs) |
| |
| class Interned(Immutable, metaclass=InternedMeta): |
| """Immutable objects that we hash-cons""" |
| |
| __inherit__ = dict( |
| __eq__=override_final, |
| __hash__=override_final, |
| ) |
| |
| # Because the object is interned, we can just use the fast Python |
| # built-in comparators. |
| __eq__ = object.__eq__ |
| __hash__ = object.__hash__ |
| |
| |
| |
| class FrozenIntRangeDict(collections.abc.Mapping): |
| """Mapping from number ranges to values""" |
| def __init__(self, contents=(), end=None): |
| """Make a dictionary mapping integer ranges to values. |
| |
| CONTENTS is a sequence of (KEY, VALUE) pairs sorted according to |
| KEY. As a special case, if the last item of CONTENTS is an |
| integer, use that value at END. END is the total size of the |
| mapped range. KEY must be an integer. VALUE can be anything. KEY |
| values must be unique. |
| """ |
| index = array("i") |
| values = [] |
| for item in contents: |
| if isinstance(item, int): |
| end = item |
| break |
| key, value = item |
| if __debug__: |
| i = len(index) |
| assert isinstance(key, int) |
| assert not i or index[i - 1] < key, "end too small or not given" |
| try: |
| index.append(key) |
| except OverflowError: |
| index = array("l", index) |
| index.append(key) |
| values.append(value) |
| |
| assert not index or end > index[-1] |
| self._index = index |
| self._values = values |
| self._end = end |
| |
| def chunks(self): |
| """Yield (CHUNK_START, CHUNK_END, VALUE) seq describing true contents""" |
| index = self._index |
| if index: |
| for i in range(1, len(index)): |
| yield index[i - 1], index[i], self._values[i - 1] |
| yield index[-1], self._end, self._values[-1] |
| |
| def __getitem__(self, key): |
| if key >= self._end: |
| raise KeyError(key) |
| index = bisect_right(self._index, key) |
| if not index: |
| raise KeyError(key) |
| return self._values[index - 1] |
| |
| def __iter__(self): |
| for chunk_start, chunk_end, _ in self.chunks(): |
| yield from range(chunk_start, chunk_end) |
| |
| def __len__(self): |
| if self._index: |
| return self._end - self._index[0] |
| return 0 |
| |
| __slots__ = "_index", "_values", "_end" |
| |
| def future_exception(future): |
| """Return exception for non-success completion of FUTURE |
| |
| If the future is canceled, return a CancelledError instead of |
| raising it. (Future.exception() bizarrely raises CancelledError |
| instead of returning it.) |
| """ |
| try: |
| return future.exception() |
| except CancelledError as ex: |
| return ex |
| |
| def future_success_p(future): |
| """Return (without blocking) whether FUTURE has completed successfully.""" |
| return future.done() and not future.cancelled() and not future.exception() |
| |
| def future_result_now(future, default=None): |
| """Return future result or DEFAULT if not completed successfully""" |
| return default if ((not future.done()) or future_exception(future)) \ |
| else future.result() |
| |
| class AllCompleteFuture(ChainableFuture, ExplicitInheritance): |
| """Future that completes when all indicated futures complete""" |
| |
| @override |
| def __init__(self, futures): |
| super().__init__() |
| self.__futures = tuple(futures) |
| self.__pending_count = len(self.__futures) |
| for future in self.__futures: |
| future.add_done_callback(self.__on_done) |
| |
| def __cancel_dependencies(self): |
| for future in self.__futures: |
| future.cancel() |
| self.__futures = () |
| |
| @override |
| def cancel(self): |
| super().cancel() |
| self.__cancel_dependencies() |
| |
| @override |
| def set_exception(self, exception): |
| super().set_exception(exception) |
| self.__cancel_dependencies() |
| |
| @override |
| def set_result(self, value): |
| super().set_result(value) |
| self.__cancel_dependencies() |
| |
| def __on_done(self, future): |
| assert self.__pending_count |
| self.__pending_count -= 1 |
| if not self.done(): |
| ex = future_exception(future) |
| if ex: |
| self.set_exception(ex) |
| elif not self.__pending_count: |
| self.__futures = () |
| self.set_result(None) |
| |
| class DeferredSynchronousFuture(ChainableFuture, ExplicitInheritance): |
| """Future that resolves its value by calling a function. |
| |
| Function is called in caller's context on demand.""" |
| |
| @override |
| def __init__(self, fn=None, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg |
| super().__init__() |
| if not fn: |
| assert not args |
| assert not kwargs |
| else: |
| self.__fn = (fn, args, kwargs) |
| |
| def __complete(self): |
| # pylint: disable=protected-access |
| if self.__base: |
| self.__base.__complete() |
| del self.__base |
| return |
| if self.__fn: |
| with self._condition: # pylint: disable=protected-access |
| if self.done() or self.running(): |
| return |
| if not self.set_running_or_notify_cancel(): |
| return |
| fn, args, kwargs = self.__fn |
| del self.__fn |
| try: |
| self.set_result(fn(*args, **kwargs)) |
| except BaseException as ex: |
| self.set_exception(ex) |
| |
| @override |
| def result(self, timeout=None): |
| self.__complete() |
| return super().result(timeout) |
| |
| @override |
| def exception(self, timeout=None): |
| self.__complete() |
| return super().exception(timeout) |
| |
| @override |
| def then(self, on_fulfilled=None, on_rejected=None): |
| new_future = super().then(on_fulfilled, on_rejected) |
| assert isinstance(new_future, type(self)) |
| new_future.__base = self # pylint: disable=protected-access |
| return new_future |
| |
| __fn = None |
| __base = None |
| |
| CAPTURE_GROUP_RE = re.compile(r"(?<!\\)\(\?P<([^>]+)>") |
| def passivize_re(re_str, predicate=lambda _: False): |
| """Conditionally make re_str named capture groups insert. |
| |
| PREDICATE is a function of one argument, a regex named capture group |
| name. If PREDICATE returns true, keep the match in the returned |
| regex string; otherwise, transform the named group to a |
| non-capturing one. |
| """ |
| def _callback(m): |
| if predicate(m.group(1)): |
| return m.group(0) |
| return r"(?:" # Transform to insert group |
| return CAPTURE_GROUP_RE.sub(_callback, re_str) |
| |
| class TmpDir(SharedObject): |
| """Temporary directory cleaned up via modernmp""" |
| |
| def __init__(self, suffix="", prefix="tmp", dir_=None): |
| super().__init__() |
| from tempfile import mkdtemp |
| self.name = mkdtemp(suffix, prefix, dir_) |
| |
| def _resman_destroy(self): |
| from shutil import rmtree |
| rmtree(self.name, ignore_errors=True) |
| |
| |
| # Misc |
| |
| class UsageTrackingDictionary(collections.abc.MutableMapping): |
| """Dict wrapper that tracks which keys were accessed""" |
| def __init__(self, data): |
| assert isinstance(data, collections.abc.Mapping) |
| self.__data = data |
| self.__used_keys = set() |
| |
| @property |
| def used(self): |
| """Return a set of keys which we accessed""" |
| return frozenset(self.__used_keys) |
| |
| def __getitem__(self, key): |
| value = self.__data[key] |
| self.__used_keys.add(key) |
| return value |
| |
| def __iter__(self): |
| return iter(self.__data) |
| |
| def clear(self): |
| self.__data.clear() |
| |
| def __len__(self): |
| return len(self.__data) |
| |
| def __delitem__(self, key): |
| del self.__data[key] |
| |
| def __setitem__(self, key, value): |
| raise Exception("assigning inputs dict not supported") |
| |
| def timeline_to_eta(timeline): |
| """Compute time-to-arrival for entries in TIMELINE. |
| |
| Each value in sequence TIMELINE is a set EVENTS of EVENT, where |
| EVENT is an event that "arrives" at TIMELINE[t], 0 <= t < |
| len(TIMELINE). |
| |
| Return a list of the same length as TIMELINE in which each item is |
| a dict mapping V to the remaining "time" (as measured by t above) |
| until the value arrives. Having arrives, the value disappears from |
| the resulting set. |
| |
| If an event arrives more than once, its distance is reset at the |
| point at which it arrives. |
| """ |
| |
| arrival_timeline = [] |
| arrivals = {} |
| for t in range(len(timeline) - 1, -1, -1): |
| for event in tuple(arrivals): |
| arrivals[event] += 1 |
| for event in timeline[t]: |
| arrivals[event] = 0 |
| arrival_timeline.append(arrivals.copy()) |
| arrival_timeline.reverse() |
| return arrival_timeline |
| |
| def unlink_if_exists(file_name): |
| """Like os.unlink, but succeeds if the file did not exist""" |
| try: |
| os.unlink(file_name) |
| except FileNotFoundError: |
| pass |
| |
| DCTV_DIR = None |
| def dctv_dir(): |
| """Return the top-level DCTV distribution directory""" |
| global DCTV_DIR |
| the_dctv_dir = DCTV_DIR |
| if not the_dctv_dir: |
| mydir = dirname(__file__) |
| DCTV_DIR = the_dctv_dir = normpath(pjoin(mydir, "..", "..")) |
| return the_dctv_dir |
| |
| NoneType = type(None) |
| |
| |
| # Numpy utilities |
| |
| # Using these ends up being much faster than calling np.dtype |
| # (np.dtype(np.int64) costs around 158ns.) |
| INT8 = np.dtype(np.int8) |
| UINT8 = np.dtype(np.uint8) |
| INT16 = np.dtype(np.int16) |
| UINT16 = np.dtype(np.uint16) |
| INT32 = np.dtype(np.int32) |
| UINT32 = np.dtype(np.uint32) |
| INT64 = np.dtype(np.int64) |
| UINT64 = np.dtype(np.uint64) |
| BOOL = np.dtype(np.bool_) |
| FLOAT32 = np.dtype(np.float32) |
| FLOAT64 = np.dtype(np.float64) |
| |
| def xabspath(path): |
| """Like os.path.abspath, but used cached cwd""" |
| if not isabs(path): |
| path = pjoin(STARTUP_CWD, path) |
| return path |
| |
| def ignore_numpy_errors(): |
| """Context manager that ignores all numpy errors""" |
| return np.errstate(divide="ignore", |
| over="ignore", |
| under="ignore", |
| invalid="ignore") |
| |
| def is_nondecreasing(arr): |
| """Determine whether array ARR is nondecreasing. |
| |
| Returns a boolean. |
| """ |
| return (arr[:-1] <= arr[1:]).all() |
| |
| |
| # SortedDict (Banyan replacement) |
| |
| class SortedDict(ExplicitInheritance, |
| brain_suck_abc=collections.abc.MutableMapping): |
| """Dictionary that maintains its keys in sorted order""" |
| |
| @override |
| def __init__(self): |
| self._data = {} |
| self._sorted_keys = None |
| |
| @override |
| def __getitem__(self, key): |
| return self._data[key] |
| |
| @override |
| def __setitem__(self, key, value): |
| self._data[key] = value |
| self._sorted_keys = None |
| |
| @override |
| def __delitem__(self, key): |
| del self._data[key] |
| self._sorted_keys = None |
| |
| @override |
| def __iter__(self): |
| return iter(self.keys()) |
| |
| @override |
| def __len__(self): |
| return len(self.keys()) |
| |
| @override |
| def keys(self): |
| keys = self._sorted_keys |
| if keys is None: |
| self._sorted_keys = keys = tuple(sorted(self._data)) |
| return keys |
| |
| __slots__ = ["_data", "_sorted_keys"] |
| |
| |
| # Sequence utilities |
| |
| def all_unique(sequence): |
| """Return whether all items in SEQUENCE are unique""" |
| return len(sequence) == len(set(sequence)) |
| |
| def all_same(sequence): |
| """Return whether all items in SEQUENCE are equal""" |
| items = list(sequence) |
| if not items: |
| return True |
| first_item = items[0] |
| return all(item == first_item for item in items[1:]) |
| |
| def drain(iterable): |
| """Deplete ITERABLE. Like list(map(...)), but without storing results.""" |
| for _ in iterable: |
| pass |
| |
| def partition(predicate, sequence): |
| """Split a sequence by a predicate. |
| |
| Return a tuple (YES_SEQUENCE, NO_SEQUENCE). Each sequence contains |
| the elements of SEQUENCE that did or did not, respectively, match |
| the predicate. |
| """ |
| t1, t2 = tee(sequence) |
| return filter(predicate, t1), filterfalse(predicate, t2) |
| |
| def _assert_seq_elements(elem_type, seq): |
| if __debug__: |
| if not all(isinstance(elem, elem_type) for elem in seq): |
| bad = [elem for elem in seq if not isinstance(elem, elem_type)] |
| raise AssertionError( |
| "elements did not meet type requirement {!r}: {!r}" |
| .format(elem_type, bad)) |
| return True |
| |
| def argmin(sequence): |
| """Return the index of the smallest item in SEQUENCE""" |
| return min(enumerate(sequence), key=second)[0] |
| |
| def argmax(sequence): |
| """Return the index of the largest item in SEQUENCE""" |
| return max(enumerate(sequence), key=second)[0] |
| |
| def common_prefix(compare, *seqs): |
| """Lazily compute the common prefix of SEQS. |
| |
| The comparison is done with the COMPARE function, which should take |
| as many arguments as there are in SEQS and return whether they are |
| all equal. |
| """ |
| return map(first, takewhile(lambda items: compare(*items), zip(*seqs))) |
| |
| def consume_list(alist): |
| """Iterate over ALIST, destroying it as we go. |
| |
| Assumes ownership of ALIST. The function makes no guarantees about |
| what ALIST looks like while we're iterating over it. |
| """ |
| assert isinstance(alist, list) |
| alist.reverse() |
| while alist: |
| yield alist.pop() |
| |
| |
| |
| class CaseInsensitiveCasePreservingDict( |
| ExplicitInheritance, |
| brain_suck_abc=collections.abc.MutableMapping): |
| """Case-preserving case-insensitive dict""" |
| |
| __slots__ = ("_contents",) |
| |
| @override |
| def __init__(self): |
| self._contents = {} |
| |
| @override |
| def __getitem__(self, key): |
| try: |
| return second(self._contents[key.lower()]) |
| except KeyError: |
| return self.__missing__(key) |
| |
| @override |
| def __setitem__(self, key, value): |
| self._contents[key.lower()] = (key, value) |
| |
| @override |
| def __delitem__(self, key): |
| del self._contents[key.lower()] |
| |
| @override |
| def __iter__(self): |
| for original_name, _ in self._contents.values(): |
| yield original_name |
| |
| @override |
| def __len__(self): |
| return len(self._contents) |
| |
| @override |
| def clear(self): |
| self._contents.clear() |
| |
| def __missing__(self, key): |
| raise KeyError(key) |
| |
| def vard(*varnames): |
| """Debug facility: dump VARNAMES from enclosing scope |
| |
| Each VARNAME in VARNAMES is a string. |
| """ |
| |
| xlog = inspect.currentframe().f_back.f_globals.get("log", log) |
| for varname in varnames: |
| value = inspect.currentframe().f_back.f_locals[varname] |
| # Don't hit on greps for the three-X string |
| xlog.debug("X""XX %s=%r", varname, value) |