|  | # Extra utilities for working with context managers that should have been | 
|  | # in the standard library but are not | 
|  |  | 
|  | import functools | 
|  | import inspect | 
|  | import warnings | 
|  | import sys | 
|  | from typing import Any, Callable, TypeVar, cast | 
|  |  | 
|  | # Used for annotating the decorator usage of _DecoratorContextManager (e.g., | 
|  | # 'no_grad' and 'enable_grad'). | 
|  | # See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators | 
|  | FuncType = Callable[..., Any] | 
|  | F = TypeVar('F', bound=FuncType) | 
|  |  | 
|  |  | 
|  | def _wrap_generator(ctx_factory, func): | 
|  | """ | 
|  | Wrap each generator invocation with the context manager factory. | 
|  |  | 
|  | The input should be a function that returns a context manager, | 
|  | not a context manager itself, to handle one-shot context managers. | 
|  | """ | 
|  | @functools.wraps(func) | 
|  | def generator_context(*args, **kwargs): | 
|  | gen = func(*args, **kwargs) | 
|  |  | 
|  | # Generators are suspended and unsuspended at `yield`, hence we | 
|  | # make sure the grad mode is properly set every time the execution | 
|  | # flow returns into the wrapped generator and restored when it | 
|  | # returns through our `yield` to our caller (see PR #49017). | 
|  | try: | 
|  | # Issuing `None` to a generator fires it up | 
|  | with ctx_factory(): | 
|  | response = gen.send(None) | 
|  |  | 
|  | while True: | 
|  | try: | 
|  | # Forward the response to our caller and get its next request | 
|  | request = yield response | 
|  |  | 
|  | except GeneratorExit: | 
|  | # Inform the still active generator about its imminent closure | 
|  | with ctx_factory(): | 
|  | gen.close() | 
|  | raise | 
|  |  | 
|  | except BaseException: | 
|  | # Propagate the exception thrown at us by the caller | 
|  | with ctx_factory(): | 
|  | response = gen.throw(*sys.exc_info()) | 
|  |  | 
|  | else: | 
|  | # Pass the last request to the generator and get its response | 
|  | with ctx_factory(): | 
|  | response = gen.send(request) | 
|  |  | 
|  | # We let the exceptions raised above by the generator's `.throw` or | 
|  | # `.send` methods bubble up to our caller, except for StopIteration | 
|  | except StopIteration as e: | 
|  | # The generator informed us that it is done: take whatever its | 
|  | # returned value (if any) was and indicate that we're done too | 
|  | # by returning it (see docs for python's return-statement). | 
|  | return e.value | 
|  |  | 
|  | return generator_context | 
|  |  | 
|  |  | 
|  | def context_decorator(ctx, func): | 
|  | """ | 
|  | Like contextlib.ContextDecorator, but: | 
|  |  | 
|  | 1. Is done by wrapping, rather than inheritance, so it works with context | 
|  | managers that are implemented from C and thus cannot easily inherit from | 
|  | Python classes | 
|  | 2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743) | 
|  | 3. Errors out if you try to wrap a class, because it is ambiguous whether | 
|  | or not you intended to wrap only the constructor | 
|  |  | 
|  | The input argument can either be a context manager (in which case it must | 
|  | be a multi-shot context manager that can be directly invoked multiple times) | 
|  | or a callable that produces a context manager. | 
|  | """ | 
|  |  | 
|  | assert not (callable(ctx) and hasattr(ctx, '__enter__')), ( | 
|  | f"Passed in {ctx} is both callable and also a valid context manager " | 
|  | "(has __enter__), making it ambiguous which interface to use.  If you " | 
|  | "intended to pass a context manager factory, rewrite your call as " | 
|  | "context_decorator(lambda: ctx()); if you intended to pass a context " | 
|  | "manager directly, rewrite your call as context_decorator(lambda: ctx)" | 
|  | ) | 
|  |  | 
|  | if not callable(ctx): | 
|  | def ctx_factory(): | 
|  | return ctx | 
|  | else: | 
|  | ctx_factory = ctx | 
|  |  | 
|  | if inspect.isclass(func): | 
|  | raise RuntimeError( | 
|  | "Cannot decorate classes; it is ambiguous whether or not only the " | 
|  | "constructor or all methods should have the context manager applied; " | 
|  | "additionally, decorating a class at definition-site will prevent " | 
|  | "use of the identifier as a conventional type.  " | 
|  | "To specify which methods to decorate, decorate each of them " | 
|  | "individually." | 
|  | ) | 
|  |  | 
|  | if inspect.isgeneratorfunction(func): | 
|  | return _wrap_generator(ctx_factory, func) | 
|  |  | 
|  | @functools.wraps(func) | 
|  | def decorate_context(*args, **kwargs): | 
|  | with ctx_factory(): | 
|  | return func(*args, **kwargs) | 
|  |  | 
|  | return decorate_context | 
|  |  | 
|  |  | 
|  | class _DecoratorContextManager: | 
|  | """Allow a context manager to be used as a decorator""" | 
|  |  | 
|  | def __call__(self, orig_func: F) -> F: | 
|  | if inspect.isclass(orig_func): | 
|  | warnings.warn("Decorating classes is deprecated and will be disabled in " | 
|  | "future versions. You should only decorate functions or methods. " | 
|  | "To preserve the current behavior of class decoration, you can " | 
|  | "directly decorate the `__init__` method and nothing else.") | 
|  | func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs)) | 
|  | else: | 
|  | func = orig_func | 
|  |  | 
|  | return cast(F, context_decorator(self.clone, func)) | 
|  |  | 
|  | def __enter__(self) -> None: | 
|  | raise NotImplementedError | 
|  |  | 
|  | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: | 
|  | raise NotImplementedError | 
|  |  | 
|  | def clone(self): | 
|  | # override this method if your children class takes __init__ parameters | 
|  | return self.__class__() |