| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Core conversion logic, serves as main point of access.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import functools |
| import inspect |
| import sys |
| import unittest |
| |
| from tensorflow.python.autograph.core import config |
| from tensorflow.python.autograph.pyct import cache |
| from tensorflow.python.autograph.pyct import inspect_utils |
| from tensorflow.python.autograph.utils import ag_logging as logging |
| from tensorflow.python.eager import function |
| from tensorflow.python.util import tf_inspect |
| |
| |
| _ALLOWLIST_CACHE = cache.UnboundInstanceCache() |
| |
| |
| def _is_of_known_loaded_module(f, module_name): |
| mod = sys.modules.get(module_name, None) |
| if mod is None: |
| return False |
| if any(v is not None for v in mod.__dict__.values() if f is v): |
| return True |
| return False |
| |
| |
| def _is_known_loaded_type(f, module_name, entity_name): |
| """Tests whether the function or method is an instance of a known type.""" |
| if (module_name not in sys.modules or |
| not hasattr(sys.modules[module_name], entity_name)): |
| return False |
| type_entity = getattr(sys.modules[module_name], entity_name) |
| if isinstance(f, type_entity): |
| # The method if of this type. Example: |
| # |
| # o = ClassType() |
| # function(o.method)() |
| return True |
| # Note: inspect is required here, to avoid unpacking tf.function decorators. |
| if inspect.ismethod(f): |
| # The unbound method if of this type. Example: |
| # |
| # class ClassType: |
| # @function |
| # def method(self): |
| # ... |
| # o = ClassType() |
| # o.method() |
| if isinstance(f.__func__, type_entity): |
| return True |
| return False |
| |
| |
| def is_unsupported(o): |
| """Checks whether an entity is supported by AutoGraph at all.""" |
| |
| # TODO(b/122265385): Remove this bypass. |
| if (_is_known_loaded_type(o, 'wrapt', 'FunctionWrapper') or |
| _is_known_loaded_type(o, 'wrapt', 'BoundFunctionWrapper')): |
| logging.warning( |
| '{} appears to be decorated by wrapt, which is not yet supported' |
| ' by AutoGraph. The function will run as-is.' |
| ' You may still apply AutoGraph before the wrapt decorator.'.format(o)) |
| logging.log(2, 'Permanently allowed: %s: wrapt decorated', o) |
| return True |
| |
| if _is_known_loaded_type(o, 'functools', '_lru_cache_wrapper'): |
| logging.log(2, 'Permanently allowed: %s: lru_cache', o) |
| return True |
| |
| # Constructors are permanently allowed. |
| # TODO(mdan): Toggle as experimental feature instead. |
| # TODO(b/124016764): Remove this limitation. |
| if inspect_utils.isconstructor(o): |
| logging.log(2, 'Permanently allowed: %s: constructor', o) |
| return True |
| |
| # Other built-in modules are permanently allowed. |
| # TODO(mdan): Figure out how to do this consistently for all stdlib modules. |
| if any( |
| _is_of_known_loaded_module(o, m) |
| for m in ('collections', 'pdb', 'copy', 'inspect', 're')): |
| logging.log(2, 'Permanently allowed: %s: part of builtin module', o) |
| return True |
| |
| # Custom ops and kernels are also permanently allowed. |
| # See tensorflow.framework.load_library. |
| if (hasattr(o, '__module__') and |
| hasattr(o.__module__, '_IS_TENSORFLOW_PLUGIN')): |
| logging.log(2, 'Permanently allowed: %s: TensorFlow plugin', o) |
| return True |
| |
| return False |
| |
| |
| # TODO(mdan): allow_namedtuple_subclass should be hardcoded to True. |
| def is_allowlisted( |
| o, check_call_override=True, allow_namedtuple_subclass=False): |
| """Checks whether an entity is allowed for use in graph mode. |
| |
| Examples of allowed entities include all members of the tensorflow |
| package. |
| |
| Args: |
| o: A Python entity. |
| check_call_override: Reserved for internal use. When set to `False`, it |
| disables the rule according to which classes are allowed if their |
| __call__ method is allowed. |
| allow_namedtuple_subclass: Reserved for internal use. When `True`, |
| namedtuple subclasses are not allowed. |
| |
| Returns: |
| Boolean |
| """ |
| # TODO(b/120224672): Fix this. |
| if isinstance(o, functools.partial): |
| # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since |
| # functools.partial objects do not have a __module__ attribute. |
| m = functools |
| else: |
| m = tf_inspect.getmodule(o) |
| |
| # Examples of callables that lack a __module__ property include builtins. |
| if hasattr(m, '__name__'): |
| for rule in config.CONVERSION_RULES: |
| action = rule.get_action(m) |
| if action == config.Action.CONVERT: |
| logging.log(2, 'Not allowed: %s: %s', o, rule) |
| return False |
| elif action == config.Action.DO_NOT_CONVERT: |
| logging.log(2, 'Allowlisted: %s: %s', o, rule) |
| return True |
| |
| # The check for __code__ below is because isgeneratorfunction crashes |
| # without one. |
| if hasattr(o, '__code__') and tf_inspect.isgeneratorfunction(o): |
| logging.log(2, 'Allowlisted: %s: generator functions are not converted', o) |
| return True |
| |
| if (check_call_override and not tf_inspect.isclass(o) and |
| hasattr(o, '__call__')): |
| # Callable objects: allowed if their __call__ method is. |
| # The type check avoids infinite recursion around the __call__ method |
| # of function objects. |
| if (type(o) != type(o.__call__)) and is_allowlisted(o.__call__): # pylint: disable=unidiomatic-typecheck |
| logging.log(2, 'Allowlisted: %s: object __call__ allowed', o) |
| return True |
| |
| owner_class = None |
| if tf_inspect.ismethod(o): |
| # Methods of allowed classes are also allowed, even if they are |
| # bound via user subclasses. |
| # |
| # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is |
| # defined as below. `tf.Foo` is allowed. Then `baz.bar` is also |
| # allowed. |
| # |
| # class Custom(tf.Foo): |
| # pass |
| # |
| # baz = Custom() |
| # |
| # For the example above, if `Custom` did overload `bar`, then it would no |
| # longer be allowed. |
| |
| owner_class = inspect_utils.getmethodclass(o) |
| if owner_class is function.TfMethodTarget: |
| owner_class = o.__self__.target_class |
| if owner_class is not None: |
| if issubclass(owner_class, unittest.TestCase): |
| logging.log(2, 'Allowlisted: %s: method of TestCase subclass', o) |
| return True |
| |
| owner_class = inspect_utils.getdefiningclass(o, owner_class) |
| if is_allowlisted( |
| owner_class, |
| check_call_override=False, |
| allow_namedtuple_subclass=True): |
| logging.log(2, 'Allowlisted: %s: owner is allowed %s', o, |
| owner_class) |
| return True |
| |
| if inspect_utils.isnamedtuple(o): |
| # Due to the way they're constructed, namedtuple types cannot be converted |
| # because they don't expose source code. But we assume they are safe for |
| # graph mode since they are just containers. |
| if allow_namedtuple_subclass: |
| if not any(inspect_utils.isnamedtuple(base) for base in o.__bases__): |
| logging.log(2, 'Allowlisted: %s: named tuple', o) |
| return True |
| else: |
| logging.log(2, 'Allowlisted: %s: named tuple or subclass', o) |
| return True |
| |
| logging.log(2, 'Not allowed: %s: default rule', o) |
| return False |
| |
| |
| def is_in_allowlist_cache(entity, options): |
| try: |
| return _ALLOWLIST_CACHE.has(entity, options) |
| except TypeError: |
| # Catch-all for entities that are unhashable or don't allow weakrefs. |
| return False |
| |
| |
| def cache_allowlisted(entity, options): |
| try: |
| _ALLOWLIST_CACHE[entity][options] = True |
| except TypeError: |
| # Catch-all for entities that are unhashable or don't allow weakrefs. |
| pass |