| # mypy: allow-untyped-defs |
| from functools import wraps |
| from inspect import unwrap |
| from typing import Callable, List, Optional |
| import logging |
| |
| logger = logging.getLogger(__name__) |
| |
| __all__ = [ |
| "PassManager", |
| "inplace_wrapper", |
| "log_hook", |
| "loop_pass", |
| "this_before_that_pass_constraint", |
| "these_before_those_pass_constraint", |
| ] |
| |
| # for callables which modify object inplace and return something other than |
| # the object on which they act |
| def inplace_wrapper(fn: Callable) -> Callable: |
| """ |
| Convenience wrapper for passes which modify an object inplace. This |
| wrapper makes them return the modified object instead. |
| |
| Args: |
| fn (Callable[Object, Any]) |
| |
| Returns: |
| wrapped_fn (Callable[Object, Object]) |
| """ |
| |
| @wraps(fn) |
| def wrapped_fn(gm): |
| val = fn(gm) |
| return gm |
| |
| return wrapped_fn |
| |
| def log_hook(fn: Callable, level=logging.INFO) -> Callable: |
| """ |
| Logs callable output. |
| |
| This is useful for logging output of passes. Note inplace_wrapper replaces |
| the pass output with the modified object. If we want to log the original |
| output, apply this wrapper before inplace_wrapper. |
| |
| |
| ``` |
| def my_pass(d: Dict) -> bool: |
| changed = False |
| if 'foo' in d: |
| d['foo'] = 'bar' |
| changed = True |
| return changed |
| |
| pm = PassManager( |
| passes=[ |
| inplace_wrapper(log_hook(my_pass)) |
| ] |
| ) |
| ``` |
| |
| Args: |
| fn (Callable[Type1, Type2]) |
| level: logging level (e.g. logging.INFO) |
| |
| Returns: |
| wrapped_fn (Callable[Type1, Type2]) |
| """ |
| @wraps(fn) |
| def wrapped_fn(gm): |
| val = fn(gm) |
| logger.log(level, "Ran pass %s\t Return value: %s", fn, val) |
| return val |
| |
| return wrapped_fn |
| |
| |
| |
| def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None): |
| """ |
| Convenience wrapper for passes which need to be applied multiple times. |
| |
| Exactly one of `n_iter`or `predicate` must be specified. |
| |
| Args: |
| base_pass (Callable[Object, Object]): pass to be applied in loop |
| n_iter (int, optional): number of times to loop pass |
| predicate (Callable[Object, bool], optional): |
| |
| """ |
| assert (n_iter is not None) ^ ( |
| predicate is not None |
| ), "Exactly one of `n_iter`or `predicate` must be specified." |
| |
| @wraps(base_pass) |
| def new_pass(source): |
| output = source |
| if n_iter is not None and n_iter > 0: |
| for _ in range(n_iter): |
| output = base_pass(output) |
| elif predicate is not None: |
| while predicate(output): |
| output = base_pass(output) |
| else: |
| raise RuntimeError( |
| f"loop_pass must be given positive int n_iter (given " |
| f"{n_iter}) xor predicate (given {predicate})" |
| ) |
| return output |
| |
| return new_pass |
| |
| |
| # Pass Schedule Constraints: |
| # |
| # Implemented as 'depends on' operators. A constraint is satisfied iff a list |
| # has a valid partial ordering according to this comparison operator. |
| def _validate_pass_schedule_constraint( |
| constraint: Callable[[Callable, Callable], bool], passes: List[Callable] |
| ): |
| for i, a in enumerate(passes): |
| for j, b in enumerate(passes[i + 1 :]): |
| if constraint(a, b): |
| continue |
| raise RuntimeError( |
| f"pass schedule constraint violated. Expected {a} before {b}" |
| f" but found {a} at index {i} and {b} at index{j} in pass" |
| f" list." |
| ) |
| |
| |
| def this_before_that_pass_constraint(this: Callable, that: Callable): |
| """ |
| Defines a partial order ('depends on' function) where `this` must occur |
| before `that`. |
| """ |
| |
| def depends_on(a: Callable, b: Callable): |
| if a == that and b == this: |
| return False |
| return True |
| |
| return depends_on |
| |
| |
| def these_before_those_pass_constraint(these: Callable, those: Callable): |
| """ |
| Defines a partial order ('depends on' function) where `these` must occur |
| before `those`. Where the inputs are 'unwrapped' before comparison. |
| |
| For example, the following pass list and constraint list would be invalid. |
| ``` |
| passes = [ |
| loop_pass(pass_b, 3), |
| loop_pass(pass_a, 5), |
| ] |
| |
| constraints = [ |
| these_before_those_pass_constraint(pass_a, pass_b) |
| ] |
| ``` |
| |
| Args: |
| these (Callable): pass which should occur first |
| those (Callable): pass which should occur later |
| |
| Returns: |
| depends_on (Callable[[Object, Object], bool] |
| """ |
| |
| def depends_on(a: Callable, b: Callable): |
| if unwrap(a) == those and unwrap(b) == these: |
| return False |
| return True |
| |
| return depends_on |
| |
| |
| class PassManager: |
| """ |
| Construct a PassManager. |
| |
| Collects passes and constraints. This defines the pass schedule, manages |
| pass constraints and pass execution. |
| |
| Args: |
| passes (Optional[List[Callable]]): list of passes. A pass is a |
| callable which modifies an object and returns modified object |
| constraint (Optional[List[Callable]]): list of constraints. A |
| constraint is a callable which takes two passes (A, B) and returns |
| True if A depends on B and False otherwise. See implementation of |
| `this_before_that_pass_constraint` for example. |
| """ |
| |
| passes: List[Callable] |
| constraints: List[Callable] |
| _validated: bool = False |
| |
| def __init__( |
| self, |
| passes=None, |
| constraints=None, |
| ): |
| self.passes = passes or [] |
| self.constraints = constraints or [] |
| |
| @classmethod |
| def build_from_passlist(cls, passes): |
| pm = PassManager(passes) |
| # TODO(alexbeloi): add constraint management/validation |
| return pm |
| |
| def add_pass(self, _pass: Callable): |
| self.passes.append(_pass) |
| self._validated = False |
| |
| def add_constraint(self, constraint): |
| self.constraints.append(constraint) |
| self._validated = False |
| |
| def remove_pass(self, _passes: List[str]): |
| if _passes is None: |
| return |
| passes_left = [] |
| for ps in self.passes: |
| if ps.__name__ not in _passes: |
| passes_left.append(ps) |
| self.passes = passes_left |
| self._validated = False |
| |
| def replace_pass(self, _target, _replacement): |
| passes_left = [] |
| for ps in self.passes: |
| if ps.__name__ == _target.__name__: |
| passes_left.append(_replacement) |
| else: |
| passes_left.append(ps) |
| self.passes = passes_left |
| self._validated = False |
| |
| def validate(self): |
| """ |
| Validates that current pass schedule defined by `self.passes` is valid |
| according to all constraints in `self.constraints` |
| """ |
| if self._validated: |
| return |
| for constraint in self.constraints: |
| _validate_pass_schedule_constraint(constraint, self.passes) |
| self._validated = True |
| |
| def __call__(self, source): |
| self.validate() |
| out = source |
| for _pass in self.passes: |
| out = _pass(out) |
| return out |