blob: 1dc97d2331b7a823681e398d3602750350e049c6 [file] [log] [blame]
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This module contains the user-facing API for AutoGraph."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
from enum import Enum
from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.impl import conversion
from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.autograph.pyct import compiler
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.utils import py_func
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
# TODO(mdan): Properly document the type hints.
# TODO(mdan): Reduce the type hint information to (module, type).
# (currently we require (module + class name, type))
class ConversionOptions(
collections.namedtuple('ConversionOptions',
('recursive', 'verbose', 'strip_decorators',
'force_conversion', 'arg_types'))):
"""Container for conversion flags.
Attributes:
recursive: bool, whether to recursively convert any user functions or
classes that the converted function may use.
verbose: bool, whether to log the compiled code.
strip_decorators: Tuple[Callable], contains decorators that should be in
excluded from the compiled output. By default, when converting a
function before the decorators are applied, the compiled output will
include those decorators.
force_conversion: bool, whether to force convertinng the target entity.
When force_conversion is turned off, the converter may decide to
return the function as-is.
arg_types: Optional[Dict[Text, Type]], type hints for symbols including
function arguments.
"""
@classmethod
def new(cls,
recursive=False,
verbose=False,
strip_decorators=None,
force_conversion=False,
arg_types=None):
return cls(recursive=recursive,
verbose=verbose,
strip_decorators=strip_decorators or (),
force_conversion=force_conversion,
arg_types=arg_types or {})
# TODO(mdan): This should behave like to_graph (e.g. convert statically).
def convert(recursive=False, verbose=False):
"""Decorator that compiles a function to use TensorFlow ops.
The decorator is dynamic - it recompiles the target whenever the decorated
function is called. This means the parameter values are known at conversion.
It also means that repeated calls with different types of parameters will be
correctly processed.
Args:
recursive: bool, whether to recursively convert any functions or classes
that the converted function may use.
verbose: bool, whether to output the compiled code in the logs.
Returns:
Callable, a decorator that converts the given function into an equivalent
function that uses TensorFlow ops.
"""
def decorator(f):
"""Decorator implementation."""
@functools.wraps(f)
def wrapper(*args, **kwargs):
return converted_call(
f,
ConversionOptions.new(
recursive=recursive,
verbose=verbose,
force_conversion=True,
), *args, **kwargs)
wrapper = tf_decorator.make_decorator(f, wrapper)
# Sometimes the decorator is just desugared, making it impossible to detect.
# This attribute makes detection easier.
setattr(wrapper, '__pyct_is_compile_decorator', True)
return wrapper
return decorator
class RunMode(Enum):
"""Specifies the way a converted function or method should be executed in TF.
The enum values have the following semantics:
* GRAPH: Call this function directly, as-is. This is suitable for functions
that were already designed for TF graphs and contain ops.
* PY_FUNC: Wrap this function into a py_func op. This is suitable for code
that will only run correctly in Python, for example code that renders
to the display, reads keyboard input, etc.
"""
GRAPH = 1
PY_FUNC = 2
def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
"""Decorator that suppresses the conversion of a function.
See also: docs/pyfunc_dtypes.md
Args:
run_as: RunMode, specifies how to use the function in TensorFlow.
return_dtypes: Optional[Iterable[
Union[tf.DType, utils.py_func.MatchDType]]], the return data types of
the converted function, if run_as is RunMode.PY_FUNC. Ignored otherwise.
May be set to None if the function has no return values.
Returns:
Callable, a decorator that wraps the original function.
"""
def decorator(f):
"""Decorator implementation."""
@functools.wraps(f)
def graph_wrapper(*args, **kwargs):
return f(*args, **kwargs)
@functools.wraps(f)
def py_func_wrapper(*args, **kwargs):
if kwargs:
raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs')
# TODO(mdan): Add support for kwargs.
return py_func.wrap_py_func(
f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes)
if run_as == RunMode.GRAPH:
wrapper = graph_wrapper
elif run_as == RunMode.PY_FUNC:
wrapper = py_func_wrapper
else:
raise ValueError('unknown value for run_as: %s' % run_as)
# Sometimes the decorator is just desugared, making it impossible to detect.
# This attribute makes detection easier.
setattr(wrapper, '__pyct_is_compile_decorator', True)
return wrapper
return decorator
# TODO(mdan): Move to a private, undocumented module.
def converted_call(f, options, *args, **kwargs):
"""Compiles a function call inline. For internal use only."""
# TODO(mdan): This needs cleanup.
# In particular, we may want to avoid renaming functions altogether.
if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
return f(*args, **kwargs)
unknown_arg_value = object() # Sentinel for arguments of unknown value
if inspect_utils.isbuiltin(f):
return py_builtins.overload_of(f)(*args, **kwargs)
if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
# Regular functions
target_entity = f
arg_map_target = f
effective_args = args
f_class = inspect_utils.getmethodclass(f)
if f_class is not None:
partial_types = (f_class,)
else:
partial_types = ()
elif tf_inspect.isclass(f):
# Constructors
target_entity = f
arg_map_target = f.__init__
effective_args = args
partial_types = ()
elif hasattr(f, '__call__') and hasattr(f, '__class__'):
# Callable objects
target_entity = f.__call__
arg_map_target = f.__call__
effective_args = (f,) + args
partial_types = (f.__class__,)
else:
NotImplementedError('unknown callable type "%s"' % type(f))
arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs)
for name, arg in arg_values.items():
if arg is unknown_arg_value:
continue
arg_class = arg.__class__
# If arg_value_hints specifies any name, use that instead.
if name not in options.arg_types:
options.arg_types[name] = (arg_class.__name__, arg_class)
# When called from within a decorator, this is the only indication that
# the function is a method - it appears that the decorator is applied
# before the method is bound.
if not partial_types:
if 'self' in arg_values:
if tf_inspect.isclass(arg_values['self'].__class__):
partial_types = (arg_values['self'].__class__,)
elif 'cls' in arg_values:
if tf_inspect.isclass(arg_values['cls']):
partial_types = (arg_values['cls'],)
converted_f = to_graph(
target_entity,
recursive=options.recursive,
verbose=options.verbose,
arg_values=arg_values,
arg_types=options.arg_types,
partial_types=partial_types,
strip_decorators=options.strip_decorators)
return converted_f(*effective_args, **kwargs)
# TODO(mdan): Rename: to_ops?
# TODO(mdan): Look into overloading as function and decorator, like tfe.defun?
# TODO(mdan): Remove partial_types.
def to_graph(e,
recursive=True,
verbose=False,
arg_values=None,
arg_types=None,
partial_types=None,
strip_decorators=None):
"""Converts a Python entity into equivalent code that uses TensorFlow ops.
Supported Python entities include:
* functions
* classes
Classes are converted by converting all their methods into a new class.
Args:
e: Union[Callable, Type], the Python entity to convert.
recursive: bool, whether to recursively convert any functions that the
converted function may call.
verbose: bool, whether to output the compiled code in the logs.
arg_values: Optional[Dict[Text, Any]], value hints for symbols including
function arguments.
arg_types: Optional[Dict[Text, Type]], type hints for symbols including
function arguments.
partial_types: Set[Type], reserved for internal use.
strip_decorators: Tuple[Callable], same as
ConversionOptions.strip_decorators.
Returns:
Union[Callable, Type], the converted entity, which is the same kind as e
(that is, a function is e is a function, a class if e is a class, etc.) but
its code has been converted to use TF ops.
Raises:
ValueError: If the entity could not be converted.
"""
if strip_decorators is None:
strip_decorators = ()
strip_decorators += (convert, do_not_convert, converted_call)
program_ctx = converter.ProgramContext(
recursive=recursive,
autograph_decorators=strip_decorators,
partial_types=partial_types,
autograph_module=tf_inspect.getmodule(to_graph),
uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
_, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
arg_types)
nodes = []
for dep in reversed(program_ctx.conversion_order):
nodes.extend(program_ctx.dependency_cache[dep])
compiled_module, compiled_src = compiler.ast_to_object(
nodes,
source_prefix=program_ctx.required_imports,
include_source_map=True)
# The compiled code should see everything the entry entity saw.
# TODO(mdan): This might not work well if the call tree spans modules?
for key, val in namespace.items():
# Avoid overwriting entities that have been transformed.
if key not in compiled_module.__dict__:
compiled_module.__dict__[key] = val
compiled = getattr(compiled_module, name)
# Need this so the source_mapping attribute is available for the context
# manager to access for runtime errors.
#
# Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
# symbol to the compiled module.
# TODO(mdan): Record this statically in the generated code.
# TODO(mdan): Rename this attribute to 'autograph_info__'
source_map_attribute_name = 'ag_source_map'
if getattr(compiled, source_map_attribute_name, None) is not None:
raise ValueError('cannot convert %s because is has an attribute '
'"%s", which is reserved for AutoGraph.' %
(compiled, source_map_attribute_name))
setattr(compiled, source_map_attribute_name,
compiled_module.__dict__['ag_source_map__'])
if verbose:
logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)
return compiled
def to_code(e,
recursive=True,
arg_values=None,
arg_types=None,
partial_types=None,
indentation=' '):
"""Returns the equivalent code that uses TensorFlow ops.
Also see: `to_graph`, `convert`
Args:
e: Union[Callable, Type], the Python entity to convert.
recursive: bool, whether to recursively convert any functions that the
converted function may call.
arg_values: Optional[Dict[Text, Any]], value hints for symbols including
function arguments.
arg_types: Optional[Dict[Text, Type]], type hints for symbols including
function arguments.
partial_types: Set[Type], reserved for internal use.
indentation: Text, when to use for each level of indentation.
Returns:
Text, the converted code.
"""
program_ctx = converter.ProgramContext(
recursive=recursive,
autograph_decorators=(convert, do_not_convert, converted_call),
partial_types=partial_types,
autograph_module=tf_inspect.getmodule(to_graph),
uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)
code = '\n'.join(
compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation)
for dep in reversed(program_ctx.conversion_order))
return program_ctx.required_imports + '\n\n' + code