blob: 9d9cf0b50c380a6b25b11b4ef2b3129ed50e6bde [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=unidiomatic-typecheck
"""Defun decorator for defining graph-mode functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
import itertools
import pprint
import threading
import types as types_lib
import weakref
import numpy as np
import six
from six.moves import map
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.python import _pywrap_utils
from tensorflow.python import pywrap_tfe
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
from tensorflow.python.eager import forwardprop_util
from tensorflow.python.eager import monitoring
from tensorflow.python.eager import tape
from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import errors
from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import default_gradient
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.profiler import trace
from tensorflow.python.saved_model import save_context
from tensorflow.python.util import compat
from tensorflow.python.util import function_utils
from tensorflow.python.util import lazy_loader
from tensorflow.python.util import memory
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
# Loaded lazily due to a circular dependency (roughly
# tf.function->autograph->->dataset->tf.function).
# TODO(b/133251390): Use a regular import.
ag_ctx = lazy_loader.LazyLoader(
"ag_ctx", globals(),
"tensorflow.python.autograph.core.ag_ctx")
np_arrays = lazy_loader.LazyLoader(
"np_arrays", globals(),
"tensorflow.python.ops.numpy_ops.np_arrays")
FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"
BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"
IMPLEMENTS_ATTRIBUTE_NAME = "_implements"
SHARED_RENDEZVOUS_ATTRIBUTE_NAME = "shared_rendezvous"
_graph_building_time_counter = monitoring.Counter(
"/tensorflow/core/tf_function/graph_building_time_usecs",
"Time for tf.function to build a graph (us).")
def _make_input_signature_hashable(elem):
"""Rewrite input signature to be hashable.
We replace nested variables in the input signature with TensorSpec in order to
be hashable.
Args:
elem: Input signature element
Returns:
A hashable object for the requested input signature
"""
try:
hash(elem)
except TypeError:
# TODO(slebedev): consider using nest.
if isinstance(elem, tuple):
return tuple(map(_make_input_signature_hashable, elem))
# TFE_Py_EncodeArg weakrefs arguments it does not recognize, and we expect
# all recognized types to be hashable.
assert isinstance(elem, weakref.ReferenceType)
v = elem()
if resource_variable_ops.is_resource_variable(v):
# We special case variables here to use unique_id as the cache key. This
# ensures we have to retrace whenever a different variable is passed in.
# This is needed to support cases where the user may use the id of a
# variable in the function perhaps as a lookup in a dictionary.
#
# This choice leads to more retracing when we could have possibly used the
# shape and dtype instead. However, we expect the number of variables in a
# program to be bounded, and correspondingly the number of retraces.
#
# Note we also include the class name to avoid collisions with strings.
return v.__class__, v._unique_id # pylint: disable=protected-access
if _is_ndarray(v):
# Numpy arrays are not hashable, but when calling functions we treat them
# in the same way as tf.Tensors.
if not hasattr(v, "shape") or not hasattr(v, "dtype"):
# TODO(tomhennigan) De-dup with _as_ndarray in _convert_numpy_inputs.
v = _as_ndarray(v)
return tensor_spec.TensorSpec(v.shape, v.dtype)
raise ValueError("Arguments to a tf.function must be Tensors, Variables, "
"or hashable Python objects (or nested structures of "
"these types).\nGot type: %s" % type(v).__name__)
return elem
CacheKey = collections.namedtuple("CacheKey", [
"input_signature",
"parent_graph",
"device_functions",
"colocation_stack",
"in_cross_replica_context",
"variable_policy",
"xla_context_id",
])
def _type_spec_for(x):
"""Returns a TypeSpec for `x`, or `None` if `x` doesn't have a TensorSpec."""
if isinstance(x, ops.Tensor):
return tensor_spec.TensorSpec.from_tensor(x)
elif isinstance(x, type_spec.TypeSpec):
return x
elif isinstance(x, composite_tensor.CompositeTensor):
return x._type_spec # pylint: disable=protected-access
else:
return None
def _is_type_subset(a, b):
"""Returns true if TypeSpec `b` is a subset of type `a` (or if a is None.)"""
if a is None:
return True
else:
return a.most_specific_compatible_type(b) == a
def _shape_relaxed_type_for_composite_tensor(x):
"""Returns a shape-relaxed TypeSpec for x (if composite) or x (if not)."""
if isinstance(x, composite_tensor.CompositeTensor):
# pylint: disable=protected-access
return x._type_spec._with_tensor_ranks_only()
else:
return x
def common_shape(x, y):
"""Find a `TensorShape` that is compatible with both `x` and `y`."""
if x is None != y is None:
raise RuntimeError(
"Cannot find a common shape when LHS shape is None but RHS shape "
"is not (or vice versa): %s vs. %s" % (x, y))
if x is None:
return None # The associated input was not a Tensor, no shape generated.
if not isinstance(x, tensor_shape.TensorShape):
raise TypeError("Expected x to be a TensorShape but saw %s" % (x,))
if not isinstance(y, tensor_shape.TensorShape):
raise TypeError("Expected y to be a TensorShape but saw %s" % (y,))
if x.rank != y.rank or x.rank is None:
return tensor_shape.TensorShape(None)
dims = []
for dim_x, dim_y in zip(x.dims, y.dims):
if (dim_x != dim_y
or tensor_shape.dimension_value(dim_x) is None
or tensor_shape.dimension_value(dim_y) is None):
dims.append(None)
else:
dims.append(tensor_shape.dimension_value(dim_x))
return tensor_shape.TensorShape(dims)
def is_same_structure(structure1,
structure2,
check_values=False):
"""Check two structures for equality, optionally of types and of values."""
try:
nest.assert_same_structure(structure1, structure2, expand_composites=True)
except (ValueError, TypeError):
return False
if check_values:
flattened1 = nest.flatten(structure1, expand_composites=True)
flattened2 = nest.flatten(structure2, expand_composites=True)
# First check the types to avoid AttributeErrors.
if any(type(f1) != type(f2) for f1, f2 in zip(flattened1, flattened2)):
return False
return flattened1 == flattened2
return True
def _parse_func_attrs(attributes):
"""Convert the keyword arguments into function_def attributes.
Currently only support primitive types: bool, int, float and string.
Args:
attributes: the dictionary of attributes.
Returns:
A dict of attributes where the key is the name of attribute and the value
is the AttrValue proto.
Raises:
ValueError: If the kwargs contains unallowlisted name or unsupported value
types.
"""
attrs = {}
for key, value in attributes.items():
if isinstance(value, attr_value_pb2.AttrValue):
attrs[key] = value
# bool type check has to happen before int since bool is a subclass of int.
elif isinstance(value, bool):
attrs[key] = attr_value_pb2.AttrValue(b=value)
elif isinstance(value, int):
attrs[key] = attr_value_pb2.AttrValue(i=value)
elif isinstance(value, float):
attrs[key] = attr_value_pb2.AttrValue(f=value)
elif isinstance(value, (str, bytes, six.text_type)):
attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value))
else:
raise ValueError("Unsupported attribute type for %s with type %s" %
(key, type(value)))
return attrs
class _InterpolateFunctionError(object):
"""Context Manager that interpolates the exception from 'top_level_func'."""
__slots__ = ["_func"]
def __init__(self, top_level_func):
self._func = top_level_func
def __enter__(self):
pass
def __exit__(self, typ, exc, tb):
if not exc or not isinstance(exc, errors.OpError):
return False
message = compat.as_text(exc.message)
_, tags = error_interpolation.parse_message(message)
g = None
func_stack = []
for t in tags:
if t.type == "function_node":
# TODO(mdan): Tests should cover this.
if t.name == compat.as_str(self._func.name):
g = self._func.graph
elif g:
next_func = g._get_function(t.name) # pylint: disable=protected-access
if next_func is not None and isinstance(next_func,
_EagerDefinedFunction):
g = next_func.graph
if g:
func_stack.append(g.name)
else:
func_stack.append("<unknown>")
if g:
message = error_interpolation.interpolate(message, g)
message += "\n\nFunction call stack:\n"
message += " -> ".join(func_stack)
message += "\n"
exc._message = message # pylint: disable=protected-access
return False
_function_callbacks = set()
def add_function_callback(function_callback):
"""Add a callback function for Function creation.
The callback function has the signature:
`def function_callback(function):`
wherein `function` is the just-created _EagerDefinedFunction.
The callback is invoked immediately after a new `_EagerDefinedFunction`
is created. The return value(s) of the callback function (if any) is ignored.
Repeated registration of the same callback function is idempotent.
After a callback is added, it can be removed with the
`remove_function_callback()` method.
Args:
function_callback: The callback to add.
"""
_function_callbacks.add(function_callback)
def remove_function_callback(function_callback):
"""Remove an already-added function callback.
See the doc string of `add_function_callback()` for more information.
Args:
function_callback: The callback to remove.
"""
_function_callbacks.remove(function_callback)
def clear_function_callbacks():
"""Clear all function callbacks, if any have been regisered."""
_function_callbacks.clear()
_FORWARD_PREFIX = "__forward_"
_BACKWARD_PREFIX = "__backward_"
_INFERENCE_PREFIX = "__inference_"
def _forward_name(n):
"""The name of a generated forward defun named n."""
return "%s%s_%s" % (_FORWARD_PREFIX, n, ops.uid())
def _backward_name(n):
"""The name of a generated backward defun named n."""
return "%s%s_%s" % (_BACKWARD_PREFIX, n, ops.uid())
def _inference_name(n):
"""The name of a forward-but-no-gradient defun named n."""
return "%s%s_%s" % (_INFERENCE_PREFIX, n, ops.uid())
def _enclosing_xla_context():
"""Returns the XLAControlFlowContext, which exists inside a tpu.rewrite()."""
graph = ops.get_default_graph()
while graph is not None:
# pylint: disable=protected-access
context_ = graph._get_control_flow_context()
# pylint: enable=protected-access
while context_ is not None:
if isinstance(context_, control_flow_ops.XLAControlFlowContext):
return context_
context_ = context_.outer_context
# This may be a FuncGraph due to defuns or v2 control flow. We need to
# find the original graph with the XLAControlFlowContext.
graph = getattr(graph, "outer_graph", None)
return None
class _EagerDefinedFunctionDeleter(object):
"""Unregister function from eager context."""
__slots__ = ["name"]
def __init__(self, name):
self.name = name
def __del__(self):
try:
context.remove_function(self.name)
except TypeError:
# Suppress some exceptions, mainly for the case when we're running on
# module deletion. Things that can go wrong include the context module
# already being unloaded, self._handle._handle_data no longer being
# valid, and so on. Printing warnings in these cases is silly
# (exceptions raised from __del__ are printed as warnings to stderr).
pass # 'NoneType' object is not callable when the handle has been
# partially unloaded.
except AttributeError:
pass # 'NoneType' object has no attribute 'eager_mode' when context has
# been unloaded. Will catch other module unloads as well.
# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction
# so it doesn't have the definition-generating logic and is just a container for
# an already-defined function.
class _EagerDefinedFunction(object):
"""Callable with the interface of `framework.function._DefinedFunction`.
`_EagerDefinedFunction` encapsulates a function definition and its properties,
and it provides a method for calling the encapsulated function. Some Ops
take functions as attributes, which have type `func`; an instance of this
class may be provided as the value of these `func` attributes.
"""
def __init__(self, name, graph, inputs, outputs, attrs):
"""Initializes an eager defined function.
Args:
name: str, the name for the created function.
graph: Graph, the graph containing the operations in the function
inputs: the tensors in the graph to be used as inputs to the function
outputs: the tensors in the graph which will be outputs to the function
attrs: dict mapping names of attributes to their AttrValue values
"""
input_ops = set(arg.op for arg in inputs)
operations = [op for op in graph.get_operations() if op not in input_ops]
graph_output_names = graph._output_names # pylint: disable=protected-access
if (graph_output_names is not None and
all(ops.tensor_id(t) in graph_output_names for t in outputs)):
output_names = [
compat.as_bytes(graph_output_names[ops.tensor_id(t)]) for t in outputs
]
if len(set(output_names)) != len(output_names):
# There are duplicate names for some reason, probably an invalid
# signature. Revert to auto-naming.
output_names = []
else:
output_names = []
fn = pywrap_tf_session.TF_GraphToFunction_wrapper(
graph._c_graph, # pylint: disable=protected-access
compat.as_str(name),
False,
[o._c_op for o in operations], # pylint: disable=protected-access
[t._as_tf_output() for t in inputs], # pylint: disable=protected-access
[t._as_tf_output() for t in outputs], # pylint: disable=protected-access
output_names,
[o._c_op for o in graph.control_outputs], # pylint: disable=protected-access
[], # control_output_names
None,
compat.as_str(""))
for name, attr_value in attrs.items():
serialized = attr_value.SerializeToString()
# TODO(iga): this creates and deletes a new TF_Status for every attr.
# It might be worth creating a convenient way to re-use status.
pywrap_tf_session.TF_FunctionSetAttrValueProto(fn, compat.as_str(name),
serialized)
# TODO(apassos) avoid creating a FunctionDef (specially to grab the
# signature, but also in general it's nice not to depend on it.
with c_api_util.tf_buffer() as buffer_:
pywrap_tf_session.TF_FunctionToFunctionDef(fn, buffer_)
proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
function_def = function_pb2.FunctionDef()
function_def.ParseFromString(compat.as_bytes(proto_data))
self._name = compat.as_bytes(function_def.signature.name)
with ops.init_scope():
if context.executing_eagerly():
context.ensure_initialized()
context.add_function(fn)
self._function_deleter = _EagerDefinedFunctionDeleter(self.name)
self._registered_on_context = True
self.definition = function_def
self.signature = function_def.signature
self._num_outputs = len(self.signature.output_arg)
self._output_types = [o.type for o in self.signature.output_arg]
self._output_shapes = [o.shape for o in outputs]
self._control_captures = graph.control_captures
# Shallow copy outputs since ConcreteFunction may mutate it.
self._func_graph_outputs = list(outputs)
self.grad_func_name = None
self.python_grad_func = None
self._c_func = c_api_util.ScopedTFFunction(fn)
self._grad_func = None
self.graph = graph
self._stateful_ops = tuple(op for op in operations if op._is_stateful) # pylint: disable=protected-access
for function_callback in _function_callbacks:
function_callback(self)
def add_to_graph(self, g=None):
# pylint: disable=protected-access
if not g and context.executing_eagerly():
context.context().add_function_def(self.definition)
else:
if not g._is_function(self.name):
g._add_function(self)
for f in self.graph._functions.values():
if not g._is_function(f.name):
g._add_function(f)
# pylint: enable=protected-access
@property
def name(self):
return self._name
@property
def stateful_ops(self):
return self._stateful_ops
def call(self, ctx, args, cancellation_manager=None):
"""Calls this function with `args` as inputs.
`ConcreteFunction` execution respects device annotations only if the
function won't be compiled with xla.
Args:
ctx: a Context object
args: a list of arguments to supply this function with.
cancellation_manager: a `CancellationManager` object that can be used to
cancel function execution.
Returns:
The outputs of the function call.
Raises:
ValueError: if the number of arguments is incorrect.
"""
if len(args) != len(self.signature.input_arg):
raise ValueError(
"Arguments and signature arguments do not match. "
"got: %s, expected: %s " %
(len(args), len(list(self.signature.input_arg))))
function_call_options = ctx.function_call_options
if function_call_options.config_proto_serialized is None:
config = function_utils.get_disabled_rewriter_config()
else:
config = function_call_options.config_proto_serialized
executor_type = function_call_options.executor_type or ""
executing_eagerly = ctx.executing_eagerly()
attrs = ("executor_type", executor_type, "config_proto", config)
if executing_eagerly:
with _InterpolateFunctionError(self):
if cancellation_manager is None:
outputs = execute.execute(
str(self.signature.name),
num_outputs=self._num_outputs,
inputs=args,
attrs=attrs,
ctx=ctx)
else:
outputs = execute.execute_with_cancellation(
str(self.signature.name),
num_outputs=self._num_outputs,
inputs=args,
attrs=attrs,
ctx=ctx,
cancellation_manager=cancellation_manager)
# Replace empty list with None
outputs = outputs or None
else:
# TODO(akshayka): Either remove this if the FunctionLibraryRuntime
# creates `PartitionedCallOp` kernels by default, or remove the previous
# branch if a TPU kernel is registered for `PartitionedCall`.
with _InterpolateFunctionError(self):
with ops.control_dependencies(self._control_captures):
# The caller must use record_operation to record this operation in the
# eager case, so we enforce the same requirement for the non-eager
# case by explicitly pausing recording. We don't have a gradient
# registered for PartitionedCall, so recording this operation confuses
# forwardprop code (GradientTape manages to ignore it).
with tape.stop_recording():
outputs = functional_ops.partitioned_call(
args=args,
f=self,
tout=self._output_types,
executing_eagerly=executing_eagerly,
config=config,
executor_type=executor_type)
for i, func_graph_output in enumerate(self._func_graph_outputs):
custom_gradient.copy_handle_data(func_graph_output, outputs[i])
if executing_eagerly:
return outputs
else:
# TODO(b/128924522): This additional set_shape should not be
# necessary. ShapeRefiner likely needs to inspect handle_data. Remove this
# once that's done.
for i, shape in enumerate(self._output_shapes):
outputs[i].set_shape(shape)
return outputs
class _DelayedRewriteGradientFunctions(object):
"""Caches forward/backward functions with a delayed forward rewrite."""
def __init__(self, func_graph, attrs, func_graph_deleter):
"""Construct an inference function and initialize caches."""
# A map from the number of forward function outputs with accepted gradients
# to forward and backward functions, used to cache non-tape backward
# function generation.
self._cached_function_pairs = {}
self._func_graph = func_graph
self._inference_function = _EagerDefinedFunction(
_inference_name(self._func_graph.name), self._func_graph,
self._func_graph.inputs, self._func_graph.outputs, attrs)
self._attrs = attrs
self._gradient_name = None
# Note that the FuncGraph is mutated later, so we need to inspect it now to
# figure out the user-specified outputs of the inference function.
self._num_inference_outputs = len(self._func_graph.outputs)
self._func_graph_deleter = func_graph_deleter
def forward_backward(self, num_doutputs=None):
"""A possibly-cached pair of forward and backward functions."""
if num_doutputs is None:
num_doutputs = self._num_inference_outputs
forward_backward = self._cached_function_pairs.get(num_doutputs)
if forward_backward is not None:
return forward_backward
forward, backward = self._construct_forward_backward(num_doutputs)
self._cached_function_pairs[num_doutputs] = (forward, backward)
return forward, backward
def _construct_forward_backward(self, num_doutputs):
"""Constructs a pair of forward and backward functions.
Args:
num_doutputs: The constructed backprop function will take output gradients
for the first `num_doutputs` outputs of the forward function. Defaults
to the number of outputs for the inference function, but when
higher-order gradients are computed this will increase to include side
outputs.
Returns:
A pair of (forward_function, backward_function):
forward_function: A re-generated inference function (an
_EagerDefinedFunction) to account for new side outputs, if any extra
were required when building the backward pass.
backward_function: A ConcreteFunction that Takes `num_doutputs`
arguments and returns gradients with respect to inputs of the forward
function.
"""
trainable_outputs = [
output for output in self._func_graph.outputs[:num_doutputs]
if backprop_util.IsTrainable(output)]
signature = []
for t in trainable_outputs:
signature.append(
tensor_spec.TensorSpec(*default_gradient.shape_and_dtype(t)))
def _backprop_function(*grad_ys):
with ops.device(None):
return gradients_util._GradientsHelper( # pylint: disable=protected-access
trainable_outputs,
self._func_graph.inputs,
grad_ys=grad_ys,
src_graph=self._func_graph)
with self._func_graph.as_default():
backwards_graph = func_graph_module.FuncGraph(
_backward_name(self._func_graph.name))
func_graph_module.func_graph_from_py_func(
name=backwards_graph.name,
python_func=_backprop_function,
args=[], kwargs={},
signature=signature,
func_graph=backwards_graph)
backwards_graph_captures = backwards_graph.external_captures
captures_from_forward = [
c for c in backwards_graph_captures if
not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph]
forward_function_name = _forward_name(self._func_graph.name)
# NB: forward and backward function have their "_implements"
# attribute set to None if it was present. This is because we don't
# support replacing those functions. If we do want for those functions
# to have implements function we need to provide a mechanism that
# would allow to identify all functions that call this one
# and trace and update their signatures as well. At the moment
# we disable this, until the tooling for doing this becomes available.
# See:
# https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md#appendix-future-support-for-optimizing-gradient-functions
common_attributes = dict(self._attrs)
common_attributes.pop(IMPLEMENTS_ATTRIBUTE_NAME, None)
existing_outputs = object_identity.ObjectIdentitySet(
self._func_graph.outputs)
for capture in captures_from_forward:
if capture not in existing_outputs:
existing_outputs.add(capture)
self._func_graph.outputs.append(capture)
backward_function_attr = _parse_func_attrs(
{FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
backward_function_attr.update(common_attributes)
backward_function = ConcreteFunction(
backwards_graph, attrs=backward_function_attr)
forward_function_attr = _parse_func_attrs({
BACKWARD_FUNCTION_ATTRIBUTE_NAME:
backward_function.name})
forward_function_attr.update(common_attributes)
forward_function = _EagerDefinedFunction(
forward_function_name, self._func_graph, self._func_graph.inputs,
self._func_graph.outputs, forward_function_attr)
return forward_function, backward_function
def _rewrite_forward_and_call_backward(self, op, *doutputs):
"""Add outputs to the forward call and feed them to the grad function."""
forward_function, backwards_function = self.forward_backward(len(doutputs))
if not backwards_function.outputs:
return backwards_function.structured_outputs
forward_function.add_to_graph(op.graph)
# pylint: disable=protected-access
# Rewrite an inference call op to be a forward call op
op._set_func_attr("f", forward_function.name)
op._set_type_list_attr("Tout", forward_function._output_types)
op._add_outputs(
forward_function._output_types[len(op.outputs):],
forward_function._output_shapes[len(op.outputs):])
for i in range(len(op.outputs)):
func_graph_output = forward_function._func_graph_outputs[i]
custom_gradient.copy_handle_data(func_graph_output, op.outputs[i])
# pylint: enable=protected-access
capture_mapping = dict(
zip((ops.tensor_id(t) for t in self._func_graph.outputs), op.outputs))
remapped_captures = [
capture_mapping.get(ops.tensor_id(capture), capture)
for capture in backwards_function.captured_inputs
]
# Replace Nones with zeros since we're calling a graph function which
# expects numeric inputs.
cleaned_doutputs = []
for doutput, placeholder in zip(doutputs, self._func_graph.outputs):
if backprop_util.IsTrainable(placeholder):
if isinstance(doutput, ops.IndexedSlices):
# Gradient passed to a backward ConcreteFunction must be tf.Tensor,
# so we convert tf.IndexedSlices to tf.Tensor.
cleaned_doutputs.append(ops.convert_to_tensor(doutput))
elif doutput is not None:
cleaned_doutputs.append(doutput)
else:
cleaned_doutputs.append(default_gradient.zeros_like(placeholder))
# Compute the gradients using the side outputs
return backwards_function._call_flat( # pylint: disable=protected-access
cleaned_doutputs, remapped_captures)
def get_gradient_function(self):
"""Returns gradient function.
The gradient rewrites an inference call op to a forward call op, but does
not modify a pre-existing forward call op. It then computes the gradient
from the output's gradients and the side outputs of the forward op.
"""
return self._rewrite_forward_and_call_backward
def forward(self, inference_args=None, input_tangents=None):
"""A forward function with only user-specified outputs.
The call operation for the returned inference function can be rewritten into
a forward function. This only happens if the backward function (from the
`backward` method) ends up being used to compute gradients.
This approach avoids constructing unnecessary graphs, but it only works if
we are calling this function when not executing eagerly.
Args:
inference_args: A flat list of Tensors, arguments to the inference
function. Unused, but taken for compatibility with
_TapeGradientFunctions.
input_tangents: A flat list of Tensors, jvps associated with
`inference_args`. Unused; if required, tape functions must be used
instead.
Returns:
An _EagerDefinedFunction.
"""
del inference_args # unused
if input_tangents:
# This class does not support special-cased forwardprop. The arguments are
# here for compatibility with _TapeGradientFunctions.
raise AssertionError(
"Internal error: unexpectedly got forwardprop information in a class "
"that does not support forwardprop.")
return self._inference_function
def _backward(self, outputs):
"""Fetch a backward function for `outputs` from the forward function."""
def _backward_function(*args):
call_op = outputs[0].op
return self._rewrite_forward_and_call_backward(call_op, *args)
return _backward_function, outputs
def record(self, flat_outputs, inference_args, input_tangents):
"""Record the function call operation.
_DelayedRewriteGradientFunctions supports only first-order backprop tape
gradients (and then only when graph building). It does not work with
higher-order tape gradients or forward autodiff, but does work with
higher-order symbolic gradients (tf.gradients).
Args:
flat_outputs: The result of running `forward`.
inference_args: A flat list of Tensors with inference inputs to the
operation.
input_tangents: A flat list of Tensors with input tangents consumed by the
operation.
"""
backward_function, to_record = self._backward(flat_outputs)
tape.record_operation(self._inference_function.signature.name,
to_record, inference_args + input_tangents,
backward_function)
# Contains information about a forward function wrapped to compute jvps.
_ForwardWrapper = collections.namedtuple(
"_ForwardWrapper", (
# The wrapper Graph.
"graph",
# A flat list of non-tangent Tensor outputs from the wrapped forward
# function.
"outputs",
# Indices for output tangents, same format as
# forwardprop_util.pack_tangents.
"output_indices",
# A flat list of tangents for `outputs`.
"output_tangents"))
class _TapeGradientFunctions(object):
"""Caches forward and backward functions compatible with eager gradients.
In contrast to the delayed-rewrite approach in
`_DelayedRewriteGradientFunctions` which only works with delayed execution,
the forward function generated by this class has a fixed set of outputs which
may be preserved by a tape in order to compute gradients later.
This class is abstract; its child classes differ in how many side outputs of
the forward function their backward function accepts gradients for, which
determines whether higher-order tape gradients are possible.
"""
def __init__(self, func_graph, attrs, func_graph_deleter,
forwardprop_input_indices, delayed_rewrite_functions,
need_gradients_for_jvps):
self._func_graph = func_graph
self._forward_graph = None
self._attrs = attrs
self._forward = None
self._backward = None
self._num_outputs = len(func_graph.outputs)
self._func_graph_deleter = func_graph_deleter
self._forwardprop_input_indices = forwardprop_input_indices
self._forwardprop_output_indices = None
self._num_forwardprop_outputs = 0
self._num_inference_outputs = len(func_graph.outputs)
self._num_trainable_inference_outputs = len(
[t for t in func_graph.outputs if backprop_util.IsTrainable(t)])
self._delayed_rewrite_functions = delayed_rewrite_functions
self._need_gradients_for_jvps = need_gradients_for_jvps
def _build_functions_for_outputs(
self, outputs, inference_args, input_tangents):
"""Forward+backward functions where the backward function sees `outputs`."""
# First figure out which of `outputs` are trainable. We'll accept gradients
# for each of these in the backward function.
handles_to_variables = self._func_graph.variable_captures
trainable_outputs = []
trainable_indices = []
for index, output in enumerate(outputs):
if backprop_util.IsTrainable(output):
# Swap in the Variable object for resource handles if we can so
# sparse gradients work.
output = handles_to_variables.get(id(output), output)
trainable_outputs.append(output)
trainable_indices.append(index)
backwards_graph = func_graph_module.FuncGraph(
_backward_name(self._func_graph.name))
with backwards_graph.as_default():
gradients_wrt_outputs = []
for output in trainable_outputs:
gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
output)
gradient_placeholder = graph_placeholder(gradient_dtype, gradient_shape)
custom_gradient.copy_handle_data(output, gradient_placeholder)
gradients_wrt_outputs.append(gradient_placeholder)
with ops.device(None):
gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access
trainable_outputs,
self._func_graph.inputs,
grad_ys=gradients_wrt_outputs,
src_graph=self._func_graph)
if input_tangents:
# Convert IndexedSlices to dense tensors (as we do elsewhere for
# function gradients). Our C++ bindings don't know how to handle them
# currently.
gradients_wrt_inputs = nest.map_structure(
lambda x: ops.convert_to_tensor(x) if x is not None else None,
gradients_wrt_inputs)
captures_from_forward = [
c for c in backwards_graph.external_captures
if not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph
]
existing_outputs = object_identity.ObjectIdentitySet(
self._func_graph.outputs)
for capture in captures_from_forward:
if capture not in existing_outputs:
existing_outputs.add(capture)
self._func_graph.outputs.append(capture)
forward_function_name = _forward_name(self._func_graph.name)
backward_function_attr = _parse_func_attrs(
{FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
backward_function_attr.update(self._attrs)
# The ordering of `backwards_graph.inputs` is important: inputs of
# `backward_function` correspond to outputs (including
# side outputs) of `self._tape_forward_function`.
backwards_graph.inputs = (
gradients_wrt_outputs + backwards_graph.internal_captures)
backwards_graph.outputs.extend(
grad
for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True)
if grad is not None)
backwards_graph.structured_outputs = gradients_wrt_inputs
backward_function = ConcreteFunction(
backwards_graph, attrs=backward_function_attr)
forward_function_attr = _parse_func_attrs({
BACKWARD_FUNCTION_ATTRIBUTE_NAME:
backward_function.name})
forward_function_attr.update(self._attrs)
forward_function = _EagerDefinedFunction(
forward_function_name, self._func_graph, self._func_graph.inputs,
self._func_graph.outputs,
forward_function_attr)
if not input_tangents:
# There is no need to special-case forwardprop, so we can return the
# forward+backward pair we've created without further wrapping.
return (forward_function, self._func_graph, backward_function,
# No forwardprop outputs.
None, 0)
forward_wrapper = self._wrap_forward_function_with_jvps(
forward_function, backward_function, inference_args, input_tangents)
(wrapped_backwards_graph,
forward_wrapper) = self._wrap_backward_function_with_jvp_backprop(
backward_function, gradients_wrt_outputs, forward_wrapper)
# Now that we've added new captures, we need to make sure forward outputs
# are in the same order the backward function expects them to be in:
# [inference outputs] + [jvps] + [side outputs] + [captures].
forward_wrapper = self._shuffle_forward_outputs(forward_wrapper)
wrapped_forward_function = _EagerDefinedFunction(
_forward_name(self._func_graph.name), forward_wrapper.graph,
forward_wrapper.graph.inputs, forward_wrapper.graph.outputs,
forward_function_attr)
wrapped_backward_function = ConcreteFunction(
wrapped_backwards_graph, attrs=backward_function_attr)
if (len(inference_args) + len(input_tangents)
!= len(forward_wrapper.graph.inputs)):
raise AssertionError(
("Internal error: the forward graph had {} inputs, but we expected"
" {} ({} inference inputs and {} input tangents)")
.format(len(len(forward_wrapper.graph.inputs)),
len(inference_args) + len(input_tangents),
len(inference_args), len(input_tangents)))
return (wrapped_forward_function, forward_wrapper.graph,
wrapped_backward_function, forward_wrapper.output_indices,
len(forward_wrapper.output_tangents))
def _wrap_forward_function_with_jvps(
self, forward_function, backward_function,
inference_args, input_tangents):
"""Adds inline JVP computation to a forward function."""
forward_wrapper_graph = func_graph_module.FuncGraph(
_forward_name(self._func_graph.name))
with forward_wrapper_graph.as_default():
# Tell forward accumulators to free up space for new JVP computations,
# since one may be in the process of computing a JVP (if that computation
# triggered this function building).
#
# We'll make symbolic versions of input JVPs, run the forward function
# under forward accumulators to get symbolic output JVPs, then set those
# as outputs of the new wrapped forward function.
with forwardprop_util.push_forwardprop_state():
forward_captures = {
ops.tensor_id(internal): external
for external, internal in self._func_graph.captures}
for input_index, real_input in enumerate(self._func_graph.inputs):
# This loop is more or less equivalent to running tf.identity on each
# of self._func_graph.inputs. However, doing that also captures jvps
# for resource handles, which confuses the jvp capturing code below
# (since primal inputs are interwoven with jvp inputs).
input_placeholder = array_ops.placeholder(
dtype=real_input.dtype,
shape=real_input.shape)
capture = forward_captures.get(ops.tensor_id(real_input))
if capture is not None:
forward_wrapper_graph.add_capture(capture, input_placeholder)
if capture.dtype == dtypes.resource:
custom_gradient.copy_handle_data(capture, input_placeholder)
else:
forward_wrapper_graph.inputs.append(input_placeholder)
for inp, arg in zip(forward_wrapper_graph.inputs, inference_args):
tape.record_operation(
"captured_value", [inp], [arg],
backward_function=lambda x: [x],
forward_function=lambda x: [x])
num_inference_inputs = len(inference_args)
for tape_indices in self._forwardprop_input_indices:
for input_index, jvp_index in tape_indices:
input_placeholder = forward_wrapper_graph.inputs[input_index]
if len(forward_wrapper_graph.inputs) != jvp_index:
raise AssertionError(
("Internal error: expected {} forward graph inputs, but "
"found {}.")
.format(jvp_index, len(forward_wrapper_graph.inputs)))
gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
input_placeholder)
jvp_placeholder = graph_placeholder(gradient_dtype, gradient_shape)
external_jvp = input_tangents[jvp_index - num_inference_inputs]
forward_wrapper_graph.add_capture(external_jvp, jvp_placeholder)
tensor_shape.TensorShape(
external_jvp.shape).assert_is_compatible_with(
jvp_placeholder.shape)
tape.record_operation(
"captured_value",
[jvp_placeholder],
[external_jvp],
backward_function=lambda x: [x],
forward_function=lambda x: [x])
forward_inputs = forward_wrapper_graph.inputs[:num_inference_inputs]
gradient_function = (
self._delayed_rewrite_functions._rewrite_forward_and_call_backward) # pylint: disable=protected-access
with ops.get_default_graph()._override_gradient_function( # pylint: disable=protected-access
{"PartitionedCall": gradient_function,
"StatefulPartitionedCall": gradient_function}):
forward_outputs = forward_function.call(context.context(),
forward_inputs)
if isinstance(forward_outputs, ops.Operation):
# _wrapped_backward_function expects a list, but if the function has
# no outputs its call() returns an Operation. We need to undo that
# so we don't cause problems later.
forward_outputs = []
py_backward, _ = self._wrap_backward_function(
self._func_graph, backward_function, forward_outputs)
# We will never request backward tape gradients for this operation
# directly since we're wrapping the call; forwardprop will call the
# backward function (and nested forward accumulators may build
# higher-order gradients), but any watching GradientTapes should ignore
# it.
#
# TODO(allenl): It might be better to explicitly stop backward recording
# so we don't use the second-order tape cases unnecessarily.
tape.record_operation_forwardprop_only(
forward_function.signature.name,
forward_outputs, forward_inputs, py_backward, None)
output_indices, output_tangents = (
pywrap_tfe.TFE_Py_PackJVPs(forward_outputs))
output_tangents = [forward_wrapper_graph.capture(t)
for t in output_tangents]
return _ForwardWrapper(
graph=forward_wrapper_graph, outputs=forward_outputs,
output_indices=output_indices, output_tangents=output_tangents)
def _wrap_backward_function_with_jvp_backprop(
self, backward_function, gradients_wrt_outputs, forward_wrapper):
"""Wraps `backward_function` to include gradients for JVPs."""
wrapped_backwards_graph = func_graph_module.FuncGraph(
_backward_name(self._func_graph.name))
with wrapped_backwards_graph.as_default():
py_backward, recorded_outputs = self._wrap_backward_function(
self._func_graph, backward_function, forward_wrapper.outputs)
trainable_index = 0
forward_doutputs = []
doutput_args = []
for output in recorded_outputs:
if backprop_util.IsTrainable(output):
doutput = gradients_wrt_outputs[trainable_index]
doutput_placeholder = graph_placeholder(doutput.dtype, doutput.shape)
doutput_args.append(doutput_placeholder)
forward_doutputs.append(doutput_placeholder)
trainable_index += 1
else:
doutput_args.append(None)
dinputs = py_backward(*doutput_args)
existing_outputs = object_identity.ObjectIdentitySet(
forward_wrapper.outputs + forward_wrapper.output_tangents)
num_processed_output_tangents = 0
gradients_wrt_output_tangents = []
tangent_doutputs = []
output_tangents = forward_wrapper.output_tangents
output_indices = forward_wrapper.output_indices
if self._need_gradients_for_jvps:
# TODO(allenl): Consider using a throwaway graph to avoid extra gradient
# evaluations; gradients for jvps may have common subgraphs.
while num_processed_output_tangents != len(output_tangents):
for output in output_tangents[num_processed_output_tangents:]:
gradient_shape, gradient_dtype = default_gradient.shape_and_dtype(
output)
placeholder = graph_placeholder(gradient_dtype, gradient_shape)
gradients_wrt_output_tangents.append(placeholder)
tangent_doutputs.append(placeholder)
num_processed_output_tangents = len(output_tangents)
with ops.device(None):
gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access
output_tangents,
forward_wrapper.graph.inputs,
grad_ys=gradients_wrt_output_tangents,
src_graph=forward_wrapper.graph)
dinputs = [
backprop.aggregate_indexed_slices_gradients((existing, new))
for existing, new in zip(dinputs, gradients_wrt_inputs)
if existing is not None or new is not None]
dinputs.extend(gradients_wrt_inputs[len(dinputs):])
captures_from_forward = [
c for c in wrapped_backwards_graph.external_captures
if (not isinstance(c, ops.EagerTensor)
and c.graph is forward_wrapper.graph)]
for capture in captures_from_forward:
if capture not in existing_outputs:
existing_outputs.add(capture)
forward_wrapper.outputs.append(capture)
output_indices, output_tangents = (
forwardprop_util.pack_tangents(forward_wrapper.outputs))
output_tangents = [forward_wrapper.graph.capture(t)
for t in output_tangents]
for t in output_tangents:
existing_outputs.add(t)
wrapped_backwards_graph.inputs = (
forward_doutputs[:self._num_trainable_inference_outputs]
+ tangent_doutputs
+ forward_doutputs[self._num_trainable_inference_outputs:]
+ wrapped_backwards_graph.internal_captures)
wrapped_backwards_graph.structured_outputs = dinputs
wrapped_backwards_graph.outputs = [t for t in dinputs if t is not None]
return (wrapped_backwards_graph,
forward_wrapper._replace(output_indices=output_indices,
output_tangents=output_tangents))
def _shuffle_forward_outputs(self, forward_wrapper):
"""Reorders function outputs so captures are last."""
def _index_map(original):
if original < self._num_inference_outputs:
return original
if original >= len(forward_wrapper.outputs):
return (original - len(forward_wrapper.outputs)
+ self._num_inference_outputs)
return original + len(forward_wrapper.output_tangents)
output_indices = nest.map_structure(
_index_map, forward_wrapper.output_indices)
forward_wrapper.graph.outputs = (
forward_wrapper.outputs[:self._num_inference_outputs]
+ forward_wrapper.output_tangents
+ forward_wrapper.outputs[self._num_inference_outputs:])
return forward_wrapper._replace(output_indices=output_indices)
def forward(self, inference_args, input_tangents):
"""Construct or fetch a forward function with side-outputs.
When graph building without a tape active, symbolic gradients rely on
regenerating the backward function for higher-order gradients (to account
for new side outputs of the rewritten forward function call). Thus there is
no fixed backward function for this case. However, when a tape is active
(eager or graph building), we generate fixed backward and forward functions
at forward function call time.
This difference between the tape and non-tape cases is to avoid building
unneeded backward functions while graph building (where we may or may not
eventually need gradients).
Args:
inference_args: A flat list of Tensors, arguments to the inference
function.
input_tangents: A flat list of Tensors, jvps associated with
`inference_args`.
Returns:
A forward _EagerDefinedFunction.
"""
if self._forward is None:
(self._forward, self._forward_graph, self._backward,
self._forwardprop_output_indices, self._num_forwardprop_outputs) = (
self._forward_and_backward_functions(inference_args, input_tangents))
return self._forward
def _wrap_backward_function(self, forward_graph, backward, outputs):
"""Create a backward function given `outputs` from the forward function."""
capture_mapping = dict(
zip((ops.tensor_id(t) for t in forward_graph.outputs), outputs))
captured_inputs = backward.captured_inputs
remapped_captures = [
capture_mapping.get(ops.tensor_id(capture), capture)
for capture in captured_inputs
]
if any(t.graph is forward_graph for t in remapped_captures
if not isinstance(t, ops.EagerTensor)):
raise AssertionError(
"Internal error: failed to map all backward graph captures to the "
"forward graph. Incorrectly mapped: {}".format(
[t for t in remapped_captures
if (not isinstance(t, ops.EagerTensor)
and t.graph is not forward_graph)]))
# We may need to use zeros_like to get a zero for variant Tensors with
# unconnected gradients. We do that in advance so we don't have to hold on
# to the outputs themselves, which may not be needed otherwise.
variant_zeros_like = {}
backward_function_inputs = (len(backward.inputs) - len(captured_inputs))
recorded_outputs = []
trainable_recorded_outputs = 0
skip_positions = []
if self._num_forwardprop_outputs and not self._need_gradients_for_jvps:
relevant_outputs = (
outputs[:self._num_inference_outputs]
+ outputs[self._num_inference_outputs
+ self._num_forwardprop_outputs:])
else:
relevant_outputs = outputs
for output_index, output in enumerate(relevant_outputs):
if trainable_recorded_outputs < backward_function_inputs:
recorded_outputs.append(output)
if backprop_util.IsTrainable(output):
trainable_recorded_outputs += 1
else:
skip_positions.append(output_index)
if output.dtype == dtypes.variant:
variant_zeros_like[output_index] = default_gradient.zeros_like(output)
def _backward_function_wrapper(*args):
"""Process output gradients and call the backward function."""
if not backward.outputs:
return backward.structured_outputs
processed_args = []
input_index = 0
for output_index, arg in enumerate(args):
# Convert IndexedSlices to dense tensors. The IndexedSlices optimization
# is only really effective when doing tf.gather(variable) as the
# adjoint functions for most operations are unlikely to preserve the
# sparsity in IndexedSlices.
if isinstance(arg, ops.IndexedSlices):
arg = ops.convert_to_tensor(arg)
if output_index in skip_positions:
continue
if arg is None:
# We're calling a (non-polymorphic) ConcreteFunction, so we need to
# have a Tensor value for each Tensor we thought would be trainable
# based on its dtype, even if it ended up being unconnected.
input_placeholder = backward.inputs[
input_index]
if input_placeholder.dtype == dtypes.variant:
arg = variant_zeros_like[output_index]
else:
arg = array_ops.zeros(
*default_gradient.shape_and_dtype(input_placeholder))
processed_args.append(arg)
input_index += 1
if input_index >= backward_function_inputs:
break
return backward._call_flat( # pylint: disable=protected-access
processed_args, remapped_captures)
return _backward_function_wrapper, recorded_outputs
def record(self, flat_outputs, inference_args, input_tangents):
"""Record the function call operation.
For backprop, indicates the backward function to use and which new Tensors
must be watched. For forwardprop from eager, the function call itself will
have produced tangents which need to be recorded.
Args:
flat_outputs: The result of running `forward`.
inference_args: A flat list of Tensors with inference inputs to the
operation.
input_tangents: A flat list of Tensors with input tangents consumed by the
operation.
"""
backward_function, to_record = self._wrap_backward_function(
self._forward_graph, self._backward, flat_outputs)
if self._forwardprop_output_indices:
tape.record_operation_backprop_only(
self._forward.signature.name,
to_record, inference_args,
backward_function)
tape.record_operation_forwardprop_only(
self._forward.signature.name,
flat_outputs, inference_args + input_tangents,
backward_function,
self._forwardprop_output_indices)
else:
tape.record_operation(self._forward.signature.name,
to_record, inference_args + input_tangents,
backward_function)
class _FirstOrderTapeGradientFunctions(_TapeGradientFunctions):
"""Caches tape-friendly functions for first-order gradients."""
def __init__(self, func_graph, attrs, func_graph_deleter,
forwardprop_input_indices, delayed_rewrite_functions,
need_gradients_for_jvps):
super(_FirstOrderTapeGradientFunctions, self).__init__(
func_graph, attrs, func_graph_deleter, forwardprop_input_indices,
delayed_rewrite_functions, need_gradients_for_jvps)
self._func_graph_deleter = func_graph_deleter
self._forwardprop_input_indices = forwardprop_input_indices
def _forward_and_backward_functions(self, inference_args, input_tangents):
"""Shortcut for when only first-order gradients are required.
The returned backward function does not accept gradients with respect to
side output of forward_function. This is fine as long as the user can't
possibly request second order tape gradients, as when they've used a single
non-persistent GradientTape. Since we don't need the backward function to
take gradients with respect to side outputs, we can skip some potentially
slow graph building.
Args:
inference_args: A flat list of Tensors, arguments to the inference
function.
input_tangents: A flat list of Tensors, jvps associated with
`inference_args`.
Returns:
A tuple of (forward_function, backward_function):
forward_function: Takes the same inputs as the inference function, but
returns side outputs used by backward_function in addition to the
inference function's outputs.
backward_function: Takes side outputs from forward_function and
gradients with respect to the "real" outputs of forward_function and
returns gradients with respect to the inputs.
"""
outputs = self._func_graph.outputs[:self._num_inference_outputs]
return self._build_functions_for_outputs(
outputs, inference_args, input_tangents)
class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions):
"""Caches tape-friendly functions for higher-order gradients."""
# TODO(b/136189779): Cond/while under a tape may need similar logic. Consider
# generalizing if so.
def _forward_and_backward_functions(self, inference_args, input_tangents):
"""Forward and backward functions suitable for higher-order gradients.
Unlike in `_FirstOrderTapeGradientFunctions`, the backward function built by
this method accepts gradients for all of the outputs of the returned forward
function, including side outputs.
Args:
inference_args: A flat list of Tensors, arguments to the inference
function.
input_tangents: A flat list of Tensors, jvps associated with
`inference_args`.
Returns:
A tuple of (forward_function, backward_function):
forward_function: Takes the same inputs as the inference function, but
returns side outputs used by backward_function in addition to the
inference function's outputs.
backward_function: Takes side outputs from forward_function and
gradients with respect to all of its outputs, real and side. Returns
gradients with respect to the inputs.
"""
outputs = []
# First we need to figure out how many side outputs from the forward pass
# will be required. We do this in a temporary graph to avoid actually
# running multiple copies of the backward pass (one per _GradientsHelper
# call).
#
# While computing gradients, the backward function captures Tensors from
# the forward function. We add these as side outputs of the original
# function. However, we then need to accept output gradients with respect
# to these side outputs for higher order gradients to work. Thus we loop
# until the number of outputs of the function stabilizes. Note that this
# is only required for tape gradients, where we need to declare in advance
# all of the forward op's outputs: symbolic gradients with tf.gradients
# instead rely on regenerating backward functions when higher-order
# gradients are requested.
while len(outputs) < len(self._func_graph.outputs):
outputs = list(self._func_graph.outputs)
self._build_functions_for_outputs(
outputs, inference_args, input_tangents)
(forward_function, forward_graph,
backward_function, output_indices, num_output_tangents) = (
self._build_functions_for_outputs(
outputs, inference_args, input_tangents))
if len(self._func_graph.outputs) != len(outputs):
raise AssertionError(
("Unexpectedly added new outputs to the forward function when "
"building the backward function: {}").format(
self._func_graph.outputs[len(outputs):]))
return (forward_function, forward_graph, backward_function, output_indices,
num_output_tangents)
class _ForwardBackwardCall(object):
"""Holds the state of a function call between execution and recording."""
__slots__ = [
"_functions", "_inference_args", "_input_tangents", "_tape_watching"
]
def __init__(self, functions, inference_args, input_tangents, tape_watching):
"""Collects information about the function call.
Args:
functions: An object which produces forward and backward functions, either
a _DelayedRewriteGradientFunctions or a _TapeGradientFunctions object.
inference_args: A flat list of Tensors, arguments to the inference
function.
input_tangents: A flat list of Tensors, jvps associated with
`inference_args`.
tape_watching: Boolean, with True indicating that recording is necessary.
"""
self._functions = functions
self._inference_args = inference_args
self._input_tangents = input_tangents
self._tape_watching = tape_watching
def forward(self):
"""Builds or retrieves a forward function for this call."""
forward_function = self._functions.forward(
self._inference_args, self._input_tangents)
return forward_function, self._inference_args + self._input_tangents
def record(self, flat_outputs):
"""Given outputs from the execution of `forward`, records the operation."""
if (self._tape_watching
and not isinstance(flat_outputs, ops.Operation)
and flat_outputs is not None):
# We only record function calls which have outputs, and then only when a
# tape is watching.
self._functions.record(
flat_outputs, self._inference_args, self._input_tangents)
# Sentinel value used by with ConcreteFunction's structured signature to
# indicate that a non-tensor parameter should use the value that was
# specified when the concrete function was created.
_BOUND_VALUE = object()
class ConcreteFunction(object):
"""Callable object encapsulating a function definition and its gradient.
`ConcreteFunction` is a callable that encapsulates a function definition and
is differentiable under `tf.GradientTape` objects.
"""
def __init__(self,
func_graph,
attrs=None,
shared_func_graph=True,
function_spec=None):
"""Initialize a `ConcreteFunction`.
Args:
func_graph: An instance of FuncGraph: the function body to wrap.
attrs: (optional) dict mapping names of attributes to their AttrValue
values. Attributes in `attrs` will be included in this function's
definition.
shared_func_graph: If False, the ConcreteFunction takes ownership of
`func_graph` and will break reference cycles when it is deleted. This
makes the FuncGraph inoperable.
function_spec: FunctionSpec for the original function. If not specified,
then this ConcreteFunction may only be called using the flat signature.
Raises:
ValueError: If number of input_placeholders is not equal to the number
of function inputs.
"""
# _arg_keywords and _num_positional_args define the flat signature. They
# are assigned after construction.
self._arg_keywords = None
self._num_positional_args = None
self._func_graph = func_graph
self._captured_inputs = self._func_graph.external_captures
self._captured_closures = self._func_graph.deferred_external_captures
structured_outputs = self._func_graph.structured_outputs
self._ndarrays_list = (
isinstance(structured_outputs, (list, tuple)) and structured_outputs and
all(isinstance(o, np_arrays.ndarray) for o in structured_outputs))
self._ndarray_singleton = isinstance(structured_outputs, np_arrays.ndarray)
# function_spec defines the structured signature.
self._set_function_spec(function_spec)
if attrs and IMPLEMENTS_ATTRIBUTE_NAME in attrs:
# The alternative is to silently drop "implements" tag
# but it seems likely it would lead to hard to catch bugs.
# Another alternative is to make func_body to preserve the order
# of arguments if variables are present. Yet another option
# is to automatically replace variables as arguments to functions
# to v.read_value() whenever "implements" tag is present
# Anytime we annotate existing function we probably want to wrap
# it with safe read_value for backward compatibility.
has_resource_vars = any(inp.dtype == dtypes.resource
for inp in self.inputs)
assert not any(
(has_resource_vars, self._captured_inputs, self._captured_closures)
), ('Function {name} has "{attr}={value}" attribute and thus can not '
"depend on any tensors outside of its signature or modify variables. "
"\n\nNote: variables are always captured and cause function "
"re-tracing for every variable called.\n"
" inputs: {inputs}\n captures: {captured}\n"
" closures: {closures}.\n\n"
"To pass a variable to such function use "
"use variable.read_value().".format(
name=func_graph.name,
attr=IMPLEMENTS_ATTRIBUTE_NAME,
value=attrs[IMPLEMENTS_ATTRIBUTE_NAME],
inputs=self.inputs,
captured=self._captured_inputs,
closures=self._captured_closures))
self._output_shapes = tuple(
output.shape for output in self._func_graph.outputs)
self._attrs = _parse_func_attrs(attrs or {})
if shared_func_graph:
self._garbage_collector = None
else:
self._garbage_collector = ConcreteFunctionGarbageCollector(func_graph)
# Pairs of forward and backward functions used for computing gradients.
#
# These each get a reference to the FuncGraph deleter since they use the
# FuncGraph directly.
self._delayed_rewrite_functions = _DelayedRewriteGradientFunctions(
func_graph, self._attrs, self._garbage_collector)
self._first_order_tape_functions = {}
self._higher_order_tape_functions = {}
# Cache the inference function to avoid a (Python) function call when not
# building gradients.
self._inference_function = self._delayed_rewrite_functions.forward()
def _set_function_spec(self, function_spec):
"""Enables the structured signature by supplying a function_spec."""
self._function_spec = None
self._pre_initialized_function_spec = function_spec
# Note: when ConcreteFunctions are built by recreate_function() in
# function_deserialization.py, they don't have a structured_input_signature
# yet. In that case, _initialize_function_spec() gets called by
# _setup_functions_structures() in load.py.
if (function_spec is not None and
self.structured_input_signature is not None):
self._initialize_function_spec()
def _initialize_function_spec(self):
"""Updates `self._function_spec` to include varargs and bound variables.
Adds new positional arguments for any varargs (i.e., for args that are
in `structured_input_signature`, but not in the original fullargspec.args).
Replaces `defaults` and `kwonlydefaults` with the `_BOUND_VALUE`, for
all args and kwargs in `structured_input_signature`.
Sets `varkw` and `varargs` to None.
"""
if self._pre_initialized_function_spec is None:
return # e.g., SavedBareConcreteFunction doesn't have function_spec yet.
assert not self._function_spec, "already initialized"
function_spec = self._pre_initialized_function_spec
args = function_spec.fullargspec.args
arg_specs, kwarg_specs = self.structured_input_signature
vararg_indices = range(len(function_spec.arg_names), len(arg_specs))
fullargspec = tf_inspect.FullArgSpec(
args=list(args) + ["<arg{}>".format(i + 1) for i in vararg_indices],
varargs=None,
varkw=None,
defaults=[_BOUND_VALUE] * len(arg_specs),
kwonlyargs=list(sorted(kwarg_specs)),
kwonlydefaults=dict((k, _BOUND_VALUE) for k in kwarg_specs),
annotations=function_spec.fullargspec.annotations)
self._function_spec = FunctionSpec(
fullargspec,
function_spec.is_method,
function_spec.input_signature,
function_spec.is_pure,
name=self._func_graph.name)
@property
def variables(self):
"""Sequence of variables for this function."""
return tuple(self._func_graph.variables)
@property
def trainable_variables(self):
"""Sequence of trainable variables for this function."""
return tuple(self._func_graph.trainable_variables)
def __call__(self, *args, **kwargs):
"""Executes the wrapped function.
ConcreteFunctions have two signatures:
* The signature of the original function wrapped by this ConcreteFunction.
* A flat signature, where each argument accepts a single Tensor.
The original function signature is generally preferred, but the flat input
signature is supported for backward compatibility.
### Original Function Signature
When calling a ConcreteFunction with the signature of the original function,
each argument must match the type or value that was used when the
ConcreteFunction's graph was traced. In particular:
* Tensor arguments (including CompositeTensors, such as RaggedTensor) must
have matching `TypeSpec`s.
* Non-Tensor arguments (such as booleans or ints) must have equal values.
* Nested arguments (such as lists, tuples, or dictionaries) must have the
same nesting structure; and each nested value must have a matching type
or value.
The default value for any arguments that were traced with non-Tensor values
is the value that was used in the trace. Arguments that were traced with
tensor arguments do not have a default value (even if the original function
had a default value for that argument).
### Flat Signature
When calling a ConcreteFunction with the flat signature, the arguments
correspond to the flattened component tensors of the arguments that were
used to construct the ConcreteFunction. Parameter names are assigned based
on `TensorSpec.name` (when specified) or the original argument names (with
suffixes automatically added for nested arguments or composite tensors with
multiple components).
Args:
*args: Positional arguments to the concrete function.
**kwargs: Keyword arguments to the concrete function.
Returns:
The result of applying the TF function on the given Tensors.
Raises:
AssertionError: If this `ConcreteFunction` was not created through
`get_concrete_function`.
TypeError: If the arguments do not match the function's signature.
"""
return self._call_impl(args, kwargs)
def _call_impl(self, args, kwargs, cancellation_manager=None):
"""See `__call__` for details."""
with trace.Trace(self._func_graph.name, tf_function_call="concrete"):
# Construct the list of input tensors: check if the structured signature
# applies first; and if not, then use the flat signature.
if self._function_spec is not None:
try:
return self._call_with_structured_signature(args, kwargs,
cancellation_manager)
except TypeError as structured_err:
try:
return self._call_with_flat_signature(args, kwargs,
cancellation_manager)
except TypeError:
raise structured_err
return self._call_with_flat_signature(args, kwargs, cancellation_manager)
def _call_with_flat_signature(self, args, kwargs, cancellation_manager):
"""Executes the wrapped function with the flat signature.
Args:
args: Positional arguments to the concrete function.
kwargs: Keyword arguments to the concrete function.
cancellation_manager: A `CancellationManager` that can be used to cancel
function invocation.
Returns:
The result of applying the function on the Tensors/Variables contained in
`args` and `kwargs`.
Raises:
TypeError: if `args` and `kwargs` do not match the flat signature of this
`ConcreteFunction`.
"""
if len(args) > self._num_positional_args:
raise TypeError(
"{} takes {} positional arguments but {} were given".format(
self._flat_signature_summary(), self._num_positional_args,
len(args)))
args = list(args)
kwargs = dict(kwargs)
for keyword in self._arg_keywords[len(args):]:
try:
args.append(kwargs.pop(compat.as_str(keyword)))
except KeyError:
specified_keywords = (
list(self._arg_keywords[:len(args)]) + list(kwargs.keys()))
raise TypeError("{} missing required arguments: {}".format(
self._flat_signature_summary(), ", ".join(
sorted(set(self._arg_keywords) - set(specified_keywords)))))
if kwargs:
positional_arg_keywords = set(self._arg_keywords[:len(args)])
for unused_key in kwargs:
if unused_key in positional_arg_keywords:
raise TypeError("{} got two values for argument '{}'".format(
self._flat_signature_summary(), unused_key))
raise TypeError("{} got unexpected keyword arguments: {}.".format(
self._flat_signature_summary(), ", ".join(sorted(kwargs))))
for i, arg in enumerate(args):
if not isinstance(
arg, (ops.Tensor, resource_variable_ops.BaseResourceVariable)):
raise TypeError("{}: expected argument #{}(zero-based) to be a Tensor; "
"got {} ({})".format(self._flat_signature_summary(), i,
type(arg).__name__, str(arg)))
return self._call_flat(args, self.captured_inputs, cancellation_manager)
def _call_with_structured_signature(self, args, kwargs, cancellation_manager):
"""Executes the wrapped function with the structured signature.
Args:
args: Positional arguments to the concrete function.
kwargs: Keyword arguments to the concrete function.
cancellation_manager: A `CancellationManager` that can be used to cancel
function invocation.
Returns:
The result of applying the function on the Tensors/Variables contained in
`args` and `kwargs`.
Raises:
TypeError: if `args` and `kwargs` do not match the structured signature
of this `ConcreteFunction`.
"""
args, kwargs, _, filtered_flat_args = \
self._function_spec.canonicalize_function_inputs(*args, **kwargs)
self._structured_signature_check_missing_args(args, kwargs)
self._structured_signature_check_unexpected_args(args, kwargs)
self._structured_signature_check_arg_types(args, kwargs)
return self._call_flat(
filtered_flat_args,
captured_inputs=self.captured_inputs,
cancellation_manager=cancellation_manager)
def _structured_signature_check_missing_args(self, args, kwargs):
"""Raises a TypeError if any args are missing."""
arg_specs, kwarg_specs = self.structured_input_signature
missing_arguments = []
for i, (arg, spec) in enumerate(zip(args, arg_specs)):
if arg is _BOUND_VALUE and _contains_type_spec(spec):
missing_arguments.append(self._function_spec.arg_names[i])
for (name, arg) in kwargs.items():
if arg is _BOUND_VALUE and _contains_type_spec(kwarg_specs[name]):
missing_arguments.append(name)
if missing_arguments:
raise TypeError("{} missing required arguments: {}".format(
self._structured_signature_summary(),
", ".join(sorted(missing_arguments))))
def _structured_signature_check_unexpected_args(self, args, kwargs):
"""Raises a TypeError if there are any extra args."""
arg_specs, kwarg_specs = self.structured_input_signature
if len(args) > len(arg_specs):
raise TypeError(
"{} takes {} positional arguments but {} were given".format(
self._structured_signature_summary(),
len(self._function_spec.arg_names), len(args)))
if len(kwargs) > len(kwarg_specs):
extra_args = set(kwargs) - set(kwarg_specs)
raise TypeError("{} got unexpected keyword arguments: {}".format(
self._structured_signature_summary(), ", ".join(extra_args)))
def _structured_signature_check_arg_types(self, args, kwargs):
"""Raises a TypeError if any args have the wrong type."""
# Check argument types
arg_specs, kwarg_specs = self.structured_input_signature
for i, (arg, spec) in enumerate(zip(args, arg_specs)):
name = self._function_spec.arg_names[i]
self._structured_signature_check_arg_type(arg, spec, name)
for (name, arg) in kwargs.items():
self._structured_signature_check_arg_type(arg, kwarg_specs[name], name)
def _structured_signature_check_arg_type(self, arg, spec, name):
"""Raise TypeError if `arg`'s type doesn't match `spec`."""
if arg is _BOUND_VALUE:
return
# Check the overall nested structure of the argument.
try:
nest.assert_same_structure(arg, spec, expand_composites=True)
except (ValueError, TypeError):
try:
nest.assert_same_structure(arg, spec, expand_composites=False)
expected, got = spec, arg
except (ValueError, TypeError):
expected, got = _structure_summary(spec), _structure_summary(arg)
raise TypeError("{}: argument {} had incorrect type\n"
" expected: {}\n got: {}".format(
self._structured_signature_summary(), name, expected,
got))
# Check the type for each leaf in the nested structure.
arg_pieces = nest.flatten(arg, expand_composites=True)
spec_pieces = nest.flatten(spec, expand_composites=True)
for (arg_piece, spec_piece) in zip(arg_pieces, spec_pieces):
if isinstance(spec_piece, tensor_spec.DenseSpec):
# TODO(edloper): Consider calling convert_to_tensor on non-tensor
# values here. That would match the behavior of
# _call_concrete_function() in function_deserialization.py. If
# we do, then we need to change the nest assert_same_structure and
# flatten calls above to use shallow variants.
tensor_types = (ops.Tensor, resource_variable_ops.BaseResourceVariable)
if not isinstance(arg_piece, tensor_types):
raise TypeError(
"{} expected a Tensor in {}, but got {} value {}".format(
self._structured_signature_summary(), name,
type(arg_piece).__name__, arg_piece))
elif arg_piece is not _BOUND_VALUE and arg_piece != spec_piece:
raise TypeError("ConcreteFunction {} was constructed with {} value "
"{} in {}, but was called with {} value {}".format(
self._structured_signature_summary(),
type(spec_piece).__name__, spec_piece, name,
type(arg_piece).__name__, arg_piece))
def _call_flat(self, args, captured_inputs, cancellation_manager=None):
"""Executes the wrapped function.
Args:
args: a list of Tensors or Variables. Arguments from the Python function
should be filtered before calling this method: objects aside from
Tensors, CompositeTensors, and Variables are ignored. Any
CompositeTensors should be expanded before calling this method.
captured_inputs: the captured inputs that are also part of the input args
to the actual execution. By default, it should be self._captured_inputs.
cancellation_manager: (Optional.) A `CancellationManager` that can be
used to cancel function invocation.
Returns:
The result of applying the TF function to `args`.
Raises:
ValueError: If `args` contains anything other than Tensors or Variables.
"""
ctx = context.context()
executing_eagerly = ctx.executing_eagerly()
# Copy saveable status of function's graph to current FuncGraph.
default_graph = ops.get_default_graph()
if default_graph.building_function and not self._func_graph.saveable:
default_graph.mark_as_unsaveable(self._func_graph.saving_errors)
if (tape.could_possibly_record() or
hasattr(default_graph, "watch_variable")):
for v in self._func_graph.variables:
resource_variable_ops.variable_accessed(v)
tensor_inputs = []
variables_used = set([])
for i, arg in enumerate(args):
if isinstance(arg, resource_variable_ops.BaseResourceVariable):
# We can pass a variable more than once, and in this case we need to
# pass its handle only once.
if id(arg.handle) in variables_used:
continue
resource_variable_ops.variable_accessed(arg)
tensor_inputs.append(arg.handle)
variables_used.add(id(arg.handle))
elif isinstance(arg, ops.Tensor):
tensor_inputs.append(arg)
if not executing_eagerly:
# If we're graph building, shape inference is on. We check for input
# compatibility up front to avoid hard to debug incompatibilities
# later.
graph_input_shape = tensor_shape.TensorShape(
self._func_graph.inputs[i].shape)
if not graph_input_shape.is_compatible_with(arg.shape):
if self._arg_keywords:
arg_name = "'{}'".format(self._arg_keywords[i])
else:
arg_name = "with index {}".format(i)
raise ValueError(
("The argument {} (value {}) is not compatible with the shape "
"this function was traced with. Expected shape {}, but got "
"shape {}.\n\nIf you called get_concrete_function, you may "
"need to pass a tf.TensorSpec(..., shape=...) with a less "
"specific shape, having None on axes which can vary.").format(
arg_name, arg,
self._func_graph.inputs[i].shape,
arg.shape))
else:
raise ValueError("All inputs to `ConcreteFunction`s must be Tensors; "
"on invocation of %s, the %d-th input (%s) was not a "
"Tensor." % (self._func_graph.name, i, str(arg)))
args = tensor_inputs + captured_inputs
possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args)
if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE
and executing_eagerly):
# No tape is watching; skip to running the function.
return self._build_call_outputs(self._inference_function.call(
ctx, args, cancellation_manager=cancellation_manager))
forward_backward = self._select_forward_and_backward_functions(
args,
possible_gradient_type,
executing_eagerly)
forward_function, args_with_tangents = forward_backward.forward()
if executing_eagerly:
flat_outputs = forward_function.call(
ctx, args_with_tangents, cancellation_manager=cancellation_manager)
else:
with default_graph._override_gradient_function( # pylint: disable=protected-access
{"PartitionedCall": self._get_gradient_function(),
"StatefulPartitionedCall": self._get_gradient_function()}):
flat_outputs = forward_function.call(ctx, args_with_tangents)
forward_backward.record(flat_outputs)
return self._build_call_outputs(flat_outputs)
def _experimental_with_cancellation_manager(self, cancellation_manager):
"""Returns a callable that invokes a cancellable version of this function.
Args:
cancellation_manager: A `CancellationManager` object that can be used to
cancel function invocation.
Returns:
A callable with the same signature as this concrete function.
"""
def cancellable_call(*args, **kwargs):
return self._call_impl(
args, kwargs, cancellation_manager=cancellation_manager)
return cancellable_call
@property
def name(self):
"""`ConcreteFunction` name."""
return self._delayed_rewrite_functions.forward().name
@property
def graph(self):
"""Returns the graph from which this function was constructed."""
return self._func_graph
@property
def inputs(self):
"""Returns tensors in `self.graph` corresponding to arguments."""
return self._func_graph.inputs
@property
def structured_input_signature(self):
"""Returns structured signature for this concrete function.
Returns:
A tuple `(args, kwargs)`, where:
* `args` is a tuple that specifies the expected type or value each for
positional argument.
* `kwargs` is a dictionary that specifies the expected type or value
for each keyword-only argument.
The type or value for each argument is specified using one of the
following:
* A `tf.TypeSpec`, indicating that a Tensor or other TensorFlow-native
value is expected.
* A Python value, such as an integer, indicating that an equal value
is expected.
* A nested structure of `tf.TypeSpec`s and Python values, indicating
that a corresponding nested structure is expected.
"""
return self._func_graph.structured_input_signature
@property
def outputs(self):
"""Returns tensors in `self.graph` corresponding to returned tensors."""
return self._func_graph.outputs
@property
def structured_outputs(self):
"""Returns outputs in `self.graph` as returned by the original function."""
return self._func_graph.structured_outputs
@property
def captured_inputs(self):
"""Returns external Tensors captured by this function.
self.__call__(*args) passes `args + self.captured_inputs` to the function.
"""
from_closures = nest.flatten([x() for x in self._captured_closures],
expand_composites=True)
return self._captured_inputs + from_closures
@property
def function_def(self):
"""Returns a `FunctionDef` object representing this function."""
return self._delayed_rewrite_functions.forward().definition
@property
def output_shapes(self):
"""The function's output shapes."""
return nest.map_structure(
lambda x: getattr(x, "shape", tensor_shape.TensorShape(None)),
composite_tensor.replace_composites_with_components(
self._func_graph.structured_outputs),
expand_composites=False)
@property
def output_dtypes(self):
# TODO(akshayka): Consider removing this.
return nest.map_structure(
lambda x: x.dtype if x is not None else None,
composite_tensor.replace_composites_with_components(
self._func_graph.structured_outputs),
expand_composites=False)
def add_to_graph(self, g=None):
"""Registers the function, adds it to the graph g or default graph.
Args:
g: If specified, registers the function with this graph. Defaults to the
current context (either the default graph or the eager context).
"""
# If we are not executing eagerly, adds the function to default graph if no
# graph is specified.
# In case of eager execution, function definition gets added to context
# during construction itself.
if not context.executing_eagerly() and not g:
g = ops.get_default_graph()
self._delayed_rewrite_functions.forward().add_to_graph(g)
def add_gradient_functions_to_graph(self, g=None):
"""Add forward/backward functions to graph `g` or the current context."""
if not context.executing_eagerly() and not g:
g = ops.get_default_graph()
self._delayed_rewrite_functions.forward().add_to_graph(g)
forward_function, backward_function = (
self._delayed_rewrite_functions.forward_backward())
forward_function.add_to_graph(g)
backward_function.add_to_graph(g)
def _get_gradient_function(self):
"""Returns gradient function. It will be lazily created at first call."""
return self._delayed_rewrite_functions._rewrite_forward_and_call_backward # pylint: disable=protected-access
def _select_forward_and_backward_functions(
self, args, possible_gradient_type, executing_eagerly):
"""Selects forward and backward functions based on the calling context.
The forward function computes the "real" function outputs, `self._outputs`,
and any extra values needed by the corresponding backward function.
Args:
args: A flat list of Tensors with all of the inputs to the forward
function (including user-specified and captured inputs).
possible_gradient_type: One of gradients_util.POSSIBLE_GRADIENT_TYPES_*.
executing_eagerly: Boolean, the value of context.executing_eagerly().
Returns:
An object with a `forward` method returning a tuple of (forward_function :
_EagerDefinedFunction, augmented_arguments : List), and a corresponding
`record` method which takes outputs from the forward function and records
the operation. forward_function should be called with augmented_arguments.
"""
if executing_eagerly:
input_tangents = forwardprop_util.pack_tangents(args)
else:
input_tangents = forwardprop_util.TangentInfo()
need_gradients_for_jvps = tape.should_record_backprop(
input_tangents.tangents)
# Allows re-use of forward and backward function pairs depending on the
# tapes and forward accumulators watching its inputs.
cache_key = (need_gradients_for_jvps, input_tangents.indices)
if (possible_gradient_type
== gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER):
if input_tangents.indices or executing_eagerly:
# There is a single non-persistent tape active, so the user can only
# request first-order gradients from a tape. We can spend less time
# graph building since we know this.
#
# We may still end up computing higher-order gradients, but that'd be
# through `tf.gradients`, which can re-write the forward pass and so
# needs no preparation here.
functions = self._first_order_tape_functions.get(cache_key, None)
if functions is None:
functions = _FirstOrderTapeGradientFunctions(
self._func_graph, self._attrs, self._garbage_collector,
forwardprop_input_indices=input_tangents.indices,
delayed_rewrite_functions=self._delayed_rewrite_functions,
need_gradients_for_jvps=need_gradients_for_jvps)
self._first_order_tape_functions[cache_key] = functions
return _ForwardBackwardCall(
functions, args, input_tangents.tangents, tape_watching=True)
else:
# We can avoid computing second-order gradients in some cases by doing a
# delayed rewrite when graph building. Since we know we'll only compute
# first-order tape gradients, the delayed rewrite is safe: we won't need
# to tell the tape about side outputs.
#
# TODO(allenl): This case is really dirty. It would be better if we
# could temporarily pop all of the current tapes to avoid
# accidentally taking second-order gradients.
return _ForwardBackwardCall(
self._delayed_rewrite_functions, args, input_tangents.tangents,
tape_watching=True)
elif (possible_gradient_type
== gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER):
# Either there's a persistent tape watching, or there are multiple nested
# tapes. Either way, the user may request higher-order gradients. We'll
# spend a bit more time and make sure higher-order gradients are correct.
functions = self._higher_order_tape_functions.get(
cache_key, None)
if functions is None:
functions = _HigherOrderTapeGradientFunctions(
self._func_graph, self._attrs, self._garbage_collector,
forwardprop_input_indices=input_tangents.indices,
delayed_rewrite_functions=self._delayed_rewrite_functions,
need_gradients_for_jvps=need_gradients_for_jvps)
self._higher_order_tape_functions[cache_key] = functions
return _ForwardBackwardCall(functions, args, input_tangents.tangents,
tape_watching=True)
# else possible_gradient_type == POSSIBLE_GRADIENT_TYPES_NONE, meaning no
# tape is recording.
return _ForwardBackwardCall(
self._delayed_rewrite_functions, args, input_tangents.tangents,
tape_watching=False)
def _build_call_outputs(self, result):
"""Maps the fdef output list to actual output structure.
Args:
result: Output lists defined by FunctionDef.
Returns:
The actual call output.
"""
# TODO(jlchu): call C++ version in function.cc when speed is improved
if self._func_graph.structured_outputs is None:
return result
if result:
if self._ndarrays_list:
return [np_arrays.tensor_to_ndarray(o) for o in result]
elif self._ndarray_singleton:
return np_arrays.tensor_to_ndarray(result[0])
# Replace outputs with results, skipping over any 'None' values.
outputs_list = nest.flatten(
self._func_graph.structured_outputs, expand_composites=True)
j = 0
for i, o in enumerate(outputs_list):
if o is not None:
custom_gradient.copy_handle_data(self.outputs[j], result[j])
outputs_list[i] = result[j]
j += 1
ret = nest.pack_sequence_as(self._func_graph.structured_outputs,
outputs_list, expand_composites=True)
return ret
@property
def _as_name_attr_list(self):
"""Returns a `NameAttrList` representing this function."""
ret = attr_value_pb2.NameAttrList(name=self.name)
for name, value in self._attrs.items():
ret.attr[name].CopyFrom(value)
return ret
def _structured_signature_summary(self, default_values=False):
"""Returns a string summarizing this function's structured signature.
Args:
default_values: If true, then include default values in the signature.
Returns:
A `string`.
"""
# Note: we can't just use self._funcion_spec.signature_summary(), because
# that would show "_BOUND_VALUE" as the default value for all arguments.
assert self._function_spec is not None
arg_specs, kwarg_specs = self.structured_input_signature
arg_names = list(self._function_spec.arg_names)
# If an explicit input_signature is provided to @tf.function, then any
# arguments with defaults that are not covered by that explicit signature
# are simply dropped from the signature.
# TODO(b/159639913) Look into whether dropping arguments with default values
# from the signature is the right thing to do.
arg_names = arg_names[:len(arg_specs)]
if default_values:
for i in range(len(arg_names)):
if not _contains_type_spec(arg_specs[i]):
arg_names[i] += "={}".format(arg_specs[i])
if kwarg_specs:
arg_names.append("*")
for name, spec in kwarg_specs.items():
arg_names.append(name)
if default_values and not _contains_type_spec(spec):
arg_names[-1] += "={}".format(spec)
signature = "{}({})".format(self._func_graph.name, ", ".join(arg_names))
return signature
def _flat_signature_summary(self):
"""Returns a string summarizing this function's flat signature."""
assert self._arg_keywords is not None
assert self._num_positional_args is not None
arg_names = self._arg_keywords
if self._num_positional_args > len(arg_names):
arg_names.extend(
"<arg{}>".format(i + 1)
for i in range(len(arg_names), self._num_positional_args))
return "{}({})".format(self._func_graph.name, ", ".join(arg_names))
def pretty_printed_signature(self, verbose=True):
"""Returns a string summarizing the signature of this concrete function."""
if not verbose:
return self._structured_signature_summary(default_values=True)
def pretty_print_spec(spec):
"""Returns a string describing the spec for a single argument."""
if isinstance(spec, tensor_spec.TensorSpec):
return "{} Tensor, shape={}".format(spec.dtype.name, spec.shape)
elif nest.is_sequence(spec):
pieces = nest.flatten(spec, expand_composites=False)
markers = [_Marker("<{}>".format(i + 1)) for i in range(len(pieces))]
structure = nest.pack_sequence_as(spec, markers)
# Ensure dictionaries are sorted by key (for determinism)
result = pprint.pformat(structure, width=10000)
for (marker, piece) in zip(markers, pieces):
result += "\n {}: {}".format(marker, pretty_print_spec(piece))
return result
else:
return repr(spec)
lines = [self._structured_signature_summary(default_values=True)]
arg_specs, kwarg_specs = self.structured_input_signature
names = list(self._function_spec.arg_names)
# If an explicit input_signature is provided to @tf.function, then any
# arguments with defaults that are not covered by that explicit signature
# are simply dropped from the signature.
# TODO(b/159639913) Look into whether dropping arguments with default values
# from the signature is the right thing to do.
names = names[:len(arg_specs)]
names.extend(sorted(kwarg_specs))
specs = list(arg_specs) + list(kwarg_specs.values())
# note: we can skip bound args, since we already displayed thier bound
# value in the signature summary.
arg_details = []
for (name, spec) in zip(names, specs):
if _contains_type_spec(spec):
arg_details.append(" {}: {}".format(name, pretty_print_spec(spec)))
if arg_details:
lines.append(" Args:")
lines.extend(arg_details)
lines.append(" Returns:")
def spec_from_value(value):
# For loaded function, structured_outputs are already specs.
if isinstance(value, type_spec.TypeSpec):
return value
return type_spec.type_spec_from_value(value)
lines.append(" {}".format(
pretty_print_spec(
nest.map_structure(spec_from_value, self.structured_outputs))))
return "\n".join(lines)
def __repr__(self):
if self._function_spec is not None:
return "<ConcreteFunction {} at 0x{:X}>".format(
self.pretty_printed_signature(verbose=False), id(self))
elif not (self._num_positional_args is None or self._arg_keywords is None):
return "<ConcreteFunction {} at 0x{:X}>".format(
self._flat_signature_summary(), id(self))
else:
return object.__repr__(self)
def __str__(self):
if self._function_spec is not None:
return "ConcreteFunction {}".format(self.pretty_printed_signature())
else:
return self.__repr__()
_pywrap_utils.RegisterType("Tensor", ops.Tensor)
_pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor)
_pywrap_utils.RegisterType("IndexedSlices", ops.IndexedSlices)
def _deterministic_dict_values(dictionary):
return tuple(dictionary[key] for key in sorted(dictionary))
class FunctionSpec(object):
"""Specification of how to bind arguments to a function."""
@staticmethod
def from_function_and_signature(python_function,
input_signature,
is_pure=False,
experimental_follow_type_hints=False,
experimental_compile=None):
"""Create a FunctionSpec instance given a python function and signature.
Args:
python_function: a function to inspect
input_signature: a signature of the function (None, if variable)
is_pure: if True all input arguments (including variables and constants)
will be converted to tensors and no variable changes allowed.
experimental_follow_type_hints: see `tf.function`
experimental_compile: see `tf.function`
Returns:
instance of FunctionSpec
"""
fullargspec = tf_inspect.getfullargspec(python_function)
# Treat a wrapped partial function as a special case. For all arguments that
# were overridden with keywords in the partial:
# - remove the corresponding arguments,
# - remove the corresponding keywords.
_, unwrapped = tf_decorator.unwrap(python_function)
# TODO(b/131153379): Consider Python3's fullargspec.kwonlyargs and
# fullargspec.kwonlydefaults.
if isinstance(unwrapped, functools.partial):
# Also consider the Python3 case with kwonlydefaults.
if fullargspec.defaults or fullargspec.kwonlydefaults:
new_defaults = fullargspec.defaults
new_args = fullargspec.args
if fullargspec.defaults:
# To be able to canonicalize the function properly, we want to ignore
# default values that are overridden via a partial kwarg. For example:
#
# def func(a, b, c, d=5, e=7):
# return a, b, c, d, e
# p_func = functools.partial(tf.function(func, 10, e=9))
#
# Here we want to drop from the defaults the parameter `e`. If we
# forwarded the call to the partial function with a default for `e`
# we would get an error for passing two values for one parameter.
#
# Note that this has a limitation: we can only override parameters at
# the end of the parameter list.
#
# In this case we want to end up with 3 arguments (b, c, d) and 1
# default value (5). We do this by constructing a mask where 0 stands
# for a value that was overridden by a partial kwarg. The seemingly
# complicated logic below does just that - for arguments (b, c, d, e)
# we would get a mask (1, 1, 1, 0).
old_args = fullargspec.args
old_defaults = fullargspec.defaults
no_default = object()
num_args_without_defaults = len(old_args) - len(old_defaults)
left_padding = tuple([no_default] * num_args_without_defaults)
args_with_defaults = zip(old_args, left_padding + old_defaults)
# Create a mask where 0 stands for args that had a partial kwarg
# defined.
non_keyword_defaults_mask = [
0 if key in unwrapped.keywords else 1 for key in old_args
]
# Keep only arguments and defaults that were not kwargs of partial.
new_args_with_defaults = list(
itertools.compress(args_with_defaults, non_keyword_defaults_mask))
# Keep all args.
new_args = [arg for arg, _ in new_args_with_defaults]
# Keep only real default values.
new_defaults = [
default for _, default in new_args_with_defaults
if default is not no_default
]
fullargspec = tf_inspect.FullArgSpec(
args=new_args,
varargs=fullargspec.varargs,
varkw=fullargspec.varkw,
defaults=new_defaults,
kwonlyargs=[],
kwonlydefaults={},
annotations=fullargspec.annotations)
is_method = tf_inspect.ismethod(python_function)
# Get the function's name. Remove functools.partial wrappers if necessary.
while isinstance(python_function, functools.partial):
python_function = python_function.func
name = getattr(python_function, "__name__", "f")
return FunctionSpec(
fullargspec,
is_method,
input_signature,
is_pure=is_pure,
experimental_compile=experimental_compile,
experimental_follow_type_hints=experimental_follow_type_hints,
name=name)
def __init__(self,
fullargspec,
is_method,
input_signature,
is_pure=False,
experimental_follow_type_hints=False,
name=None,
experimental_compile=None):
"""Constructs a FunctionSpec describing a python function.
Args:
fullargspec: `tf_inspect.FullArgSpec` object describing the function.
is_method: True if the function is a method.
input_signature: a signature of the function (None, if variable)
is_pure: if True all input arguments (including variables and constants)
will be converted to tensors and no variable changes allowed.
experimental_follow_type_hints: see `tf.function`.
name: Name of the function
experimental_compile: see `tf.function`.
"""
self._fullargspec = fullargspec
self._is_method = is_method
self._is_pure = is_pure
self._experimental_compile = experimental_compile
self._experimental_follow_type_hints = experimental_follow_type_hints
# TODO(edloper): Include name when serializing for SavedModel?
self._name = name or "f"
if self._is_method:
# Remove `self`: default arguments shouldn't be matched to it.
# TODO(b/127938157): Should this error out if there is no arg to
# be removed?
args = fullargspec.args[1:]
else:
args = fullargspec.args
# A cache mapping from argument name to index, for canonicalizing
# arguments that are called in a keyword-like fashion.
self._args_to_indices = {arg: i for i, arg in enumerate(args)}
self._arg_names = args
# A cache mapping from arg index to default value, for canonicalization.
default_values = fullargspec.defaults
offset = len(args) - len(default_values or [])
self._arg_indices_to_default_values = {
offset + index: default
for index, default in enumerate(default_values or [])
}
if input_signature is None:
self._input_signature = None
else:
if set(fullargspec.kwonlyargs) - set(fullargspec.kwonlydefaults or ()):
raise ValueError("Cannot define a TensorFlow function from a Python "
"function with keyword-only arguments when "
"input_signature is provided.")
if not isinstance(input_signature, (tuple, list)):
raise TypeError("input_signature must be either a tuple or a "
"list, received " + str(type(input_signature)))
self._input_signature = tuple(input_signature)
self._flat_input_signature = tuple(nest.flatten(input_signature,
expand_composites=True))
@property
def fullargspec(self):
return self._fullargspec
@property
def is_method(self):
return self._is_method
@property
def args_to_indices(self):
return self._args_to_indices
@property
def kwargs_to_include(self):
return self._kwargs_to_include
@property
def input_signature(self):
return self._input_signature
@property
def flat_input_signature(self):
return self._flat_input_signature
@property
def is_pure(self):
return self._is_pure
@property
def experimental_compile(self):
return self._experimental_compile
@property
def arg_names(self):
return self._arg_names
@property
def vararg_name(self):
return self._fullargspec.varargs
@property
def varkw_name(self):
return self._fullargspec.varkw
def signature_summary(self, default_values=False):
"""Returns a string summarizing this function's signature.
Args:
default_values: If true, then include default values in the signature.
Returns:
A `string`.
"""
args = list(self._arg_names)
if default_values:
for (i, default) in self._arg_indices_to_default_values.items():
args[i] += "={}".format(default)
if self._fullargspec.kwonlyargs:
args.append("*")
for arg_name in self._fullargspec.kwonlyargs:
args.append(arg_name)
if default_values and arg_name in self._fullargspec.kwonlydefaults:
args[-1] += "={}".format(self._fullargspec.kwonlydefaults[arg_name])
return "{}({})".format(self._name, ", ".join(args))
def _convert_variables_to_tensors(self, args, kwargs):
args = [ops.convert_to_tensor(x) for x in args]
kwargs = {kw: ops.convert_to_tensor(x) for kw, x in kwargs.items()}
return tuple(args), kwargs
def _convert_annotated_args_to_tensors(self, args, kwargs):
"""Attempts to autobox arguments annotated as tf.Tensor."""
if self.input_signature is not None:
return
args = list(args)
for i, arg in enumerate(args):
# See
# https://docs.python.org/3/library/inspect.html#inspect.getfullargspec
if i < len(self._fullargspec.args):
arg_annotation = self._fullargspec.annotations.get(
self._fullargspec.args[i])
# TODO(rahulkamat): Change to TensorLike (here ans below).
if arg_annotation == ops.Tensor:
args[i] = ops.convert_to_tensor(arg)
else:
varargs_annotation = self._fullargspec.annotations.get(
self._fullargspec.varargs)
if varargs_annotation == ops.Tensor:
args[i] = ops.convert_to_tensor(arg)
for kw, v in kwargs.items():
if kw in self._fullargspec.kwonlyargs:
kwonlyarg_annotation = self._fullargspec.annotations.get(kw)
if kwonlyarg_annotation == ops.Tensor:
kwargs[kw] = ops.convert_to_tensor(v)
elif self._fullargspec.varkw is not None:
varkw_annotation = self._fullargspec.annotations.get(
self._fullargspec.varkw)
if kw in self._fullargspec.args:
arg_annotation = self._fullargspec.annotations.get(kw)
if arg_annotation == ops.Tensor:
kwargs[kw] = ops.convert_to_tensor(v)
elif varkw_annotation == ops.Tensor:
kwargs[kw] = ops.convert_to_tensor(v)
return tuple(args), kwargs
def canonicalize_function_inputs(self, *args, **kwargs):
"""Canonicalizes `args` and `kwargs`.
Canonicalize the inputs to the Python function using a `FunctionSpec`
instance. In particular, we parse the varargs and kwargs that the
original function was called with into a tuple corresponding to the
Python function's positional (named) arguments and a dictionary
corresponding to its kwargs. Missing default arguments are added.
If this `FunctionSpec` has an input signature, then it is used to convert
arguments to tensors; otherwise, any inputs containing numpy arrays are
converted to tensors.
Additionally, any inputs containing numpy arrays are converted to Tensors.
Args:
*args: The varargs this object was called with.
**kwargs: The keyword args this function was called with.
Returns:
A canonicalized ordering of the inputs, as well as full and filtered
(Tensors and Variables only) versions of their concatenated flattened
representations, represented by a tuple in the form (args, kwargs,
flat_args, filtered_flat_args). Here: `args` is a full list of bound
arguments, and `kwargs` contains only true keyword arguments, as opposed
to named arguments called in a keyword-like fashion.
Raises:
ValueError: If a keyword in `kwargs` cannot be matched with a positional
argument when an input signature is specified, or when the inputs
do not conform to the input signature.
"""
if self._is_pure:
args, kwargs = self._convert_variables_to_tensors(args, kwargs)
if self._experimental_follow_type_hints:
args, kwargs = self._convert_annotated_args_to_tensors(args, kwargs)
if self._input_signature is not None:
if len(args) > len(self._input_signature):
raise TypeError("{} takes {} positional arguments (as specified by the "
"input_signature) but {} were given".format(
self.signature_summary(),
len(self._input_signature), len(args)))
for arg in six.iterkeys(kwargs):
index = self._args_to_indices.get(arg, None)
if index is None:
raise TypeError("{} got unexpected keyword argument `{}`".format(
self.signature_summary(), arg))
if index >= len(self._input_signature):
raise TypeError(
"{} got keyword argument `{}` that was not included in "
"input_signature".format(self.signature_summary(), arg))
if not kwargs:
inputs = args
if self._arg_indices_to_default_values:
try:
inputs += tuple(
self._arg_indices_to_default_values[i]
for i in range(len(args), len(self._arg_names)))
except KeyError:
missing_args = [
self._arg_names[i]
for i in range(len(args), len(self._arg_names))
if i not in self._arg_indices_to_default_values
]
raise TypeError("{} missing required arguments: {}".format(
self.signature_summary(), ", ".join(missing_args)))
if self._fullargspec.kwonlydefaults:
kwargs.update(self._fullargspec.kwonlydefaults)
else:
# Maps from index of arg to its corresponding value, according to `args`
# and `kwargs`; seeded with the default values for the named args that
# aren't in `args`.
arg_indices_to_values = {
index: default for index, default in six.iteritems(
self._arg_indices_to_default_values) if index >= len(args)
}
consumed_args = []
for arg, value in six.iteritems(kwargs):
index = self._args_to_indices.get(arg, None)
if index is not None:
if index < len(args):
raise TypeError("{} got two values for argument '{}'".format(
self.signature_summary(), arg))
arg_indices_to_values[index] = value
consumed_args.append(arg)
for arg in consumed_args:
# After this loop, `kwargs` will only contain keyword_only arguments,
# and all positional_or_keyword arguments have been moved to `inputs`.
kwargs.pop(arg)
inputs = args + _deterministic_dict_values(arg_indices_to_values)
if kwargs and self._input_signature is not None:
raise TypeError(
"{} got unexpected keyword arguments: {}\n(Cannot define a "
"TensorFlow function from a Python function with keyword arguments "
"when input_signature is provided.)".format(
self.signature_summary(), ", ".join(kwargs)))
if self._fullargspec.kwonlydefaults:
for (kwarg, default) in self._fullargspec.kwonlydefaults.items():
kwargs.setdefault(kwarg, default)
if self._input_signature is None:
inputs, flat_inputs, filtered_flat_inputs = _convert_numpy_inputs(inputs)
kwargs, flat_kwargs, filtered_flat_kwargs = _convert_numpy_inputs(kwargs)
return (inputs, kwargs, flat_inputs + flat_kwargs,
filtered_flat_inputs + filtered_flat_kwargs)
else:
assert not kwargs
inputs, flat_inputs, filtered_flat_inputs = _convert_inputs_to_signature(
inputs, self._input_signature, self._flat_input_signature)
return inputs, {}, flat_inputs, filtered_flat_inputs
def _as_ndarray(value):
"""Converts value to an ndarray, assumes _is_ndarray(value)."""
# TODO(tomhennigan) Support __array_interface__ too.
return value.__array__()
def _is_ndarray(value):
"""Tests whether the given value is an ndarray (and not a TF tensor/var)."""
# TODO(tomhennigan) Support __array_interface__ too.
return hasattr(value, "__array__") and not (
isinstance(value, ops.Tensor)
or isinstance(value, resource_variable_ops.BaseResourceVariable)
or hasattr(value, "_should_act_as_resource_variable")
# For legacy reasons we do not automatically promote Numpy strings.
or isinstance(value, np.str_)
# NumPy dtypes have __array__ as unbound methods.
or isinstance(value, type)
# CompositeTensors should be flattened instead.
or isinstance(value, composite_tensor.CompositeTensor))
def _convert_numpy_inputs(inputs):
"""Convert numpy array inputs to tensors."""
# We assume that any CompositeTensors have already converted their components
# from numpy arrays to Tensors, so we don't need to expand composites here for
# the numpy array conversion. Instead, we do so because the flattened inputs
# are eventually passed to ConcreteFunction()._call_flat, which requires
# expanded composites.
flat_inputs = nest.flatten(inputs, expand_composites=True)
# Check for NumPy arrays in arguments and convert them to Tensors.
# TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps
# finding a way to store them directly in the cache key (currently not
# possible since ndarrays are not hashable).
need_packing = False
filtered_flat_inputs = []
for index, value in enumerate(flat_inputs):
if isinstance(value,
(ops.Tensor, resource_variable_ops.BaseResourceVariable)):
filtered_flat_inputs.append(value)
elif hasattr(value, "__array__") and not (
hasattr(value, "_should_act_as_resource_variable") or
isinstance(value, (np.str_, type, composite_tensor.CompositeTensor))):
# This case is equivalent to _is_ndarray(value) == True
a = _as_ndarray(value)
if not isinstance(a, np.ndarray):
raise TypeError("The output of __array__ must be an np.ndarray "
"(got {} from {}).".format(type(a), type(value)))
flat_inputs[index] = constant_op.constant(a)
filtered_flat_inputs.append(flat_inputs[index])
need_packing = True
if need_packing:
return (nest.pack_sequence_as(
structure=inputs, flat_sequence=flat_inputs,
expand_composites=True), flat_inputs, filtered_flat_inputs)
else:
return inputs, flat_inputs, filtered_flat_inputs
def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature):
"""Convert inputs to pass into a function with an explicit signature."""
def format_error_message(inputs, input_signature):
return (" inputs: (\n" + " " + ",\n ".join(str(i) for i in inputs) +
")\n" + " input_signature: (\n" + " " +
",\n ".join(str(i) for i in input_signature) + ")")
try:
flatten_inputs = nest.flatten_up_to(
input_signature,
inputs[:len(input_signature)],
expand_composites=True,
check_types=False) # lists are convert to tuples for `tf.data`.
except ValueError:
raise ValueError("Structure of Python function inputs does not match "
"input_signature:\n%s" %
format_error_message(inputs, input_signature))
need_packing = False
for index, (value, spec) in enumerate(zip(flatten_inputs,
flat_input_signature)):
if (isinstance(spec, tensor_spec.TensorSpec) and
not _pywrap_utils.IsTensor(value)):
try:
flatten_inputs[index] = ops.convert_to_tensor(
value, dtype_hint=spec.dtype)
need_packing = True
except ValueError:
raise ValueError("When input_signature is provided, all inputs to "
"the Python function must be convertible to "
"tensors:\n%s" %
format_error_message(inputs, input_signature))
if any(not spec.is_compatible_with(other) for spec, other in zip(
flat_input_signature,
flatten_inputs)):
raise ValueError("Python inputs incompatible with input_signature:\n%s" %
format_error_message(inputs, input_signature))
if need_packing:
inputs = nest.pack_sequence_as(
structure=input_signature,
flat_sequence=flatten_inputs,
expand_composites=True)
flat_inputs = nest.flatten(inputs, expand_composites=True)
return (inputs, flat_inputs, [
t for t in flat_inputs
if isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
])
class FunctionCache(object):
"""A lightweight container for cached functions.
"""
__slots__ = [
"missed", "primary", "arg_relaxed_specs", "arg_relaxed",
"_garbage_collectors"
]
def __init__(self):
# The set of functions that have been missed; entries are CacheKey with
# input_signature `None` (e.g. a "call context key")
self.missed = set()
# The primary cache, mapping a fully shaped CacheKey to a function.
self.primary = collections.OrderedDict()
# A cache key lookup, mapping a CacheKey generated without shape info to a
# flat list of `TypeSpec`s with relaxed shapes (one for each flattened
# argument). Arguments that are not Tensors or `CompositeTensor`s contain a
# `None` for the corresponding relaxed spec.
self.arg_relaxed_specs = collections.OrderedDict()
# The secondary cache, mapping a CacheKey generated without shape info to a
# function.
self.arg_relaxed = collections.OrderedDict()
# All OrderedDicts require manual garbage collection.
self._garbage_collectors = [
_FunctionGarbageCollector(self.primary),
_FunctionGarbageCollector(self.arg_relaxed),
_FunctionGarbageCollector(self.arg_relaxed_specs)]
def all_values(self):
"""A set of all `ConcreteFunction` instances held by this cache."""
return set(self.primary.values()) | set(self.arg_relaxed.values())
class Function(object):
"""Wrapper class for the graph functions defined for a Python function.
See the documentation for `defun` for more information on the semantics of
defined functions.
`Function` class is thread-compatible meaning that minimal usage of defuns
(defining and calling) is thread-safe, but if users call other methods or
invoke the base `python_function` themselves, external synchronization is
necessary.
In addition, Function is not reentrant, so recursive functions need to call
the wrapped function, not the wrapper.
"""
def __init__(self,
python_function,
name,
input_signature=None,
attributes=None,
autograph=True,
autograph_options=None,
experimental_relax_shapes=False,
capture_by_value=None,
experimental_compile=None,
experimental_follow_type_hints=False):
"""Initializes a `Function`.
Args:
python_function: the function to be wrapped.
name: the name given to it.
input_signature: a possibly nested sequence of `TensorSpec` objects
specifying the input signature of this function. If `None`, a separate
function is instantiated for each inferred input signature.
attributes: dict, extra keyword arguments that will be added as attribute
of the function.
autograph: whether to use autograph to compile
`python_function`. See https://www.tensorflow.org/guide/autograph for
more information.
autograph_options: Experimental knobs to control behavior
`when autograph=True`. See https://www.tensorflow.org/guide/autograph
for more information.
experimental_relax_shapes: When true, argument shapes may be relaxed to
avoid unnecessary retracing.
capture_by_value: Experimental. Whether to capture resource variables by
value or reference. If None, will inherit from a parent context or
default to False.
experimental_compile: Force-compile the function with XLA, cf.
def_function.Function doc on experimental_compile.
experimental_follow_type_hints: See the documentation for `tf.function`.
Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
argspec has keyword arguments.
"""
self._python_function = python_function
pure_function = attributes and IMPLEMENTS_ATTRIBUTE_NAME in attributes
self._function_spec = FunctionSpec.from_function_and_signature(
python_function,
input_signature,
is_pure=pure_function,
experimental_follow_type_hints=experimental_follow_type_hints)
self._name = name
self._autograph = autograph
self._autograph_options = autograph_options
self._experimental_relax_shapes = experimental_relax_shapes
self._function_cache = FunctionCache()
self._function_attributes = attributes or {}
self._capture_by_value = capture_by_value
self.tracing_count = 0
if self.input_signature is not None:
self._hashable_input_signature = _make_input_signature_hashable(
self.flat_input_signature)
self._lock = threading.Lock()
# _descriptor_cache is a of instance of a class to an instance-specific
# `Function`, used to make sure defun-decorated methods create different
# functions for each instance.
self._descriptor_cache = weakref.WeakKeyDictionary()
self._experimental_compile = experimental_compile
self._experimental_follow_type_hints = experimental_follow_type_hints
def __call__(self, *args, **kwargs):
"""Calls a graph function specialized to the inputs."""
with self._lock:
(graph_function,
filtered_flat_args) = self._maybe_define_function(args, kwargs)
return graph_function._call_flat(
filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access
@property
def python_function(self):
"""Returns the wrapped Python function."""
return self._python_function # pylint: disable=protected-access
@property
def function_spec(self):
return self._function_spec
@property
def input_signature(self):
"""Returns the input signature."""
return self._function_spec.input_signature
@property
def flat_input_signature(self):
"""Returns the flattened input signature."""
return self._function_spec.flat_input_signature
def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs):
"""Returns a concrete function which cleans up its graph function."""
if self.input_signature:
args, kwargs = None, None
with self._lock:
graph_function, _ = self._maybe_define_function(args, kwargs)
return graph_function
def _get_concrete_function_internal(self, *args, **kwargs):
"""Bypasses error checking when getting a graph function."""
graph_function = self._get_concrete_function_internal_garbage_collected(
*args, **kwargs)
# We're returning this concrete function to someone, and they may keep a
# reference to the FuncGraph without keeping a reference to the
# ConcreteFunction object. So we won't clean up the reference cycles
# manually and instead will leave them to Python's garbage collector.
graph_function._garbage_collector.release() # pylint: disable=protected-access
return graph_function
def _get_concrete_function_garbage_collected(self, *args, **kwargs):
"""Returns a `ConcreteFunction` specialized to inputs and execution context.
Unlike `get_concrete_function(...)`, the graph will be deleted when the
returned function is deleted. It's useful to avoid creating a reference
cycle when you know for sure that the graph will be no longer used without
the returned function.
Args:
*args: inputs to specialize on.
**kwargs: inputs to specialize on.
"""
if self.input_signature:
if kwargs:
raise ValueError("Cannot define a TensorFlow function from a Python "
"function with keyword arguments when "
"input_signature is provided.")
if args:
# If args are provided, they must match the input signature.
if not is_same_structure(self.input_signature, args):
raise ValueError("Structure of Python function inputs does not match "
"input_signature.")
flat_inputs = nest.flatten(args, expand_composites=True)
if any(not isinstance(arg, (ops.Tensor, tensor_spec.DenseSpec,
resource_variable_ops.BaseResourceVariable))
for arg in flat_inputs):
raise ValueError("When input_signature is provided, all inputs to "
"the Python function must be Tensors, Variables, "
"tf.TensorSpec or tf.VariableSpec objects.")
if any(not spec.is_compatible_with(other)
for spec, other in zip(self.flat_input_signature, flat_inputs)):
raise ValueError("Python inputs incompatible with input_signature: "
"inputs (%s), input_signature (%s)" %
(str(args), str(self.input_signature)))
args, kwargs = None, None
with self._lock:
graph_function, _ = self._maybe_define_function(args, kwargs)
seen_names = set()
captured = object_identity.ObjectIdentitySet(
graph_function.graph.internal_captures)
# pylint: disable=protected-access
graph_function._arg_keywords = []
prefix_counts = {}
# pylint: enable=protected-access
num_positional = 0
for arg in graph_function.graph.inputs:
if arg in captured:
break
num_positional += 1
user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name"))
proposal = user_arg_name
while proposal in seen_names:
index = prefix_counts.get(user_arg_name, 1)
proposal = "{}_{}".format(user_arg_name, index)
prefix_counts[user_arg_name] = index + 1
seen_names.add(proposal)
graph_function._arg_keywords.append(proposal) # pylint: disable=protected-access
# Anything can be a positional argument, in the same order as .inputs
graph_function._num_positional_args = num_positional # pylint: disable=protected-access
return graph_function
def get_concrete_function(self, *args, **kwargs):
"""Returns a `ConcreteFunction` specialized to inputs and execution context.
Args:
*args: inputs to specialize on.
**kwargs: inputs to specialize on.
"""
graph_function = self._get_concrete_function_garbage_collected(
*args, **kwargs)
graph_function._garbage_collector.release() # pylint: disable=protected-access
return graph_function
def __get__(self, instance, owner):
"""Makes it possible to defun instance methods."""
del owner
# `instance` here is the instance that this `Function` was accessed through
# e.g., for
#
# class Foo(object):
#
# @function.defun
# def bar(self):
# ...
#
# foo = Foo()
# foo.bar() # `foo.bar` is a `Function` instance
#
# then `instance` will be `foo` (and `owner` will be `Foo`). We create a
# new instance of `Function` here to allow different instances each
# to create variables once, thereby allowing methods to be decorated with
# defun. Keeps a cache to avoid retracing the function every time the
# descriptor is accessed.
if instance not in self._descriptor_cache:
if instance is None:
return self
# If there is no instance-specific `Function` in the cache, we construct
# an instance-specific `Function` that uses a weak reference to the
# instance (so that the instance will be correctly gc'd).
# And finally add the wrapped function to the description cache
self._descriptor_cache[instance] = class_method_to_instance_method(
self, instance)
# Return the cached `Function` for the instance
return self._descriptor_cache[instance]
def _cache_key(self,
args,
kwargs,
cache_key_context,
include_tensor_ranks_only=False):
"""Computes the cache key given inputs and execution context."""
if self.input_signature is None:
inputs = (args, kwargs) if kwargs else args
input_signature = pywrap_tfe.TFE_Py_EncodeArg(inputs,
include_tensor_ranks_only)
hashable_input_signature = _make_input_signature_hashable(input_signature)
else:
del args, kwargs
assert not include_tensor_ranks_only
hashable_input_signature = self._hashable_input_signature
(parent_graph, device_functions, colocation_stack, in_cross_replica_context,
variable_policy, xla_context_id) = cache_key_context
return CacheKey(hashable_input_signature, parent_graph, device_functions,
colocation_stack, in_cross_replica_context, variable_policy,
xla_context_id)
def _cache_key_context(self):
"""Returns execution context."""
ctx = context.context()
# Don't need to open an init_scope if the _cache_key call is in eager mode
# already.
executing_eagerly = ctx.executing_eagerly()
parent_graph = None
xla_context_id = 0
if not executing_eagerly:
# We want to force function retracing for each different
# XLAControlFlowContext, so add `xla_context_id` to the cache key.
xla_context = _enclosing_xla_context()
if xla_context is not None and \
xla_context.RequiresUniqueFunctionRetracing():
xla_context_id = id(xla_context)
with ops.init_scope():
# The graph, or whether we're executing eagerly, should be a part of the
# cache key so we don't improperly capture tensors such as variables.
executing_eagerly = ctx.executing_eagerly()
parent_graph = None if executing_eagerly else ops.get_default_graph()
# pylint: disable=protected-access
default_graph = ops.get_default_graph()
# TODO(b/117617952): The current distribution strategy will affect graph
# building (e.g. accessing different variables from different devices) and
# so requires retracing for each device.
strategy_stack = default_graph._distribution_strategy_stack
uses_distribution_strategy = (
strategy_stack and
strategy_stack[-1].strategy.extended._retrace_functions_for_each_device
)
if executing_eagerly:
colocation_stack = ()
if uses_distribution_strategy:
device_functions = (pydev.merge_device(ctx.device_name),)
else:
device_functions = ()
else:
colocation_stack = tuple(default_graph._colocation_stack.peek_objs())
if (uses_distribution_strategy
or func_graph_module.device_stack_has_callable(
default_graph._device_function_stack)):
# Putting the device in the cache key ensures that call-site device
# annotations are respected.
device_functions = tuple(default_graph._device_functions_outer_to_inner)
else:
device_functions = ()
in_cross_replica_context = False
try:
in_cross_replica_context = (strategy_stack[-1].replica_context is None) # pylint: disable=protected-access
except (AttributeError, IndexError):
pass
if save_context.in_save_context():
variable_policy = (
save_context.get_save_options().experimental_variable_policy)
else:
variable_policy = None
return (parent_graph, device_functions, colocation_stack,
in_cross_replica_context, variable_policy, xla_context_id)
def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None):
"""Create a `ConcreteFunction` from `args` and `kwargs`."""
self.tracing_count += 1
if self.input_signature is None:
arglen = len(args)
else:
arglen = len(self.input_signature)
base_arg_names = self._function_spec.arg_names[:arglen]
num_missing_args = arglen - len(self._function_spec.arg_names)
missing_arg_names = [self._function_spec.vararg_name] * num_missing_args
# Produce a list of missing args of the form ["arg_0", "arg_1", ...],
# where arg is based on the self._function_spec.vararg_name.
missing_arg_names = [
"%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names)
]
arg_names = base_arg_names + missing_arg_names
graph_function = ConcreteFunction(
func_graph_module.func_graph_from_py_func(
self._name,
self._python_function,
args,
kwargs,
self.input_signature,
autograph=self._autograph,
autograph_options=self._autograph_options,
arg_names=arg_names,
override_flat_arg_shapes=override_flat_arg_shapes,
capture_by_value=self._capture_by_value),
self._function_attributes,
function_spec=self.function_spec,
# Tell the ConcreteFunction to clean up its graph once it goes out of
# scope. This is not the default behavior since it gets used in some
# places (like Keras) where the FuncGraph lives longer than the
# ConcreteFunction.
shared_func_graph=False)
return graph_function
def _define_function_with_shape_relaxation(self, args, kwargs, flat_args,
filtered_flat_args,
cache_key_context):
"""Define a function, relaxing arg shapes to avoid unnecessary retracing."""
flat_no_comp = nest.flatten((args, kwargs), expand_composites=False)
any_composite_args = any(
isinstance(x, composite_tensor.CompositeTensor) for x in flat_no_comp)
# Build a cache key where TensorShapes include only rank information (and
# not information about the size of each dimension).
if not any_composite_args:
rank_only_cache_key = self._cache_key(
args, kwargs, cache_key_context, include_tensor_ranks_only=True)
else:
# For the rank-only cache key, replace any composite tensors with
# shape-relaxed TypeSpecs.
(cache_key_args, cache_key_kwargs) = nest.map_structure(
_shape_relaxed_type_for_composite_tensor, (args, kwargs))
rank_only_cache_key = self._cache_key(
cache_key_args,
cache_key_kwargs,
cache_key_context,
include_tensor_ranks_only=True)
arg_specs = [_type_spec_for(x) for x in flat_no_comp]
relaxed_arg_specs = self._function_cache.arg_relaxed_specs.get(
rank_only_cache_key, None)
relaxed_arg_function = self._function_cache.arg_relaxed.get(
rank_only_cache_key, None)
if (relaxed_arg_function is not None
and all(_is_type_subset(x, y) for (x, y) in
zip(relaxed_arg_specs, arg_specs))):
return relaxed_arg_function, filtered_flat_args
if relaxed_arg_specs is None:
relaxed_arg_specs = arg_specs
else:
if len(arg_specs) != len(relaxed_arg_specs):
raise RuntimeError("Expected arg_specs len to match "
"relaxed_arg_specs len: %d vs. %d"
% (len(arg_specs), len(relaxed_arg_specs)))
relaxed_arg_specs = [
x if x is None else x.most_specific_compatible_type(y)
for (x, y) in zip(arg_specs, relaxed_arg_specs)]
self._function_cache.arg_relaxed_specs[rank_only_cache_key] = (
relaxed_arg_specs)
relaxed_arg_shapes = [
x if x is None else x.shape
for x in nest.flatten(relaxed_arg_specs, expand_composites=True)]
if any_composite_args:
# Rebuild composite tensors with the relaxed TypeSpecs. For example,
# if a tf.data iterator is passed as an argument, then we need to relax
# the TensorShapes in its element_spec.
(relaxed_arg_specs, relaxed_kwarg_specs) = nest.pack_sequence_as(
(args, kwargs), relaxed_arg_specs, expand_composites=False)
(args, kwargs) = nest.pack_sequence_as(
(relaxed_arg_specs, relaxed_kwarg_specs),
flat_args,
expand_composites=True)
graph_function = self._create_graph_function(
args, kwargs, override_flat_arg_shapes=relaxed_arg_shapes)
self._function_cache.arg_relaxed[rank_only_cache_key] = graph_function
return (graph_function, [
t for t in nest.flatten((args, kwargs), expand_composites=True)
if isinstance(t, (ops.Tensor,
resource_variable_ops.BaseResourceVariable))
])
def _maybe_define_function(self, args, kwargs):
"""Gets a function for these inputs, defining it if necessary.
`args` and `kwargs` can be None if this `Function` was created with an
`input_signature`.
Caller must hold self._lock.
Args:
args: The varargs for the Python function.
kwargs: The keyword args for the Python function.
Returns:
A graph function corresponding to the input signature implied by args and
kwargs, as well as filtered flattened inputs (only Tensors and Variables)
that the object should be called with.
Raises:
ValueError: If inputs are incompatible with the input signature.
TypeError: If the function inputs include non-hashable objects
RuntimeError: If there's an internal bug (inconsistency) in handling
shape relaxation retracing.
"""
if self.input_signature is None or args is not None or kwargs is not None:
args, kwargs, flat_args, filtered_flat_args = \
self._function_spec.canonicalize_function_inputs(*args, **kwargs)
else:
flat_args, filtered_flat_args = [None], []
cache_key_context = self._cache_key_context()
cache_key = self._cache_key(args, kwargs, cache_key_context)
try:
hash(cache_key)
except TypeError as e:
raise TypeError(
"Arguments supplied to `defun`-generated functions must be"
" hashable. Original error: %s" % e)
graph_function = self._function_cache.primary.get(cache_key, None)
if graph_function is not None:
return graph_function, filtered_flat_args
with monitoring.MonitoredTimer(_graph_building_time_counter.get_cell()):
with trace.Trace("tf.function-graph_building"):
logging.vlog(1,
"Creating new FuncGraph for Python function %r (key: %r)",
self._python_function, cache_key)
logging.vlog(2, "Python function signature [args: %s] [kwargs: %s]",
args, kwargs)
# pylint: disable=protected-access
call_context_key = cache_key._replace(input_signature=None)
# pylint: disable=protected-access
ag_status = (
ag_ctx.Status.ENABLED
if self._autograph else ag_ctx.Status.DISABLED)
with ag_ctx.ControlStatusCtx(
status=ag_status, options=self._autograph_options):
# Build a function with shape relaxation retracing if:
# 1. shape relaxation is explicitly enabled
# and 2. there's no provided input signature
# and 3. there's been a cache miss for this calling context
if (self._experimental_relax_shapes and
self.input_signature is None and
call_context_key in self._function_cache.missed):
return self._define_function_with_shape_relaxation(
args, kwargs, flat_args, filtered_flat_args, cache_key_context)
self._function_cache.missed.add(call_context_key)
graph_function = self._create_graph_function(args, kwargs)
self._function_cache.primary[cache_key] = graph_function
return graph_function, filtered_flat_args
def register(func, *args, **kwargs):
"""Register a specialization of a `Function` into the graph.
This won't actually call the function with the inputs, and only put the
function definition into graph. Register function with different input param
will result into multiple version of functions registered in graph.
Args:
func: the `Function` instance that generated by a @defun
*args: input arguments for the Python function.
**kwargs: input keyword arguments for the Python function.
Returns:
a `ConcreteFunction` object specialized to inputs and execution context.
Raises:
ValueError: When the input function is not a defun wrapped python function.
"""
if not isinstance(func, Function):
raise ValueError("Only defun function is allowed to be registered. "
"Got type: %s" % type(func))
concrete_func = func.get_concrete_function(*args, **kwargs)
concrete_func.add_to_graph()
concrete_func.add_gradient_functions_to_graph()
return concrete_func
def validate_signature(signature):
if any(not isinstance(arg, tensor_spec.DenseSpec)
for arg in nest.flatten(signature, expand_composites=True)):
raise TypeError("Invalid input_signature {}; input_signature must be "
"a possibly nested sequence of TensorSpec objects."
.format(signature))
def defun(func=None,
input_signature=None,
autograph=True,
experimental_autograph_options=None,
experimental_relax_shapes=False):
"""Compiles a Python function into a callable TensorFlow graph.
`defun` (short for "define function") compiles a Python function
composed of TensorFlow operations into a callable that executes a `tf.Graph`
containing those operations. The callable produced by `defun` contains only
the subgraph of TensorFlow operations that were executed when the Python
function was called with a particular input signature, defined as a list
of the shapes and dtypes of the Python function's Tensor-valued arguments and
the values of its non-Tensor Python objects.
When eager execution is enabled, the ability to create graphs from Python
functions makes it possible to incrementally trade off debuggability and
interactivity for performance. Functions compiled with `defun` cannot be
inspected with `pdb`; however, executing a graph
generated by `defun` sometimes takes less time and memory than eagerly
executing the corresponding Python function, since specifying computations as
graphs allows for optimizations like automatic buffer reuse and
parallelization among ops. Note that executing a `defun`-compiled function
incurs a small constant overhead, so eagerly executing sufficiently small
Python functions might take less time than executing their corresponding
`defun`-generated graphs.
For a Python function to be compatible with `defun`, all of its arguments must
be hashable Python objects or lists thereof. The function itself may not
modify the list/map structure of its arguments. Additionally, it must return
zero or more `tf.Tensor` objects. If the Python function returns
a `tf.Variable`, its compiled version will return the value of that variable
as a `tf.Tensor`.
Executing a graph generated by `defun` respects device annotations (i.e.,
all `with tf.device` directives present in a Python function will also be
present in its corresponding graph), but it is not yet possible to execute the
generated graphs across multiple machines.
_Example Usage_
```python
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
# A simple example.
def f(x, y):
return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)
g = tf.contrib.eager.defun(f)
x = tf.constant([[2.0, 3.0]])
y = tf.constant([[3.0, -2.0]])
# `f` and `g` will return the same value, but `g` will be executed as a
# TensorFlow graph.
assert f(x, y).numpy() == g(x, y).numpy()
# `defun` is capable of compiling Python functions that close over Python
# objects, including Tensors and Variables.
@tf.contrib.eager.defun
def h():
return f(x, y)
assert (h().numpy() == f(x, y).numpy()).all()
# `defun` automatically lifts variables out of the graphs it creates,
# allowing you to compile the `call` methods of `tf.keras.layers.Layer` and
# `tf.keras.Model` objects.
class MyModel(tf.keras.Model):
def __init__(self, keep_probability=0.2):
super(MyModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
self.keep_probability = keep_probability
@tf.contrib.eager.defun
def call(self, inputs, training=True):
x = self.dense2(self.dense1(inputs))
if training:
return tf.nn.dropout(x, self.keep_probability)
else:
return x
model = MyModel()
model(x, training=True) # executes a graph, with dropout
model(x, training=False) # executes a graph, without dropout
# `defun`-compiled functions are differentiable.
optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.01)
with tf.GradientTape() as tape:
outputs = model(x)
gradient = tape.gradient(outputs, model.trainable_variables)
optimizer.apply_gradients((grad, var) for grad, var in zip(gradient,
model.trainable_variables))
```
When using `defun`, there are subtleties regarding inputs, Python control
flow, and variable creation that one should be aware of. For concreteness, let
`f` be a Python function that returns zero or more `tf.Tensor` objects and
let `F = defun(f)`. `F` builds a graph for each unique input signature it
sees, Python control flow is baked into graphs, and operations related to
variable initialization are automatically lifted out of the graphs that `F`
generates and placed in the eager context if executing eagerly or into an
outer graph otherwise.
_Input Signatures_
By default, `F = tf.contrib.eager.defun(f)` instantiates a separate graph
for every unique sequence of the shapes and dtypes of Tensor arguments and
the values of Python objects it is invoked with. For example, calling
`F(tf.random.uniform([2])` will execute a different graph than
`F(tf.random.uniform([3])` because the two inputs have different shapes.
The first time that `F(*args, **kwargs)` is called with a particular sequence
of Tensor shapes and dtypes and Python values, it constructs a graph by
tracing the execution of `f(*args, **kwargs)`; this graph is bound to an
input signature inferred from `(*args, **kwargs)` and cached for future reuse.
NumPy arrays passed as inputs to `F` are converted to `tf.Tensor` objects
before being passed to `f`, and are treated as Tensors for caching. This
allows a function to be called multiple times with NumPy arrays having
different values but the same shape and dtype without re-tracing each time.
`tf.contrib.eager.defun` caches graphs for your convenience, letting you
define TensorFlow functions without explicitly specifying their signatures.
However, this policy is conservative and potentially expensive; for example,
when different invocations of your function have differently-shaped Tensor
inputs, this policy might generate more graph functions than necessary. To
eliminate such costs, `tf.contrib.eager.defun` allows you to supply an
optional `input_signature` argument specifying the shapes and dtypes of the
inputs. In particular, the shapes may be partially unspecified, with `None`s
in the unknown dimensions. When an input signature is provided,
`tf.contrib.eager.defun` will only instantiate a single graph for the
decorated Python function. The following is an example:
```python
import tensorflow as tf
# The first `TensorSpec` below describes the shape and dtype of `words`,
# and the second describes the shape and dtype of `another_tensor`. Note that
# the last dimension of the `words` `TensorSpec` is left unspecified.
@tf.contrib.eager.defun(input_signature=[
tf.contrib.eager.TensorSpec(shape=[50, 300, None], dtype=tf.float32),
tf.contrib.eager.TensorSpec(shape=[300, 100], dtype=tf.float32)
])
def my_sequence_model(words, another_tensor):
...
# Note how the third dimension of the first input can vary freely.
words = tf.random.uniform(([50, 300, 10])
second_input = tf.random.uniform([300, 100])
my_sequence_model(words, second_input)
words = tf.random.uniform(([50, 300, 20])
my_sequence_model(words, second_input)
# Passing an input with an incompatible shape will raise an error.
words = tf.random.uniform(([50, 100, 20])
my_sequence_model(words, second_input) # <---- This will raise an error.
```
Python functions that are compiled with an `input_signature` must only accept
Tensors as arguments and must not take unnamed keyword arguments (**kwargs).
_Tracing_
Be aware that because `F` only logs TensorFlow operations, all the other
Python code that `f` executes will only shape the _construction_ of the graphs
that `F` executes: the Python code won't be executed when the graphs
themselves are executed, though it will be executed every time the Python
function is traced (and a given Python function might be traced multiple
times, once for each input signature it is invoked with). For example, whereas
the Python function
```python
import tensorflow as tf
import numpy as np
tf.compat.v1.enable_eager_execution()
def add_noise():
return tf.eye(5) + np.random.randn(5, 5)
```
will return a different output everytime it is invoked, the compiled function
`compiled = tf.contrib.eager.defun(add_noise)` will return the same value
every time it is called, since a particular random offset generated by NumPy
will be inserted into the graph as a TensorFlow constant. The solution is to
replace the call to `np.random.randn` with `tf.random.normal((5, 5))`.
_Python Side-Effects_
A corollary of the previous discussion on tracing is the following: If a
Python function `f` has Python side-effects, then executing `f` multiple times
will not necessarily be semantically equivalent to executing `F =
tf.contrib.eager.defun(f)` multiple times; this difference is due to the fact
that `defun` only captures the subgraph of TensorFlow operations that is
constructed when `f` is called in a graph-building context.
_Python Control Flow_
The structure of many machine learning computations depend upon whether one is
training or validating, and it is common to nest specialized logic under `if
training:` blocks. By mapping each input signature to a unique graph, `defun`
lets users transparently compile such code, as the following code snippet
demonstrates:
```python
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
@tf.contrib.eager.defun
def lossy_matmul(W, x, training=True):
outputs = tf.matmul(W, x)
if training:
outputs = tf.nn.dropout(outputs, keep_probability=0.2)
return outputs
W = tf.random.normal((3, 5))
x = tf.random.normal((5, 1))
# Executes a graph that applies dropout.
lossy_outputs = lossy_matmul(W, x, training=True)
# Executes a graph that does not apply dropout.
exact_outputs = lossy_matmul(W, x, training=False)
```
_TensorFlow Control Flow_
When `autograph` is `True`, data-dependent control flow is allowed as well.
Control flow statements that depend on `Tensor` values are staged into
corresponding TensorFlow ops. For example, the following code will work as
expected:
```python
@tf.contrib.eager.defun
def dynamic_rnn_loop(cell, seq):
state, output = cell.zero_state()
for input in seq:
state, output = cell(input, state)
return output
```
For more information see `tf.autograph`.
_Variables_
TensorFlow operations related to variable creation and initialization are
automatically lifted out of the graphs generated by `defun`. In practice, this
implies that variable creation and initialization only happen the first time
`F` is called, and that variables are reused every time thereafter. Many
TensorFlow APIs, like `tf.keras.layers.Layer` objects, create variables the
first time they are called and reuse them thereafter. Automatic variable
lifting makes it possible to compile these APIs without extra effort, at the
cost of introducing a discrepancy between the semantics of executing Python
functions and their corresponding compiled functions. For example:
```python
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
def fn():
x = tf.Variable(0.0)
x.assign_add(1.0)
return x.read_value()
# `fn` is a Python function, so x is created, initialized, and destroyed upon
# every invocation
assert fn().numpy() == fn().numpy() == 1.0
compiled = tf.contrib.eager.defun(fn)
# Compiling `fn` with `defun` hoists all variables outside of the generated
# graph, so initialization happens exactly once.
assert compiled().numpy() == 1.0
assert compiled().numpy() == 2.0
```
Finally, because each input signature is bound to a unique graph, if your
Python function constructs `tf.Variable` objects, then each graph constructed
for that Python function will reference a unique set of variables. To
circumvent this problem, we recommend against compiling Python functions that
create `tf.Variable` objects. Instead, Python functions should either
lexically close over `tf.Variable` objects or accept them as arguments,
preferably encapsulated in an object-oriented container. If you must create
variables inside your Python function and you want each graph generated for it
to reference the same set of variables, add logic to your Python function that
ensures that variables are only created the first time it is called and are
reused for every subsequent invocation; note that this is precisely what
`tf.keras.layers.Layer` objects do, so we recommend using them to represent
variable-bearing computations whenever possible.
Args:
func: function to be compiled. If `func` is None, returns a
decorator that can be invoked with a single argument - `func`. The
end result is equivalent to providing all the arguments up front.
In other words, defun(input_signature=...)(func) is equivalent to
defun(func, input_signature=...). The former allows
the following use case:
@tf.contrib.eager.defun(input_signature=...)
def foo(...):
...
input_signature: A possibly nested sequence of
`tf.contrib.eager.TensorSpec` objects specifying the shapes and dtypes of
the Tensors that will be supplied to this function. If `None`, a separate
function is instantiated for each inferred input signature. If a
signature is specified, every input to `func` must be a `Tensor`, and
`func` cannot accept `**kwargs`.
autograph: Whether `func` should be compiled before
constructing the graph. See https://www.tensorflow.org/guide/autograph
for more information.
experimental_autograph_options: Experimental knobs (in the form of a tuple
of tensorflow.autograph.Feature values) to control behavior when
autograph=True.
experimental_relax_shapes: When true, argument shapes may be relaxed to
avoid unnecessary retracing.
Returns:
If `func` is not None, returns a callable that will execute the compiled
function (and return zero or more `tf.Tensor` objects).
If `func` is None, returns a decorator that, when invoked with a single
`func` argument, returns a callable equivalent to the case above.
Raises:
TypeError: If `input_signature` is neither `None` nor a sequence of
`tf.contrib.eager.TensorSpec` objects.
"""
return defun_with_attributes(
func=func,
input_signature=input_signature,
autograph=autograph,
experimental_autograph_options=experimental_autograph_options,
experimental_relax_shapes=experimental_relax_shapes)
def defun_with_attributes(func=None,
input_signature=None,
attributes=None,
autograph=True,
experimental_autograph_options=None,
experimental_compile=None,
experimental_relax_shapes=False,
experimental_follow_type_hints=False):
"""Compiles a Python function into a callable TensorFlow graph.
This function supports adding extra function attributes. See detailed
documentation in defun(). Currently this is not exposed in public API since we
don't expect user to directly use attributes, and attribute won't work by
itself. This assumption might change in future.
Args:
func: function to be compiled.
input_signature: same as defun()'s input_signature.
attributes: A dictionary of arguments which will be added to function def as
attributes. Currently only support primitive types as value, and only
allowlisted attribute name is allowed. Unallowlisted attribute name or
unsupported value will result into ValueError. `func_name` is also one of
the allowlisted argument which is a python string, and sets the name for
this `ConcreteFunction` in the graph.
autograph: same as defun()'s autograph.
experimental_autograph_options: same as defun()'s
experimental_autograph_options.
experimental_compile: same as defun()'s experimental_compile.
experimental_relax_shapes: same as defun()'s experimental_relax_shapes
experimental_follow_type_hints: see `tf.function`.
Returns:
Same as the return value of defun, with attributes added to the function in
graph.
"""
if input_signature is not None:
validate_signature(input_signature)
# TODO(apassos): deal with captured global state. Deal with control flow.
def decorated(function):
try:
if attributes:
name = attributes.pop("func_name", function.__name__)
else:
name = function.__name__
except AttributeError:
name = "function"
return tf_decorator.make_decorator(
function,
Function(
function,
name,
input_signature=input_signature,
attributes=attributes,
autograph=autograph,
autograph_options=experimental_autograph_options,
experimental_compile=experimental_compile,
experimental_relax_shapes=experimental_relax_shapes,
experimental_follow_type_hints=experimental_follow_type_hints))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None:
return decorated(func)
# This code path is for the
#
# @tfe.defun(...)
# def foo(...):
# ...
#
# use case, which is equivalent to `foo = tfe.defun(...)(foo)`
return decorated
# When a method is bound to objects of this type, it allows AutoGraph to
# recover a weak reference the original method's self pointer, so that it can
# execute it consistent with class_method_to_instance_method's
# bound_method_wrapper.
# TODO(b/119246461): This is not pretty. Use a descriptor instead?
class TfMethodTarget(object):
"""Binding target for methods replaced by function and defun."""
__slots__ = ("weakrefself_target__", "weakrefself_func__")
def __init__(self, target, original_python_function):
self.weakrefself_target__ = target
self.weakrefself_func__ = weakref.ref(original_python_function)
@property
def target(self):
return self.weakrefself_target__()
@property
def target_class(self):
true_self = self.weakrefself_target__()
if tf_inspect.isclass(true_self):
# Class method
return true_self
else:
return true_self.__class__
def call(self, args, kwargs):
wrapped_fn = self.weakrefself_func__()
if tf_inspect.ismethod(wrapped_fn):
wrapped_fn = six.get_unbound_function(wrapped_fn)
return wrapped_fn(self.weakrefself_target__(), *args, **kwargs)
def class_method_to_instance_method(original_function, instance):
"""Constructs a new `Function` with `self` bound."""
weak_instance = weakref.ref(instance)
# Note: while we could bind to a weakref proxy instead, that causes the
# bound method to be unhashable.
bound_method = types_lib.MethodType(
original_function.python_function,
TfMethodTarget(weak_instance, original_function.python_function))
# original_function is expected to be of one of the two `Function` types
# (defined either in function.py or def_function.py).
assert hasattr(original_function, "_name")
assert hasattr(original_function, "_autograph")
assert hasattr(original_function, "_function_spec")
assert hasattr(original_function, "python_function")
weak_bound_method_wrapper = None
def bound_method_wrapper(*args, **kwargs):
"""Wraps either a dummy MethodType or a converted AutoGraph function."""
# __wrapped__ allows AutoGraph to swap in a converted function.
strong_bound_method_wrapper = weak_bound_method_wrapper()
wrapped_fn = strong_bound_method_wrapper.__wrapped__
if wrapped_fn is strong_bound_method_wrapper.__original_wrapped__:
# If __wrapped__ was not replaced, then call original_function.
# TODO(mdan): For better consistency, use the wrapper's call().
wrapped_fn = original_function.python_function
if tf_inspect.ismethod(wrapped_fn):
wrapped_fn = six.get_unbound_function(wrapped_fn)
return wrapped_fn(weak_instance(), *args, **kwargs)
# If __wrapped__ was replaced, then it is always an unbound function.
# However, the replacer is still responsible for attaching self properly.
# TODO(mdan): Is it possible to do it here instead?
return wrapped_fn(*args, **kwargs)
weak_bound_method_wrapper = weakref.ref(bound_method_wrapper)
# pylint: disable=protected-access
# We make a dummy MethodType object to generate the correct bound method
# signature. The actual call is to a function with a weak reference to
# `instance`.
instance_func = type(original_function)(
tf_decorator.make_decorator(bound_method, bound_method_wrapper),
name=original_function._name,
autograph=original_function._autograph,
input_signature=original_function.input_signature,
experimental_relax_shapes=original_function._experimental_relax_shapes,
experimental_compile=original_function._experimental_compile)
# pylint: enable=protected-access
# And we wrap the function with tf_decorator so inspection works correctly
wrapped_instance_func = tf_decorator.make_decorator(
original_function.python_function, instance_func)
return wrapped_instance_func
class _FunctionGarbageCollector(object):
"""Cleans up cycles when a defun goes out of scope."""
__slots__ = ["_cache"]
def __init__(self, cache):
self._cache = cache
def __del__(self):
if func_graph_module is None or memory is None:
return
try:
while self._cache:
self._cache.popitem()
memory.dismantle_ordered_dict(self._cache)
except: # pylint: disable=bare-except
pass
class ConcreteFunctionGarbageCollector(object):
"""Cleans up reference cycles when a `ConcreteFunction` goes out of scope."""
__slots__ = ["_func_graph"]
def __init__(self, func_graph):
self._func_graph = func_graph
def release(self):
"""Call off the FuncGraph deletion."""
self._func_graph = None
def __del__(self):
if func_graph_module is None or memory is None or self._func_graph is None:
return
try:
func_graph_module.dismantle_func_graph(self._func_graph)
except: # pylint: disable=bare-except
pass
class _Marker(object):
"""Markers used to pretty-print nested args in function signatures."""
__slots__ = ["_s"]
def __init__(self, s):
self._s = s
def __repr__(self):
return str(self._s)
def _structure_summary(structure):
"""Displays a summary of the nesting structure of the given value."""
def type_name(x):
if isinstance(x, type_spec.TypeSpec):
return x.value_type.__name__
else:
return type(x).__name__
markers = [_Marker(type_name(v)) for v in nest.flatten(structure)]
return str(nest.pack_sequence_as(structure, markers))
def _contains_type_spec(value):
return any(isinstance(x, type_spec.TypeSpec) for x in nest.flatten(value))