| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| # pylint: disable=unidiomatic-typecheck |
| """Defun decorator for defining graph-mode functions.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import collections |
| import enum # pylint: disable=g-bad-import-order |
| import functools |
| import itertools |
| import threading |
| import types as types_lib |
| import weakref |
| |
| import numpy as np |
| import six |
| |
| from tensorflow.core.framework import attr_value_pb2 |
| from tensorflow.core.framework import function_pb2 |
| from tensorflow.python import pywrap_tensorflow |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import execute |
| from tensorflow.python.eager import tape |
| from tensorflow.python.eager.graph_only_ops import graph_placeholder |
| from tensorflow.python.framework import c_api_util |
| from tensorflow.python.framework import composite_tensor |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import device as pydev |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import error_interpolation |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import func_graph as func_graph_module |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.framework import tensor_spec |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import custom_gradient |
| from tensorflow.python.ops import default_gradient |
| from tensorflow.python.ops import functional_ops |
| from tensorflow.python.ops import gradients_util |
| from tensorflow.python.ops import resource_variable_ops |
| from tensorflow.python.platform import tf_logging as logging |
| from tensorflow.python.util import compat |
| from tensorflow.python.util import function_utils |
| from tensorflow.python.util import lazy_loader |
| from tensorflow.python.util import memory |
| from tensorflow.python.util import nest |
| from tensorflow.python.util import object_identity |
| from tensorflow.python.util import tf_decorator |
| from tensorflow.python.util import tf_inspect |
| |
| # Loaded lazily due to a circular dependency (roughly |
| # tf.function->autograph->->dataset->tf.function). |
| # TODO(b/133251390): Use a regular import. |
| ag_ctx = lazy_loader.LazyLoader( |
| "ag_ctx", globals(), |
| "tensorflow.python.autograph.core.ag_ctx") |
| |
| |
| FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name" |
| BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name" |
| |
| |
| class CacheKey( |
| collections.namedtuple("CacheKey", [ |
| "input_signature", "parent_graph", "device_functions", |
| "colocation_stack", "in_cross_replica_context" |
| ])): |
| """Named tuple used to key the function cache.""" |
| |
| def __hash__(self): |
| """Provide a hash even if the input signature objects aren't hashable.""" |
| return hash((self._hash_fix(self.input_signature), self.parent_graph, |
| self.device_functions, self.colocation_stack, |
| self.in_cross_replica_context)) |
| |
| def _hash_fix(self, elem): |
| """Ensure elem is hashable even if a Variable is nested in it.""" |
| # Descend into tuples |
| if isinstance(elem, tuple): |
| return tuple(self._hash_fix(i) for i in elem) |
| |
| if isinstance(elem, set): |
| return {self._hash_fix(i) for i in elem} |
| |
| # If the element is not hashable, assume it is a weakref to a variable and |
| # return the dtype & shape. Else, simply return the element |
| try: |
| hash(elem) |
| except TypeError: |
| v = elem() |
| return (v.__class__, tensor_spec.TensorSpec(v.shape, v.dtype)) |
| |
| return elem |
| |
| |
| CacheKey.replace = CacheKey._replace # pylint: disable=protected-access |
| |
| |
| def _flat_shape_list(*params): |
| """Return a flat list of TensorShapes, one for each tensor[spec] in `*params`. |
| |
| If `params` contains `CompositeTensors`, then they are expanded to their |
| components `Tensors`. |
| |
| Args: |
| *params: Set of nested entries containing Tensors, TensorSpec, and |
| non-tensors. |
| |
| Returns: |
| A list of entries containing either `None` or `TensorShape`. |
| """ |
| return [tensor_shape.TensorShape(x.shape) |
| if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec)) else None |
| for x in nest.flatten(params, expand_composites=True)] |
| |
| |
| def _shape_less_specific_than(relaxed, to_check): |
| """Checks if `relaxed` is less specific than `to_check`. |
| |
| This is an asymmetric check, unlike `TensorShape.is_compatible_with`. If |
| `to_check` has a dimension with an undefined shape, `relaxed` must also have |
| an undefined shape for that dimension. |
| |
| Args: |
| relaxed: A `TensorShape` to check against. |
| to_check: A second `TensorShape`. |
| |
| Returns: |
| True if `to_check` represents a set of shapes which is a subset of |
| `relaxed`'s shapes and False otherwise. |
| """ |
| if to_check.dims is not None and relaxed.dims is not None: |
| if to_check.rank != relaxed.rank: |
| return False |
| for check_dim, relaxed_dim in zip(to_check.dims, relaxed.dims): |
| if check_dim.value is None and relaxed_dim.value is not None: |
| return False |
| if not relaxed_dim.is_compatible_with(check_dim): |
| return False |
| return True |
| |
| |
| def _compatible_shapes(flat_relaxed, flat_to_check): |
| """Check if lists of TensorShapes contain compatible shapes. |
| |
| Checks that each `flat_relaxed` shape covers a superset of the shapes of the |
| corresponding `flat_to_check` shape. |
| |
| Args: |
| flat_relaxed: List of TensorShape or None. |
| flat_to_check: List of TensorShape or None. |
| |
| Returns: |
| A python bool. |
| |
| Raises: |
| RuntimeError: |
| if `len(flat_relaxed) != len(flat_to_check)`. |
| RuntimeError: |
| if `flat_relaxed[i] is None != flat_to_check[i] is None` for any `i`. |
| """ |
| |
| if len(flat_relaxed) != len(flat_to_check): |
| raise RuntimeError("Expected shape lists of identical lengths, but saw: " |
| "%s and %s" % (flat_relaxed, flat_to_check)) |
| def is_compatible(relaxed, to_check): |
| """Internal help function. |
| |
| Args: |
| relaxed: TensorShape or None. |
| to_check: TensorShape or None. |
| |
| Returns: |
| Python bool. |
| |
| Raises: |
| RuntimeError: If `relaxed is None != to_check is None`. |
| """ |
| # If both x and y are None, there is no shape to compare. Otherwise check |
| # if they are compatible with each other. Either way, both input signatures |
| # must have have Tensors in the same entries. If not, raise an assertion |
| # error. |
| if relaxed is None != to_check is None: |
| raise RuntimeError( |
| "Expected signature type matches between flattened input shapes " |
| "%s and %s; but saw that (%s is None) != (%s is None)" |
| % (flat_relaxed, flat_to_check, relaxed, to_check)) |
| return relaxed is None or _shape_less_specific_than(relaxed, to_check) |
| return all(is_compatible(relaxed, to_check) |
| for relaxed, to_check in zip(flat_relaxed, flat_to_check)) |
| |
| |
| def common_shape(x, y): |
| """Find a `TensorShape` that is compatible with both `x` and `y`.""" |
| if x is None != y is None: |
| raise RuntimeError( |
| "Cannot find a common shape when LHS shape is None but RHS shape " |
| "is not (or vice versa): %s vs. %s" % (x, y)) |
| if x is None: |
| return None # The associated input was not a Tensor, no shape generated. |
| if not isinstance(x, tensor_shape.TensorShape): |
| raise TypeError("Expected x to be a TensorShape but saw %s" % (x,)) |
| if not isinstance(y, tensor_shape.TensorShape): |
| raise TypeError("Expected y to be a TensorShape but saw %s" % (y,)) |
| if x.rank != y.rank or x.rank is None: |
| return tensor_shape.TensorShape(None) |
| dims = [] |
| for dim_x, dim_y in zip(x.dims, y.dims): |
| if (dim_x != dim_y |
| or tensor_shape.dimension_value(dim_x) is None |
| or tensor_shape.dimension_value(dim_y) is None): |
| dims.append(None) |
| else: |
| dims.append(tensor_shape.dimension_value(dim_x)) |
| return tensor_shape.TensorShape(dims) |
| |
| |
| def is_same_structure(structure1, |
| structure2, |
| check_values=False): |
| """Check two structures for equality, optionally of types and of values.""" |
| try: |
| nest.assert_same_structure(structure1, structure2, expand_composites=True) |
| except (ValueError, TypeError): |
| return False |
| if check_values: |
| flattened1 = nest.flatten(structure1, expand_composites=True) |
| flattened2 = nest.flatten(structure2, expand_composites=True) |
| # First check the types to avoid AttributeErrors. |
| if any(type(f1) != type(f2) for f1, f2 in zip(flattened1, flattened2)): |
| return False |
| return flattened1 == flattened2 |
| return True |
| |
| |
| def _parse_func_attrs(attributes): |
| """Convert the keyword arguments into function_def attributes. |
| |
| Currently only support primitive types: bool, int, float and string. |
| |
| Args: |
| attributes: the dictionary of attributes. |
| Returns: |
| A dict of attributes where the key is the name of attribute and the value |
| is the AttrValue proto. |
| Raises: |
| ValueError: If the kwargs contains unwhitelisted name or unsupported value |
| types. |
| """ |
| attrs = {} |
| for key, value in attributes.items(): |
| if isinstance(value, attr_value_pb2.AttrValue): |
| attrs[key] = value |
| # bool type check has to happen before int since bool is a subclass of int. |
| elif isinstance(value, bool): |
| attrs[key] = attr_value_pb2.AttrValue(b=value) |
| elif isinstance(value, int): |
| attrs[key] = attr_value_pb2.AttrValue(i=value) |
| elif isinstance(value, float): |
| attrs[key] = attr_value_pb2.AttrValue(f=value) |
| elif isinstance(value, (str, bytes, six.text_type)): |
| attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value)) |
| else: |
| raise ValueError("Unsupported attribute type for %s with type %s" % |
| (key, type(value))) |
| return attrs |
| |
| |
| class _InterpolateFunctionError(object): |
| """Context Manager that interpolates the exception from 'top_level_func'.""" |
| |
| def __init__(self, top_level_func): |
| self._func = top_level_func |
| |
| def __enter__(self): |
| pass |
| |
| def __exit__(self, typ, exc, tb): |
| if not exc or not isinstance(exc, errors.OpError): |
| return False |
| message = compat.as_text(exc.message) |
| _, tags = error_interpolation.parse_message(message) |
| g = None |
| func_stack = [] |
| for t in tags: |
| if t.type == "function_node": |
| # TODO(mdan): Tests should cover this. |
| if t.name == compat.as_str(self._func.name): |
| g = self._func.graph |
| elif g: |
| next_func = g._get_function(t.name) |
| if next_func is not None and isinstance(next_func, |
| _EagerDefinedFunction): |
| g = next_func.graph |
| if g: |
| func_stack.append(g.name) |
| else: |
| func_stack.append("<unknown>") |
| if g: |
| message = error_interpolation.interpolate(message, g) |
| message += "\n\nFunction call stack:\n" |
| message += " -> ".join(func_stack) |
| message += "\n" |
| exc._message = message # pylint: disable=protected-access |
| return False |
| |
| |
| def _forward_name(n): |
| """The name of a generated forward defun named n.""" |
| return "__forward_%s_%s" % (n, ops.uid()) |
| |
| |
| def _backward_name(n): |
| """The name of a generated backward defun named n.""" |
| return "__backward_%s_%s" % (n, ops.uid()) |
| |
| |
| def _inference_name(n): |
| """The name of a forward-but-no-gradient defun named n.""" |
| return "__inference_%s_%s" % (n, ops.uid()) |
| |
| |
| class _EagerDefinedFunctionDeleter(object): |
| """Unregister function from eager context.""" |
| |
| def __init__(self, name): |
| self.name = name |
| |
| def __del__(self): |
| try: |
| context.remove_function(self.name) |
| except TypeError: |
| # Suppress some exceptions, mainly for the case when we're running on |
| # module deletion. Things that can go wrong include the context module |
| # already being unloaded, self._handle._handle_data no longer being |
| # valid, and so on. Printing warnings in these cases is silly |
| # (exceptions raised from __del__ are printed as warnings to stderr). |
| pass # 'NoneType' object is not callable when the handle has been |
| # partially unloaded. |
| except AttributeError: |
| pass # 'NoneType' object has no attribute 'eager_mode' when context has |
| # been unloaded. Will catch other module unloads as well. |
| |
| |
| # TODO(apassos) get rid of this by splitting framework.function._DefinedFunction |
| # so it doesn't have the definition-generating logic and is just a container for |
| # an already-defined function. |
| class _EagerDefinedFunction(object): |
| """Callable with the interface of `framework.function._DefinedFunction`. |
| |
| `_EagerDefinedFunction` encapsulates a function definition and its properties, |
| and it provides a method for calling the encapsulated function. Some Ops |
| take functions as attributes, which have type `func`; an instance of this |
| class may be provided as the value of these `func` attributes. |
| """ |
| |
| def __init__(self, name, graph, inputs, outputs, attrs): |
| """Initializes an eager defined function. |
| |
| Args: |
| name: str, the name for the created function. |
| graph: Graph, the graph containing the operations in the function |
| inputs: the tensors in the graph to be used as inputs to the function |
| outputs: the tensors in the graph which will be outputs to the function |
| attrs: dict mapping names of attributes to their AttrValue values |
| """ |
| input_ops = set(arg.op for arg in inputs) |
| operations = [op for op in graph.get_operations() if op not in input_ops] |
| |
| graph_output_names = graph._output_names # pylint: disable=protected-access |
| if (graph_output_names is not None and |
| all(ops.tensor_id(t) in graph_output_names for t in outputs)): |
| output_names = [ |
| compat.as_bytes(graph_output_names[ops.tensor_id(t)]) for t in outputs |
| ] |
| if len(set(output_names)) != len(output_names): |
| # There are duplicate names for some reason, probably an invalid |
| # signature. Revert to auto-naming. |
| output_names = [] |
| else: |
| output_names = [] |
| fn = pywrap_tensorflow.TF_GraphToFunction_wrapper( |
| graph._c_graph, # pylint: disable=protected-access |
| compat.as_str(name), |
| False, |
| [o._c_op for o in operations], # pylint: disable=protected-access |
| [t._as_tf_output() for t in inputs], # pylint: disable=protected-access |
| [t._as_tf_output() for t in outputs], # pylint: disable=protected-access |
| output_names, |
| [o._c_op for o in graph.control_outputs], # pylint: disable=protected-access |
| [], # control_output_names |
| None, |
| compat.as_str("")) |
| |
| for name, attr_value in attrs.items(): |
| serialized = attr_value.SerializeToString() |
| # TODO(iga): this creates and deletes a new TF_Status for every attr. |
| # It might be worth creating a convenient way to re-use status. |
| pywrap_tensorflow.TF_FunctionSetAttrValueProto( |
| fn, compat.as_str(name), serialized) |
| |
| # TODO(apassos) avoid creating a FunctionDef (specially to grab the |
| # signature, but also in general it's nice not to depend on it. |
| with c_api_util.tf_buffer() as buffer_: |
| pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_) |
| proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) |
| function_def = function_pb2.FunctionDef() |
| function_def.ParseFromString(compat.as_bytes(proto_data)) |
| self.name = compat.as_bytes(function_def.signature.name) |
| with ops.init_scope(): |
| if context.executing_eagerly(): |
| context.ensure_initialized() |
| context.add_function(fn) |
| self._function_deleter = _EagerDefinedFunctionDeleter(self.name) |
| self._registered_on_context = True |
| self.definition = function_def |
| self.signature = function_def.signature |
| self._num_outputs = len(self.signature.output_arg) |
| self._output_types = [o.type for o in self.signature.output_arg] |
| self._output_shapes = [o.shape for o in outputs] |
| self._control_captures = graph.control_captures |
| # Shallow copy outputs since ConcreteFunction may mutate it. |
| self._func_graph_outputs = list(outputs) |
| self.grad_func_name = None |
| self.python_grad_func = None |
| self._c_func = c_api_util.ScopedTFFunction(fn) |
| self._grad_func = None |
| self.graph = graph |
| self._stateful_ops = tuple(op for op in operations if op._is_stateful) # pylint: disable=protected-access |
| |
| def add_to_graph(self, g=None): |
| # pylint: disable=protected-access |
| if not g and context.executing_eagerly(): |
| context.context().add_function_def(self.definition) |
| else: |
| if self.name not in g._functions: |
| g._add_function(self) |
| for f in self.graph._functions.values(): |
| if f.name not in g._functions: |
| g._add_function(f) |
| # pylint: enable=protected-access |
| |
| @property |
| def stateful_ops(self): |
| return self._stateful_ops |
| |
| def call(self, ctx, args, cancellation_manager=None): |
| """Calls this function with `args` as inputs. |
| |
| `ConcreteFunction` execution respects device annotations only if the |
| function won't be compiled with xla. |
| |
| Args: |
| ctx: a Context object |
| args: a list of arguments to supply this function with. |
| cancellation_manager: a `CancellationManager` object that can be used to |
| cancel function execution. |
| |
| Returns: |
| The outputs of the function call. |
| |
| Raises: |
| ValueError: if the number of arguments is incorrect. |
| """ |
| if len(args) != len(self.signature.input_arg): |
| raise ValueError( |
| "Arguments and signature arguments do not match. " |
| "got: %s, expected: %s " % |
| (len(args), len(list(self.signature.input_arg)))) |
| |
| function_call_options = ctx.function_call_options |
| if function_call_options.config_proto_serialized is None: |
| config = function_utils.get_disabled_rewriter_config() |
| else: |
| config = function_call_options.config_proto_serialized |
| executor_type = function_call_options.executor_type or "" |
| |
| executing_eagerly = ctx.executing_eagerly() |
| if executing_eagerly: |
| with _InterpolateFunctionError(self): |
| if cancellation_manager is None: |
| outputs = execute.execute( |
| str(self.signature.name), |
| num_outputs=self._num_outputs, |
| inputs=args, |
| attrs=("executor_type", executor_type, "config_proto", config), |
| ctx=ctx) |
| else: |
| outputs = execute.execute_with_cancellation( |
| str(self.signature.name), |
| num_outputs=self._num_outputs, |
| inputs=args, |
| attrs=("executor_type", executor_type, "config_proto", config), |
| ctx=ctx, |
| cancellation_manager=cancellation_manager) |
| # Replace empty list with None |
| outputs = outputs or None |
| else: |
| # TODO(akshayka): Either remove this if the FunctionLibraryRuntime |
| # creates `PartitionedCallOp` kernels by default, or remove the previous |
| # branch if a TPU kernel is registered for `PartitionedCall`. |
| with _InterpolateFunctionError(self): |
| with ops.control_dependencies(self._control_captures): |
| # The caller must use record_operation to record this operation in the |
| # eager case, so we enforce the same requirement for the non-eager |
| # case by explicitly pausing recording. We don't have a gradient |
| # registered for PartitionedCall, so recording this operation confuses |
| # forwardprop code (GradientTape manages to ignore it). |
| with tape.stop_recording(): |
| outputs = functional_ops.partitioned_call( |
| args=args, |
| f=self, |
| tout=self._output_types, |
| executing_eagerly=executing_eagerly, |
| config=config, |
| executor_type=executor_type) |
| |
| if executing_eagerly: |
| return outputs |
| else: |
| # TODO(b/128924522): This additional set_shape should not be |
| # necessary. ShapeRefiner likely needs to inspect handle_data. Remove this |
| # once that's done. |
| for i, shape in enumerate(self._output_shapes): |
| outputs[i].set_shape(shape) |
| for i, func_graph_output in enumerate(self._func_graph_outputs): |
| custom_gradient.copy_handle_data(func_graph_output, outputs[i]) |
| return outputs |
| |
| |
| class _DelayedRewriteGradientFunctions(object): |
| """Caches forward/backward functions with a delayed forward rewrite.""" |
| |
| def __init__(self, func_graph, attrs, func_graph_deleter): |
| """Construct an inference function and initialize caches.""" |
| # A map from the number of forward function outputs with accepted gradients |
| # to forward and backward functions, used to cache non-tape backward |
| # function generation. |
| self._cached_function_pairs = {} |
| self._func_graph = func_graph |
| self._inference_function = _EagerDefinedFunction( |
| _inference_name(self._func_graph.name), self._func_graph, |
| self._func_graph.inputs, self._func_graph.outputs, attrs) |
| self._attrs = attrs |
| self._gradient_name = None |
| # Note that the FuncGraph is mutated later, so we need to inspect it now to |
| # figure out the user-specified outputs of the inference function. |
| self._num_inference_outputs = len(self._func_graph.outputs) |
| self._func_graph_deleter = func_graph_deleter |
| |
| def forward_backward(self, num_doutputs=None): |
| """A possibly-cached pair of forward and backward functions.""" |
| if num_doutputs is None: |
| num_doutputs = self._num_inference_outputs |
| forward_backward = self._cached_function_pairs.get(num_doutputs) |
| if forward_backward is not None: |
| return forward_backward |
| forward, backward = self._construct_forward_backward(num_doutputs) |
| self._cached_function_pairs[num_doutputs] = (forward, backward) |
| return forward, backward |
| |
| def _construct_forward_backward(self, num_doutputs): |
| """Constructs a pair of forward and backward functions. |
| |
| Args: |
| num_doutputs: The constructed backprop function will take output gradients |
| for the first `num_doutputs` outputs of the forward function. Defaults |
| to the number of outputs for the inference function, but when |
| higher-order gradients are computed this will increase to include side |
| outputs. |
| |
| Returns: |
| A pair of (forward_function, backward_function): |
| forward_function: A re-generated inference function (an |
| _EagerDefinedFunction) to account for new side outputs, if any extra |
| were required when building the backward pass. |
| backward_function: A ConcreteFunction that Takes `num_doutputs` |
| arguments and returns gradients with respect to inputs of the forward |
| function. |
| """ |
| trainable_outputs = [ |
| output for output in self._func_graph.outputs[:num_doutputs] |
| if gradients_util.IsTrainable(output)] |
| |
| signature = [] |
| for t in trainable_outputs: |
| signature.append( |
| tensor_spec.TensorSpec(*default_gradient.shape_and_dtype(t))) |
| |
| def _backprop_function(*grad_ys): |
| return gradients_util._GradientsHelper( # pylint: disable=protected-access |
| trainable_outputs, |
| self._func_graph.inputs, |
| grad_ys=grad_ys, |
| src_graph=self._func_graph) |
| |
| with self._func_graph.as_default(): |
| backwards_graph = func_graph_module.FuncGraph( |
| _backward_name(self._func_graph.name)) |
| func_graph_module.func_graph_from_py_func( |
| name=backwards_graph.name, |
| python_func=_backprop_function, |
| args=[], kwargs={}, |
| signature=signature, |
| func_graph=backwards_graph) |
| backwards_graph_captures = backwards_graph.external_captures |
| captures_from_forward = [ |
| c for c in backwards_graph_captures if |
| not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph] |
| |
| forward_function_name = _forward_name(self._func_graph.name) |
| |
| existing_outputs = object_identity.ObjectIdentitySet( |
| self._func_graph.outputs) |
| for capture in captures_from_forward: |
| if capture not in existing_outputs: |
| existing_outputs.add(capture) |
| self._func_graph.outputs.append(capture) |
| backward_function_attr = _parse_func_attrs( |
| {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name}) |
| backward_function_attr.update(self._attrs) |
| |
| backward_function = ConcreteFunction( |
| backwards_graph, attrs=backward_function_attr) |
| forward_function_attr = _parse_func_attrs({ |
| BACKWARD_FUNCTION_ATTRIBUTE_NAME: |
| backward_function.name}) |
| forward_function_attr.update(self._attrs) |
| |
| forward_function = _EagerDefinedFunction( |
| forward_function_name, self._func_graph, self._func_graph.inputs, |
| self._func_graph.outputs, forward_function_attr) |
| return forward_function, backward_function |
| |
| def _rewrite_forward_and_call_backward(self, op, *doutputs): |
| """Add outputs to the forward call and feed them to the grad function.""" |
| forward_function, backwards_function = self.forward_backward(len(doutputs)) |
| if not backwards_function.outputs: |
| return [] |
| forward_function.add_to_graph(op.graph) |
| |
| # pylint: disable=protected-access |
| # Rewrite an inference call op to be a forward call op |
| op._set_func_attr("f", forward_function.name) |
| op._set_type_list_attr("Tout", forward_function._output_types) |
| op._add_outputs( |
| forward_function._output_types[len(op.outputs):], |
| forward_function._output_shapes[len(op.outputs):]) |
| for i in range(len(op.outputs)): |
| func_graph_output = forward_function._func_graph_outputs[i] |
| custom_gradient.copy_handle_data(func_graph_output, op.outputs[i]) |
| # pylint: enable=protected-access |
| |
| capture_mapping = dict( |
| zip([ops.tensor_id(t) for t in self._func_graph.outputs], op.outputs)) |
| remapped_captures = [ |
| capture_mapping.get(ops.tensor_id(capture), capture) |
| for capture in backwards_function.captured_inputs |
| ] |
| |
| # Replace Nones with zeros since we're calling a graph function which |
| # expects numeric inputs. |
| cleaned_doutputs = [] |
| for doutput, placeholder in zip(doutputs, self._func_graph.outputs): |
| if gradients_util.IsTrainable(placeholder): |
| if doutput is not None: |
| cleaned_doutputs.append(doutput) |
| else: |
| cleaned_doutputs.append(default_gradient.zeros_like(placeholder)) |
| |
| # Compute the gradients using the side outputs |
| return backwards_function._call_flat( # pylint: disable=protected-access |
| cleaned_doutputs, remapped_captures) |
| |
| def register(self): |
| """Registers a delayed-rewrite gradient with a unique name (idempotent). |
| |
| The gradient rewrites an inference call op to a forward call op, but does |
| not modify a pre-existing forward call op. It then computes the gradient |
| from the output's gradients and the side outputs of the forward op. |
| |
| Returns: |
| The name under which gradient was registered. |
| """ |
| if self._gradient_name: |
| return self._gradient_name |
| self._gradient_name = "PartitionedCall-%s" % ops.uid() |
| |
| @ops.RegisterGradient(self._gradient_name) |
| def _registered_grad_fn(op, *doutputs): # pylint: disable=unused-variable |
| return self._rewrite_forward_and_call_backward(op, *doutputs) |
| return self._gradient_name |
| |
| @property |
| def forward(self): |
| """A forward function with only user-specified outputs. |
| |
| The call operation for the returned inference function can be rewritten into |
| a forward function. This only happens if the backward function (from the |
| `backward` method) ends up being used to compute gradients. |
| |
| This approach avoids constructing unnecessary graphs, but it only works if |
| we are calling this function when not executing eagerly. |
| |
| Returns: |
| An _EagerDefinedFunction. |
| """ |
| return self._inference_function |
| |
| def backward(self, outputs): |
| """Fetch a backward function for `outputs` from the forward function.""" |
| def _backward_function(*args): |
| call_op = outputs[0].op |
| return self._rewrite_forward_and_call_backward(call_op, *args) |
| return _backward_function, outputs |
| |
| |
| class _TapeGradientFunctions(object): |
| """Caches forward and backward functions compatible with eager gradients. |
| |
| In contrast to the delayed-rewrite approach in |
| `_DelayedRewriteGradientFunctions` which only works with delayed execution, |
| the forward function generated by this class has a fixed set of outputs which |
| may be preserved by a tape in order to compute gradients later. |
| |
| This class is abstract; its child classes differ in how many side outputs of |
| the forward function their backward function accepts gradients for, which |
| determines whether higher-order tape gradients are possible. |
| """ |
| |
| def __init__(self, func_graph, attrs, func_graph_deleter): |
| self._func_graph = func_graph |
| self._attrs = attrs |
| self._forward = None |
| self._backward = None |
| self._num_outputs = len(func_graph.outputs) |
| self._func_graph_deleter = func_graph_deleter |
| |
| def _build_functions_for_outputs(self, outputs): |
| """Forward+backward functions where the backward function sees `outputs`.""" |
| # First figure out which of `outputs` are trainable. We'll accept gradients |
| # for each of these in the backward function. |
| handles_to_variables = self._func_graph.variable_captures |
| trainable_outputs = [] |
| for output in outputs: |
| if gradients_util.IsTrainable(output): |
| # Swap in the Variable object for resource handles if we can so |
| # sparse gradients work. |
| output = handles_to_variables.get(ops.tensor_id(output), output) |
| trainable_outputs.append(output) |
| |
| backwards_graph = func_graph_module.FuncGraph( |
| _backward_name(self._func_graph.name)) |
| # Keep track of the forward graph so that if the backwards graph |
| # tries to capture tensors those will be correctly captured first in |
| # the forward graph. This is an edge case that can only happen with |
| # tf.custom_gradient. |
| backwards_graph._forward_func_graph = self._func_graph # pylint: disable=protected-access |
| with backwards_graph.as_default(): |
| gradients_wrt_outputs = [] |
| for output in trainable_outputs: |
| gradient_shape, gradient_dtype = default_gradient.shape_and_dtype( |
| output) |
| gradients_wrt_outputs.append( |
| graph_placeholder(gradient_dtype, gradient_shape)) |
| gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access |
| trainable_outputs, |
| self._func_graph.inputs, |
| grad_ys=gradients_wrt_outputs, |
| src_graph=self._func_graph) |
| |
| captures_from_forward = [ |
| c for c in backwards_graph.external_captures |
| if not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph |
| ] |
| existing_outputs = object_identity.ObjectIdentitySet( |
| self._func_graph.outputs) |
| for capture in captures_from_forward: |
| if capture not in existing_outputs: |
| existing_outputs.add(capture) |
| self._func_graph.outputs.append(capture) |
| |
| forward_function_name = _forward_name(self._func_graph.name) |
| backward_function_attr = _parse_func_attrs( |
| {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name}) |
| backward_function_attr.update(self._attrs) |
| |
| # The ordering of `backwards_graph.inputs` is important: inputs of |
| # `backward_function` correspond to outputs (including |
| # side outputs) of `self._tape_forward_function`. |
| backwards_graph.inputs = ( |
| gradients_wrt_outputs + backwards_graph.internal_captures) |
| backwards_graph.outputs.extend( |
| grad |
| for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True) |
| if grad is not None) |
| backwards_graph.structured_outputs = gradients_wrt_inputs |
| backward_function = ConcreteFunction( |
| backwards_graph, attrs=backward_function_attr) |
| |
| forward_function_attr = _parse_func_attrs({ |
| BACKWARD_FUNCTION_ATTRIBUTE_NAME: |
| backward_function.name}) |
| forward_function_attr.update(self._attrs) |
| |
| forward_function = _EagerDefinedFunction( |
| forward_function_name, self._func_graph, self._func_graph.inputs, |
| self._func_graph.outputs, |
| forward_function_attr) |
| return forward_function, backward_function |
| |
| @property |
| def forward(self): |
| """Construct or fetch a forward function with side-outputs. |
| |
| When graph building without a tape active, symbolic gradients rely on |
| regenerating the backward function for higher-order gradients (to account |
| for new side outputs of the rewritten forward function call). Thus there is |
| no fixed backward function for this case. However, when a tape is active |
| (eager or graph building), we generate fixed backward and forward functions |
| at forward function call time. |
| |
| This difference between the tape and non-tape cases is to avoid building |
| unneeded backward functions while graph building (where we may or may not |
| eventually need gradients). |
| |
| Returns: |
| A forward _EagerDefinedFunction. |
| """ |
| if self._forward is None: |
| self._forward, self._backward = ( |
| self._forward_and_backward_functions()) |
| return self._forward |
| |
| def backward(self, outputs): |
| """Create a backward function given `outputs` from the forward function.""" |
| capture_mapping = dict( |
| zip([ops.tensor_id(t) for t in self._func_graph.outputs], outputs)) |
| remapped_captures = [ |
| capture_mapping.get(ops.tensor_id(capture), capture) |
| for capture in self._backward.captured_inputs |
| ] |
| # We may need to use zeros_like to get a zero for variant Tensors with |
| # unconnected gradients. We do that in advance so we don't have to hold on |
| # to the outputs themselves, which may not be needed otherwise. |
| variant_zeros_like = {} |
| backward_function_inputs = ( |
| len(self._backward.inputs) - len(self._backward.captured_inputs)) |
| recorded_outputs = [] |
| trainable_recorded_outputs = 0 |
| skip_positions = [] |
| for output_index, output in enumerate(outputs): |
| if trainable_recorded_outputs < backward_function_inputs: |
| recorded_outputs.append(output) |
| if gradients_util.IsTrainable(output): |
| trainable_recorded_outputs += 1 |
| else: |
| skip_positions.append(output_index) |
| if output.dtype == dtypes.variant: |
| variant_zeros_like[output_index] = default_gradient.zeros_like(output) |
| |
| def _backward_function_wrapper(*args): |
| """Process output gradients and call the backward function.""" |
| if not self._backward.outputs: |
| return [] |
| processed_args = [] |
| input_index = 0 |
| for output_index, arg in enumerate(args): |
| if output_index in skip_positions: |
| continue |
| if arg is None: |
| # We're calling a (non-polymorphic) ConcreteFunction, so we need to |
| # have a Tensor value for each Tensor we thought would be trainable |
| # based on its dtype, even if it ended up being unconnected. |
| input_placeholder = self._backward.inputs[ |
| input_index] |
| if input_placeholder.dtype == dtypes.variant: |
| arg = variant_zeros_like[output_index] |
| else: |
| arg = array_ops.zeros( |
| *default_gradient.shape_and_dtype(input_placeholder)) |
| processed_args.append(arg) |
| input_index += 1 |
| if input_index >= backward_function_inputs: |
| break |
| return self._backward._call_flat( # pylint: disable=protected-access |
| processed_args, remapped_captures) |
| |
| return _backward_function_wrapper, recorded_outputs |
| |
| |
| class _FirstOrderTapeGradientFunctions(_TapeGradientFunctions): |
| """Caches tape-friendly functions for first-order gradients.""" |
| |
| def __init__(self, func_graph, attrs, func_graph_deleter): |
| super(_FirstOrderTapeGradientFunctions, self).__init__( |
| func_graph, attrs, func_graph_deleter) |
| self._num_inference_outputs = len(func_graph.outputs) |
| self._func_graph_deleter = func_graph_deleter |
| |
| def _forward_and_backward_functions(self): |
| """Shortcut for when only first-order gradients are required. |
| |
| The returned backward function does not accept gradients with respect to |
| side output of forward_function. This is fine as long as the user can't |
| possibly request second order tape gradients, as when they've used a single |
| non-persistent GradientTape. Since we don't need the backward function to |
| take gradients with respect to side outputs, we can skip some potentially |
| slow graph building. |
| |
| Returns: |
| A tuple of (forward_function, backward_function): |
| forward_function: Takes the same inputs as the inference function, but |
| returns side outputs used by backward_function in addition to the |
| inference function's outputs. |
| backward_function: Takes side outputs from forward_function and |
| gradients with respect to the "real" outputs of forward_function and |
| returns gradients with respect to the inputs. |
| """ |
| outputs = self._func_graph.outputs[:self._num_inference_outputs] |
| return self._build_functions_for_outputs(outputs) |
| |
| |
| class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions): |
| """Caches tape-friendly functions for higher-order gradients.""" |
| |
| # TODO(b/136189779): Cond/while under a tape may need similar logic. Consider |
| # generalizing if so. |
| def _forward_and_backward_functions(self): |
| """Forward and backward functions suitable for higher-order gradients. |
| |
| Unlike in `_FirstOrderTapeGradientFunctions`, the backward function built by |
| this method accepts gradients for all of the outputs of the returned forward |
| function, including side outputs. |
| |
| Returns: |
| A tuple of (forward_function, backward_function): |
| forward_function: Takes the same inputs as the inference function, but |
| returns side outputs used by backward_function in addition to the |
| inference function's outputs. |
| backward_function: Takes side outputs from forward_function and |
| gradients with respect to all of its outputs, real and side. Returns |
| gradients with respect to the inputs. |
| """ |
| outputs = [] |
| # First we need to figure out how many side outputs from the forward pass |
| # will be required. We do this in a temporary graph to avoid actually |
| # running multiple copies of the backward pass (one per _GradientsHelper |
| # call). |
| # |
| # While computing gradients, the backward function captures Tensors from |
| # the forward function. We add these as side outputs of the original |
| # function. However, we then need to accept output gradients with respect |
| # to these side outputs for higher order gradients to work. Thus we loop |
| # until the number of outputs of the function stabilizes. Note that this |
| # is only required for tape gradients, where we need to declare in advance |
| # all of the forward op's outputs: symbolic gradients with tf.gradients |
| # instead rely on regenerating backward functions when higher-order |
| # gradients are requested. |
| while len(outputs) < len(self._func_graph.outputs): |
| new_outputs = self._func_graph.outputs[len(outputs):] |
| outputs = list(self._func_graph.outputs) |
| self._build_functions_for_outputs(new_outputs) |
| forward_function, backward_function = ( |
| self._build_functions_for_outputs(outputs)) |
| if len(self._func_graph.outputs) != len(outputs): |
| raise AssertionError( |
| ("Unexpectedly added new outputs to the forward function when " |
| "building the backward function: {}").format( |
| self._func_graph.outputs[len(outputs):])) |
| return forward_function, backward_function |
| |
| |
| class _PossibleTapeGradientTypes(enum.Enum): |
| """Represents the output of TFE_Py_TapeSetPossibleGradientTypes.""" |
| NONE = 0 |
| FIRST_ORDER = 1 |
| HIGHER_ORDER = 2 |
| |
| |
| class ConcreteFunction(object): |
| """Callable object encapsulating a function definition and its gradient. |
| |
| `ConcreteFunction` is a callable that encapsulates a function definition and |
| is differentiable under `tf.GradientTape` objects. |
| """ |
| |
| def __init__(self, func_graph, attrs=None, signature=None, |
| shared_func_graph=True): |
| """Initialize a `ConcreteFunction`. |
| |
| Args: |
| func_graph: An instance of FuncGraph: the function body to wrap. |
| attrs: (optional) dict mapping names of attributes to their AttrValue |
| values. Attributes in `attrs` will be included in this function's |
| definition. |
| signature: a nested sequence of `TensorSpec` objects specifying the input |
| signature of this function. |
| shared_func_graph: If False, the ConcreteFunction takes ownership of |
| `func_graph` and will break reference cycles when it is deleted. This |
| makes the FuncGraph inoperable. |
| |
| Raises: |
| ValueError: If number of input_placeholders is not equal to the number |
| of function inputs. |
| """ |
| self._arg_keywords = None |
| self._num_positional_args = None |
| self._func_graph = func_graph |
| self._captured_inputs = self._func_graph.external_captures |
| self._captured_closures = self._func_graph.deferred_external_captures |
| self._output_shapes = tuple( |
| output.shape for output in self._func_graph.outputs) |
| attrs = _parse_func_attrs(attrs or {}) |
| self._signature = signature |
| |
| if shared_func_graph: |
| self._garbage_collector = None |
| else: |
| self._garbage_collector = ConcreteFunctionGarbageCollector( |
| func_graph) |
| |
| # Pairs of forward and backward functions used for computing gradients. |
| # |
| # These each get a reference to the FuncGraph deleter since they use the |
| # FuncGraph directly. |
| self._delayed_rewrite_functions = _DelayedRewriteGradientFunctions( |
| func_graph, attrs, self._garbage_collector) |
| self._first_order_tape_functions = _FirstOrderTapeGradientFunctions( |
| func_graph, attrs, self._garbage_collector) |
| self._higher_order_tape_functions = _HigherOrderTapeGradientFunctions( |
| func_graph, attrs, self._garbage_collector) |
| |
| def __call__(self, *args, **kwargs): |
| """Executes the wrapped function. |
| |
| Args: |
| *args: Tensors or Variables. Positional arguments are only accepted when |
| they correspond one-to-one with arguments of the traced Python function. |
| **kwargs: Tensors or Variables specified by name. When |
| `get_concrete_function` was called to create this `ConcreteFunction`, |
| each Tensor input was given a name, defaulting to the name of the Python |
| function's argument but possibly overridden by the `name=` argument to |
| `tf.TensorSpec`. These names become the argument names for the concrete |
| function. |
| |
| Returns: |
| The result of applying the TF function on the given Tensors. |
| |
| Raises: |
| AssertionError: If this `ConcreteFunction` was not created through |
| `get_concrete_function`. |
| ValueError: If arguments contains anything other than Tensors or |
| Variables. |
| TypeError: For invalid positional/keyword argument combinations. |
| """ |
| return self._call_impl(args, kwargs) |
| |
| def _call_impl(self, args, kwargs, cancellation_manager=None): |
| """See `__call__` for details.""" |
| if self._arg_keywords is None or self._num_positional_args is None: |
| if self._signature is not None: |
| if kwargs: |
| raise NotImplementedError( |
| "Keyword arguments not supported when calling a " |
| "wrap_function-decorated function.") |
| return self._call_flat(args, self.captured_inputs) |
| raise AssertionError( |
| "Tried to call a concrete function obtained from an internal API " |
| "through the public interface. Use get_concrete_function instead.") |
| if len(args) > self._num_positional_args: |
| raise TypeError( |
| ("Expected at most {} positional arguments (and the rest keywords, " |
| "of {}), got {}. When calling a concrete function, positional " |
| "arguments may not be bound to Tensors within nested structures." |
| ).format(self._num_positional_args, self._arg_keywords, args)) |
| args = list(args) |
| for keyword in self._arg_keywords[len(args):]: |
| try: |
| args.append(kwargs.pop(compat.as_str(keyword))) |
| except KeyError: |
| specified_keywords = (list(self._arg_keywords[:len(args)]) |
| + list(kwargs.keys())) |
| raise TypeError( |
| "Expected argument names {} but got values for {}. Missing: {}." |
| .format( |
| list(self._arg_keywords), |
| specified_keywords, |
| list(set(self._arg_keywords) - set(specified_keywords)))) |
| if kwargs: |
| positional_arg_keywords = set(self._arg_keywords[:len(args)]) |
| for unused_key in kwargs: |
| if unused_key in positional_arg_keywords: |
| raise TypeError("Got two values for keyword '{}'.".format(unused_key)) |
| raise TypeError("Keyword arguments {} unknown. Expected {}.".format( |
| list(kwargs.keys()), list(self._arg_keywords))) |
| return self._call_flat(args, self.captured_inputs, cancellation_manager) |
| |
| def _filtered_call(self, args, kwargs): |
| """Executes the function, filtering arguments from the Python function. |
| |
| Objects aside from Tensors, CompositeTensors, and Variables are ignored. |
| CompositeTensors are expanded into their components. |
| |
| Args: |
| args: Canonicalized positional arguments of the Python function. |
| kwargs: Canonicalized keyword arguments of the Python function. |
| |
| Returns: |
| The result of applying the function on the Tensors/Variables contained in |
| `args` and `kwargs`. |
| """ |
| return self._call_flat( |
| (t for t in nest.flatten((args, kwargs), expand_composites=True) |
| if isinstance(t, (ops.Tensor, |
| resource_variable_ops.BaseResourceVariable))), |
| self.captured_inputs) |
| |
| def _call_flat(self, args, captured_inputs, cancellation_manager=None): |
| """Executes the wrapped function. |
| |
| Args: |
| args: a list of Tensors or Variables. Any CompositeTensors should be |
| expanded before calling this method. |
| captured_inputs: the captured inputs that are also part of the input args |
| to the actual execution. By default, it should be self._captured_inputs. |
| cancellation_manager: (Optional.) A `CancellationManager` that can be |
| used to cancel function invocation. |
| |
| Returns: |
| The result of applying the TF function to `args`. |
| |
| Raises: |
| ValueError: If `args` contains anything other than Tensors or Variables. |
| """ |
| args = list(args) |
| ctx = context.context() |
| executing_eagerly = ctx.executing_eagerly() |
| |
| # Copy saveable status of function's graph to current FuncGraph. |
| default_graph = ops.get_default_graph() |
| if default_graph.building_function and not self._func_graph.saveable: |
| default_graph.mark_as_unsaveable(self._func_graph.saving_errors) |
| |
| if any(isinstance(a, composite_tensor.CompositeTensor) for a in args): |
| raise AssertionError("Expected all args to be Tensors or Variables; " |
| "but got CompositeTensor: %r" % args) |
| |
| if (tape.could_possibly_record() or |
| hasattr(ops.get_default_graph(), "watch_variable")): |
| for v in self._func_graph.variables: |
| resource_variable_ops.variable_accessed(v) |
| |
| tensor_inputs = [] |
| variables_used = object_identity.ObjectIdentitySet([]) |
| for i, arg in enumerate(args): |
| if isinstance(arg, resource_variable_ops.BaseResourceVariable): |
| # We can pass a variable more than once, and in this case we need to |
| # pass its handle only once. |
| if arg.handle in variables_used: |
| continue |
| resource_variable_ops.variable_accessed(arg) |
| tensor_inputs.append(arg.handle) |
| variables_used.add(arg.handle) |
| elif isinstance(arg, ops.Tensor): |
| tensor_inputs.append(arg) |
| if not executing_eagerly: |
| # If we're graph building, shape inference is on. We check for input |
| # compatibility up front to avoid hard to debug incompatibilities |
| # later. |
| graph_input_shape = tensor_shape.TensorShape( |
| self._func_graph.inputs[i].shape) |
| if not graph_input_shape.is_compatible_with(arg.shape): |
| if self._arg_keywords: |
| arg_name = "'{}'".format(self._arg_keywords[i]) |
| else: |
| arg_name = "with index {}".format(i) |
| raise ValueError( |
| ("The argument {} (value {}) is not compatible with the shape " |
| "this function was traced with. Expected shape {}, but got " |
| "shape {}.\n\nIf you called get_concrete_function, you may " |
| "need to pass a tf.TensorSpec(..., shape=...) with a less " |
| "specific shape, having None on axes which can vary.").format( |
| arg_name, arg, |
| self._func_graph.inputs[i].shape, |
| arg.shape)) |
| elif (self._signature is not None and |
| isinstance(self._signature[i], tensor_spec.TensorSpec)): |
| tensor_inputs.append( |
| ops.convert_to_tensor(arg, self._signature[i].dtype)) |
| else: |
| raise ValueError("All inputs to `ConcreteFunction`s must be Tensors; " |
| "on invocation of %s, the %d-th input (%s) was not a " |
| "Tensor." % (self._func_graph.name, i, str(arg))) |
| args = tensor_inputs + captured_inputs |
| forward_backward = self._select_forward_and_backward_functions(args) |
| forward_function = forward_backward.forward |
| if executing_eagerly: |
| flat_outputs = forward_function.call( |
| ctx, args, cancellation_manager=cancellation_manager) |
| else: |
| gradient_name = self._delayed_rewrite_functions.register() |
| with ops.get_default_graph().gradient_override_map( |
| {"PartitionedCall": gradient_name, |
| "StatefulPartitionedCall": gradient_name}): |
| flat_outputs = forward_function.call(ctx, args) |
| if isinstance(flat_outputs, ops.Operation) or flat_outputs is None: |
| # We only record function calls which have outputs. |
| return self._build_call_outputs(flat_outputs) |
| backward_function, to_record = forward_backward.backward(flat_outputs) |
| tape.record_operation(forward_function.signature.name, |
| to_record, args, backward_function) |
| return self._build_call_outputs(flat_outputs) |
| |
| def _experimental_with_cancellation_manager(self, cancellation_manager): |
| """Returns a callable that invokes a cancelable version of this function. |
| |
| Args: |
| cancellation_manager: A `CancellationManager` object that can be used to |
| cancel function invocation. |
| |
| Returns: |
| A callable with the same signature as this concrete function. |
| """ |
| |
| def cancellable_call(*args, **kwargs): |
| return self._call_impl( |
| args, kwargs, cancellation_manager=cancellation_manager) |
| |
| return cancellable_call |
| |
| @property |
| def name(self): |
| """`ConcreteFunction` name.""" |
| return self._delayed_rewrite_functions.forward.name |
| |
| @property |
| def graph(self): |
| """Returns the graph from which this function was constructed.""" |
| return self._func_graph |
| |
| @property |
| def inputs(self): |
| """Returns tensors in `self.graph` corresponding to arguments.""" |
| return self._func_graph.inputs |
| |
| @property |
| def structured_input_signature(self): |
| """Returns structured signature of the original function.""" |
| return self._func_graph.structured_input_signature |
| |
| @property |
| def outputs(self): |
| """Returns tensors in `self.graph` corresponding to returned tensors.""" |
| return self._func_graph.outputs |
| |
| @property |
| def structured_outputs(self): |
| """Returns outputs in `self.graph` as returned by the original function.""" |
| return self._func_graph.structured_outputs |
| |
| @property |
| def captured_inputs(self): |
| """Returns external Tensors captured by this function. |
| |
| self.__call__(*args) passes `args + self.captured_inputs` to the function. |
| """ |
| from_closures = nest.flatten([x() for x in self._captured_closures], |
| expand_composites=True) |
| return self._captured_inputs + from_closures |
| |
| @property |
| def function_def(self): |
| """Returns a `FunctionDef` object representing this function.""" |
| return self._delayed_rewrite_functions.forward.definition |
| |
| @property |
| def output_shapes(self): |
| """The function's output shapes.""" |
| return nest.map_structure( |
| lambda x: getattr(x, "shape", tensor_shape.TensorShape(None)), |
| composite_tensor.replace_composites_with_components( |
| self._func_graph.structured_outputs), |
| expand_composites=False) |
| |
| @property |
| def output_dtypes(self): |
| # TODO(akshayka): Consider removing this. |
| return nest.map_structure( |
| lambda x: x.dtype if x is not None else None, |
| composite_tensor.replace_composites_with_components( |
| self._func_graph.structured_outputs), |
| expand_composites=False) |
| |
| def add_to_graph(self, g=None): |
| """Registers the function, adds it to the graph g or default graph. |
| |
| Args: |
| g: If specified, registers the function with this graph. Defaults to the |
| current context (either the default graph or the eager context). |
| """ |
| # If we are not executing eagerly, adds the function to default graph if no |
| # graph is specified. |
| # In case of eager execution, function definition gets added to context |
| # during construction itself. |
| |
| if not context.executing_eagerly() and not g: |
| g = ops.get_default_graph() |
| self._delayed_rewrite_functions.forward.add_to_graph(g) |
| |
| def add_gradient_functions_to_graph(self, g=None): |
| """Add forward/backward functions to graph `g` or the current context.""" |
| if not context.executing_eagerly() and not g: |
| g = ops.get_default_graph() |
| self._delayed_rewrite_functions.forward.add_to_graph(g) |
| forward_function, backward_function = ( |
| self._delayed_rewrite_functions.forward_backward()) |
| forward_function.add_to_graph(g) |
| backward_function.add_to_graph(g) |
| |
| def _register_delayed_rewrite_gradient(self): |
| """Registers a delayed-rewrite gradient function and returns the name.""" |
| return self._delayed_rewrite_functions.register() |
| |
| def _select_forward_and_backward_functions(self, args): |
| """Selects forward and backward functions based on the calling context. |
| |
| The forward function computes the "real" function outputs, `self._outputs`, |
| and any extra values needed by the corresponding backward function. |
| |
| Args: |
| args: A flat list of Tensors with all of the inputs to the forward |
| function (including user-specified and captured inputs). |
| |
| Returns: |
| An object with a `forward` property containing an _EagerDefinedFunction, |
| and a corresponding `backward` method which takes outputs from the forward |
| function and returns a backward function. |
| """ |
| possible_gradient_type = _PossibleTapeGradientTypes( |
| pywrap_tensorflow.TFE_Py_TapeSetPossibleGradientTypes(args)) |
| if possible_gradient_type == _PossibleTapeGradientTypes.FIRST_ORDER: |
| if context.executing_eagerly(): |
| # There is a single non-persistent tape active, so the user can only |
| # request first-order gradients from a tape. We can spend less time |
| # graph building since we know this. |
| # |
| # We may still end up computing higher-order gradients, but that'd be |
| # through `tf.gradients`, which can re-write the forward pass and so |
| # needs no preparation here. |
| return self._first_order_tape_functions |
| else: |
| # We can avoid computing second-order gradients in some cases by doing a |
| # delayed rewrite when graph building. Since we know we'll only compute |
| # first-order tape gradients, the delayed rewrite is safe: we won't need |
| # to tell the tape about side outputs. |
| # |
| # TODO(allenl): This case is really dirty. It would be better if we |
| # could temporarily pop all of the current tapes to avoid |
| # accidentally taking second-order gradients. |
| return self._delayed_rewrite_functions |
| elif possible_gradient_type == _PossibleTapeGradientTypes.HIGHER_ORDER: |
| # Either there's a persistent tape watching, or there are multiple nested |
| # tapes. Either way, the user may request higher-order gradients. We'll |
| # spend a bit more time and make sure higher-order gradients are correct. |
| return self._higher_order_tape_functions |
| # else possible_gradient_type == _PossibleTapeGradientTypes.NONE, meaning no |
| # tape is recording. |
| return self._delayed_rewrite_functions |
| |
| def _build_call_outputs(self, result): |
| """Maps the fdef output list to actual output structure. |
| |
| Args: |
| result: Output lists defined by FunctionDef. |
| Returns: |
| The actual call output. |
| """ |
| if self._func_graph.structured_outputs is None: |
| return result |
| |
| # Replace outputs with results, skipping over any 'None' values. |
| outputs_list = nest.flatten(self._func_graph.structured_outputs, |
| expand_composites=True) |
| j = 0 |
| for i, o in enumerate(outputs_list): |
| if o is not None: |
| outputs_list[i] = result[j] |
| j += 1 |
| ret = nest.pack_sequence_as(self._func_graph.structured_outputs, |
| outputs_list, expand_composites=True) |
| return ret |
| |
| |
| pywrap_tensorflow.RegisterType("Tensor", ops.Tensor) |
| pywrap_tensorflow.RegisterType("IndexedSlices", ops.IndexedSlices) |
| |
| |
| def _deterministic_dict_values(dictionary): |
| return tuple(dictionary[key] for key in sorted(dictionary)) |
| |
| |
| class FunctionSpec(object): |
| """Specification of how to bind arguments to a function.""" |
| |
| @staticmethod |
| def from_function_and_signature(python_function, input_signature): |
| """Create a FunctionSpec instance given a python function and signature.""" |
| fullargspec = tf_inspect.getfullargspec(python_function) |
| # Treat a wrapped partial function as a special case. For all arguments that |
| # were overridden with keywords in the partial: |
| # - remove the corresponding arguments, |
| # - remove the corresponding keywords. |
| _, unwrapped = tf_decorator.unwrap(python_function) |
| # TODO(b/131153379): Consider Python3's fullargspec.kwonlyargs and |
| # fullargspec.kwonlydefaults. |
| if isinstance(unwrapped, functools.partial): |
| # Also consider the Python3 case with kwonlydefaults. |
| if fullargspec.defaults or fullargspec.kwonlydefaults: |
| new_defaults = fullargspec.defaults |
| new_args = fullargspec.args |
| if fullargspec.defaults: |
| # To be able to canonicalize the function properly, we want to ignore |
| # default values that are overridden via a partial kwarg. For example: |
| # |
| # def func(a, b, c, d=5, e=7): |
| # return a, b, c, d, e |
| # p_func = functools.partial(tf.function(func, 10, e=9)) |
| # |
| # Here we want to drop from the defaults the parameter `e`. If we |
| # forwarded the call to the partial function with a default for `e` |
| # we would get an error for passing two values for one parameter. |
| # |
| # Note that this has a limitation: we can only override parameters at |
| # the end of the parameter list. |
| # |
| # In this case we want to end up with 3 arguments (b, c, d) and 1 |
| # default value (5). We do this by constructing a mask where 0 stands |
| # for a value that was overridden by a partial kwarg. The seemingly |
| # complicated logic below does just that - for arguments (b, c, d, e) |
| # we would get a mask (1, 1, 1, 0). |
| old_args = fullargspec.args |
| old_defaults = fullargspec.defaults |
| |
| no_default = object() |
| num_args_without_defaults = len(old_args) - len(old_defaults) |
| left_padding = tuple([no_default] * num_args_without_defaults) |
| |
| args_with_defaults = zip(old_args, left_padding + old_defaults) |
| |
| # Create a mask where 0 stands for args that had a partial kwarg |
| # defined. |
| non_keyword_defaults_mask = [ |
| 0 if key in unwrapped.keywords else 1 for key in old_args |
| ] |
| # Keep only arguments and defaults that were not kwargs of partial. |
| new_args_with_defaults = list( |
| itertools.compress(args_with_defaults, non_keyword_defaults_mask)) |
| # Keep all args. |
| new_args = [arg for arg, _ in new_args_with_defaults] |
| # Keep only real default values. |
| new_defaults = [ |
| default for _, default in new_args_with_defaults |
| if default is not no_default |
| ] |
| fullargspec = tf_inspect.FullArgSpec( |
| args=new_args, |
| varargs=fullargspec.varargs, |
| varkw=fullargspec.varkw, |
| defaults=new_defaults, |
| kwonlyargs=[], |
| kwonlydefaults={}, |
| annotations=fullargspec.annotations) |
| is_method = tf_inspect.ismethod(python_function) |
| return FunctionSpec(fullargspec, is_method, [], {}, input_signature) |
| |
| def __init__(self, fullargspec, is_method, args_to_prepend, kwargs_to_include, |
| input_signature): |
| self._fullargspec = fullargspec |
| self._is_method = is_method |
| del args_to_prepend |
| del kwargs_to_include |
| self._default_values = fullargspec.defaults |
| |
| if self._is_method: |
| # Remove `self`: default arguments shouldn't be matched to it. |
| # TODO(b/127938157): Should this error out if there is no arg to |
| # be removed? |
| args = fullargspec.args[1:] |
| else: |
| args = fullargspec.args |
| |
| # A cache mapping from argument name to index, for canonicalizing |
| # arguments that are called in a keyword-like fashion. |
| self._args_to_indices = {arg: i for i, arg in enumerate(args)} |
| self.arg_names = args |
| self.vararg_name = fullargspec.varargs |
| |
| # A cache mapping from arg index to default value, for canonicalization. |
| offset = len(args) - len(self._default_values or []) |
| self._arg_indices_to_default_values = { |
| offset + index: default |
| for index, default in enumerate(self._default_values or []) |
| } |
| if input_signature is None: |
| self._input_signature = None |
| else: |
| if fullargspec.kwonlyargs: |
| raise ValueError("Cannot define a TensorFlow function from a Python " |
| "function with keyword arguments when " |
| "input_signature is provided.") |
| |
| if not isinstance(input_signature, (tuple, list)): |
| raise TypeError("input_signature must be either a tuple or a " |
| "list, received " + str(type(input_signature))) |
| |
| self._input_signature = tuple(input_signature) |
| self._flat_input_signature = tuple(nest.flatten(input_signature, |
| expand_composites=True)) |
| |
| @property |
| def fullargspec(self): |
| return self._fullargspec |
| |
| @property |
| def is_method(self): |
| return self._is_method |
| |
| @property |
| def args_to_prepend(self): |
| return self._args_to_prepend |
| |
| @property |
| def kwargs_to_include(self): |
| return self._kwargs_to_include |
| |
| @property |
| def input_signature(self): |
| return self._input_signature |
| |
| @property |
| def flat_input_signature(self): |
| return self._flat_input_signature |
| |
| def canonicalize_function_inputs(self, *args, **kwargs): |
| """Canonicalizes `args` and `kwargs`. |
| |
| Canonicalize the inputs to the Python function using a `FunctionSpec` |
| instance. In particular, we parse the varags and kwargs that the |
| original function was called with into a tuple corresponding to the |
| Python function's positional (named) arguments and a dictionary |
| corresponding to its kwargs. |
| |
| Args: |
| *args: The varargs this object was called with. |
| **kwargs: The keyword args this function was called with. |
| |
| Returns: |
| A canonicalized ordering of the inputs representened by a tuple in the |
| form (args, kwargs). Here: `args` is a full list of bound arguments, and |
| `kwargs` contains only true keyword arguments, as opposed to named |
| arguments called in a keyword-like fashion. |
| |
| Raises: |
| ValueError: If a keyword in `kwargs` cannot be matched with a positional |
| argument when an input signature is specified, or when the inputs |
| do not conform to the input signature. |
| """ |
| if self._input_signature is not None: |
| if len(args) > len(self._input_signature): |
| raise TypeError( |
| "When input_signature is provided, only pass arguments " |
| "covered by it. Received %d argument(s)." % len(args)) |
| for arg in six.iterkeys(kwargs): |
| index = self._args_to_indices.get(arg, None) |
| if index is None: |
| raise TypeError( |
| "Function got an unexpected keyword argument %s" % arg) |
| if index >= len(self._input_signature): |
| raise TypeError( |
| "When input_signature is provided, only pass arguments " |
| "covered by it. Received argument %s." % arg) |
| |
| if not kwargs: |
| inputs = args |
| default_keys = sorted(self._arg_indices_to_default_values.keys()) |
| if default_keys: |
| assert min(default_keys) <= len( |
| args), "Not enough arguments (%s, %s, %s)" % (args, default_keys, |
| self.arg_names) |
| for index in default_keys: |
| if index >= len(args): |
| inputs += (self._arg_indices_to_default_values[index],) |
| else: |
| # Maps from index of arg to its corresponding value, according to `args` |
| # and `kwargs`; seeded with the default values for the named args that |
| # aren't in `args`. |
| arg_indices_to_values = { |
| index: default for index, default in six.iteritems( |
| self._arg_indices_to_default_values) if index >= len(args) |
| } |
| consumed_args = [] |
| for arg, value in six.iteritems(kwargs): |
| index = self._args_to_indices.get(arg, None) |
| if index is not None: |
| arg_indices_to_values[index] = value |
| consumed_args.append(arg) |
| elif self._input_signature is not None: |
| raise ValueError("Cannot define a TensorFlow function from a Python " |
| "function with keyword arguments when " |
| "input_signature is provided.") |
| for arg in consumed_args: |
| # After this loop, `kwargs` will only contain true keyword arguments, as |
| # opposed to named arguments called in a keyword-like fashion. |
| kwargs.pop(arg) |
| inputs = args + _deterministic_dict_values(arg_indices_to_values) |
| |
| if self._input_signature is None: |
| inputs = _convert_numpy_inputs(inputs) |
| return inputs, kwargs |
| else: |
| assert not kwargs |
| inputs = _convert_inputs_to_signature( |
| inputs, |
| self._input_signature, |
| self._flat_input_signature) |
| return inputs, {} |
| |
| |
| def _convert_numpy_inputs(inputs): |
| """Convert numpy array inputs to tensors.""" |
| # We assume that any CompositeTensors have already converted their components |
| # from numpy arrays to Tensors, so we don't need to expand composites here. |
| flat_inputs = nest.flatten(inputs, expand_composites=False) |
| |
| # Check for NumPy arrays in arguments and convert them to Tensors. |
| # TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps |
| # finding a way to store them directly in the cache key (currently not |
| # possible since ndarrays are not hashable). |
| need_packing = False |
| for index, value in enumerate(flat_inputs): |
| if type(value) == np.ndarray: |
| flat_inputs[index] = constant_op.constant(value) |
| need_packing = True |
| if need_packing: |
| return nest.pack_sequence_as( |
| structure=inputs, flat_sequence=flat_inputs, expand_composites=False) |
| else: |
| return inputs |
| |
| |
| def _convert_inputs_to_signature(inputs, input_signature, flat_input_signature): |
| """Convert inputs to pass into a function with an explicit signature.""" |
| |
| def format_error_message(inputs, input_signature): |
| return (" inputs: (\n" + " " + |
| ",\n ".join([str(i) for i in inputs]) + ")\n" + |
| " input_signature: (\n" + " " + |
| ",\n ".join([str(i) for i in input_signature]) + ")") |
| |
| try: |
| # TODO(b/124370185): Use all elements as inputs to throw an error if there |
| # are ignored arguments. Calling with arguments that are not part of the |
| # signature should throw an error. |
| flatten_inputs = nest.flatten_up_to( |
| input_signature, |
| inputs[:len(input_signature)], |
| expand_composites=True) |
| except ValueError: |
| raise ValueError("Structure of Python function inputs does not match " |
| "input_signature:\n%s" % |
| format_error_message(inputs, input_signature)) |
| |
| need_packing = False |
| for index, (value, spec) in enumerate(zip(flatten_inputs, |
| flat_input_signature)): |
| if not pywrap_tensorflow.IsTensor(value): |
| try: |
| flatten_inputs[index] = ops.convert_to_tensor( |
| value, dtype_hint=spec.dtype) |
| need_packing = True |
| except ValueError: |
| raise ValueError("When input_signature is provided, all inputs to " |
| "the Python function must be convertible to " |
| "tensors:\n%s" % |
| format_error_message(inputs, input_signature)) |
| |
| if any(not spec.is_compatible_with(other) for spec, other in zip( |
| flat_input_signature, |
| flatten_inputs)): |
| raise ValueError("Python inputs incompatible with input_signature:\n%s" % |
| format_error_message(inputs, input_signature)) |
| |
| if need_packing: |
| inputs = nest.pack_sequence_as( |
| structure=input_signature, |
| flat_sequence=flatten_inputs, |
| expand_composites=True) |
| |
| return inputs |
| |
| |
| class FunctionCache(object): |
| """A lightweight container for cached functions. |
| """ |
| |
| def __init__(self): |
| # The set of functions that have been missed; entries are CacheKey with |
| # input_signature `None` (e.g. a "call context key") |
| self.missed = set() |
| # The primary cache, mapping a fully shaped CacheKey to a function. |
| self.primary = collections.OrderedDict() |
| # A cache key lookup, mapping a CacheKey generated without shape info to a |
| # flat list of relaxed shapes (one for each argument). Arguments that are |
| # not Tensors contain a `None` for the corresponding relaxed shape. |
| self.arg_relaxed_shapes = collections.OrderedDict() |
| # The secondary cache, mapping a CacheKey generated without shape info to a |
| # function. |
| self.arg_relaxed = collections.OrderedDict() |
| # All OrderedDicts require manual garbage collection. |
| self._garbage_collectors = [ |
| _FunctionGarbageCollector(self.primary), |
| _FunctionGarbageCollector(self.arg_relaxed), |
| _FunctionGarbageCollector(self.arg_relaxed_shapes)] |
| |
| def all_values(self): |
| """A set of all `ConcreteFunction` instances held by this cache.""" |
| return set(self.primary.values()) | set(self.arg_relaxed.values()) |
| |
| |
| class Function(object): |
| """Wrapper class for the graph functions defined for a Python function. |
| |
| See the documentation for `defun` for more information on the semantics of |
| defined functions. |
| |
| `Function` class is thread-compatible meaning that minimal usage of defuns |
| (defining and calling) is thread-safe, but if users call other methods or |
| invoke the base `python_function` themselves, external synchronization is |
| necessary. |
| """ |
| |
| def __init__(self, |
| python_function, |
| name, |
| input_signature=None, |
| attributes=None, |
| autograph=True, |
| autograph_options=None, |
| experimental_relax_shapes=False, |
| capture_by_value=None): |
| """Initializes a `Function`. |
| |
| Args: |
| python_function: the function to be wrapped. |
| name: the name given to it. |
| input_signature: a possibly nested sequence of `TensorSpec` objects |
| specifying the input signature of this function. If `None`, a separate |
| function is instantiated for each inferred input signature. |
| attributes: dict, extra keyword arguments that will be added as attribute |
| of the function. |
| autograph: whether to use autograph to compile |
| `python_function`. See https://www.tensorflow.org/guide/autograph for |
| more information. |
| autograph_options: Experimental knobs to control behavior |
| `when autograph=True`. See https://www.tensorflow.org/guide/autograph |
| for more information. |
| experimental_relax_shapes: When true, argument shapes may be relaxed to |
| avoid unecessary retracing. |
| capture_by_value: Experimental. Whether to capture resource variables by |
| value or reference. If None, will inherit from a parent context or |
| default to False. |
| |
| Raises: |
| ValueError: if `input_signature` is not None and the `python_function`'s |
| argspec has keyword arguments. |
| """ |
| self._python_function = python_function |
| self._function_spec = FunctionSpec.from_function_and_signature( |
| python_function, input_signature) |
| self._name = name |
| self._autograph = autograph |
| self._autograph_options = autograph_options |
| self._experimental_relax_shapes = experimental_relax_shapes |
| self._function_cache = FunctionCache() |
| self._function_attributes = attributes or {} |
| self._capture_by_value = capture_by_value |
| |
| self._lock = threading.Lock() |
| # _descriptor_cache is a of instance of a class to an instance-specific |
| # `Function`, used to make sure defun-decorated methods create different |
| # functions for each instance. |
| self._descriptor_cache = weakref.WeakKeyDictionary() |
| |
| def __call__(self, *args, **kwargs): |
| """Calls a graph function specialized to the inputs.""" |
| graph_function, args, kwargs = self._maybe_define_function(args, kwargs) |
| return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access |
| |
| @property |
| def python_function(self): |
| """Returns the wrapped Python function.""" |
| return self._python_function # pylint: disable=protected-access |
| |
| @property |
| def function_spec(self): |
| return self._function_spec |
| |
| @property |
| def input_signature(self): |
| """Returns the input signature.""" |
| return self._function_spec.input_signature |
| |
| @property |
| def flat_input_signature(self): |
| """Returns the flattened input signature.""" |
| return self._function_spec.flat_input_signature |
| |
| def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs): |
| """Returns a concrete function which cleans up its graph function.""" |
| if self.input_signature: |
| args, kwargs = None, None |
| graph_function, _, _ = self._maybe_define_function(args, kwargs) |
| return graph_function |
| |
| def _get_concrete_function_internal(self, *args, **kwargs): |
| """Bypasses error checking when getting a graph function.""" |
| graph_function = self._get_concrete_function_internal_garbage_collected( |
| *args, **kwargs) |
| # We're returning this concrete function to someone, and they may keep a |
| # reference to the FuncGraph without keeping a reference to the |
| # ConcreteFunction object. So we won't clean up the reference cycles |
| # manually and instead will leave them to Python's garbage collector. |
| graph_function._garbage_collector.release() # pylint: disable=protected-access |
| return graph_function |
| |
| def get_concrete_function(self, *args, **kwargs): |
| """Returns a `ConcreteFunction` specialized to inputs and execution context. |
| |
| Args: |
| *args: inputs to specialize on. |
| **kwargs: inputs to specialize on. |
| """ |
| if self.input_signature: |
| if kwargs: |
| raise ValueError("Cannot define a TensorFlow function from a Python " |
| "function with keyword arguments when " |
| "input_signature is provided.") |
| if args: |
| # If args are provided, they must match the input signature. |
| if not is_same_structure(self.input_signature, args): |
| raise ValueError("Structure of Python function inputs does not match " |
| "input_signature.") |
| flat_inputs = nest.flatten(args, expand_composites=True) |
| if any(not isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)) |
| for arg in flat_inputs): |
| raise ValueError("When input_signature is provided, all inputs to " |
| "the Python function must be Tensors or " |
| "tf.TensorSpec objects.") |
| if any(not spec.is_compatible_with(other) |
| for spec, other in zip(self.flat_input_signature, flat_inputs)): |
| raise ValueError("Python inputs incompatible with input_signature: " |
| "inputs (%s), input_signature (%s)" % |
| (str(args), str(self.input_signature))) |
| args, kwargs = None, None |
| graph_function, args, kwargs = self._maybe_define_function(args, kwargs) |
| if self.input_signature: |
| args = self.input_signature |
| kwargs = {} |
| seen_names = set() |
| captured = object_identity.ObjectIdentitySet( |
| graph_function.graph.internal_captures) |
| # pylint: disable=protected-access |
| graph_function._arg_keywords = [] |
| prefix_counts = {} |
| # pylint: enable=protected-access |
| num_positional = 0 |
| for arg in graph_function.graph.inputs: |
| if arg in captured: |
| break |
| num_positional += 1 |
| user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name")) |
| proposal = user_arg_name |
| while proposal in seen_names: |
| index = prefix_counts.get(user_arg_name, 1) |
| proposal = "{}_{}".format(user_arg_name, index) |
| prefix_counts[user_arg_name] = index + 1 |
| seen_names.add(proposal) |
| graph_function._arg_keywords.append(proposal) # pylint: disable=protected-access |
| # Anything can be a positional argument, in the same order as .inputs |
| graph_function._num_positional_args = num_positional # pylint: disable=protected-access |
| return graph_function |
| |
| def __get__(self, instance, owner): |
| """Makes it possible to defun instance methods.""" |
| del owner |
| # `instance` here is the instance that this `Function` was accessed through |
| # e.g., for |
| # |
| # class Foo(object): |
| # |
| # @function.defun |
| # def bar(self): |
| # ... |
| # |
| # foo = Foo() |
| # foo.bar() # `foo.bar` is a `Function` instance |
| # |
| # then `instance` will be `foo` (and `owner` will be `Foo`). We create a |
| # new instance of `Function` here to allow different instances each |
| # to create variables once, thereby allowing methods to be decorated with |
| # defun. Keeps a cache to avoid retracing the function every time the |
| # descriptor is accessed. |
| if instance not in self._descriptor_cache: |
| if instance is None: |
| return self |
| # If there is no instance-specific `Function` in the cache, we construct |
| # an instance-specific `Function` that uses a weak reference to the |
| # instance (so that the instance will be correctly gc'd). |
| |
| # And finally add the wrapped function to the description cache |
| self._descriptor_cache[instance] = class_method_to_instance_method( |
| self, instance) |
| |
| # Return the cached `Function` for the instance |
| return self._descriptor_cache[instance] |
| |
| def _cache_key(self, args, kwargs, include_tensor_ranks_only=False): |
| """Computes the cache key given inputs and execution context.""" |
| if self.input_signature is None: |
| inputs = (args, kwargs) if kwargs else args |
| input_signature = pywrap_tensorflow.TFE_Py_EncodeArg( |
| inputs, include_tensor_ranks_only) |
| else: |
| del args, kwargs |
| assert not include_tensor_ranks_only |
| input_signature = self.flat_input_signature |
| |
| ctx = context.context() |
| |
| # Don't need to open an init_scope if the _cache_key call is in eager mode |
| # already. |
| executing_eagerly = ctx.executing_eagerly() |
| parent_graph = None |
| if not executing_eagerly: |
| with ops.init_scope(): |
| # The graph, or whether we're executing eagerly, should be a part of the |
| # cache key so we don't improperly capture tensors such as variables. |
| executing_eagerly = ctx.executing_eagerly() |
| parent_graph = None if executing_eagerly else ops.get_default_graph() |
| |
| # pylint: disable=protected-access |
| default_graph = ops.get_default_graph() |
| # TODO(b/117617952): The current distribution strategy will affect graph |
| # building (e.g. accessing different variables from different devices) and |
| # so requires retracing for each device. |
| strategy_stack = default_graph._distribution_strategy_stack |
| uses_distribution_strategy = ( |
| strategy_stack and |
| strategy_stack[-1].strategy.extended._retrace_functions_for_each_device |
| ) |
| if executing_eagerly: |
| colocation_stack = () |
| if uses_distribution_strategy: |
| device_functions = (pydev.merge_device(ctx.device_name),) |
| else: |
| device_functions = () |
| else: |
| colocation_stack = tuple(default_graph._colocation_stack.peek_objs()) |
| if (uses_distribution_strategy |
| or func_graph_module.device_stack_has_callable( |
| default_graph._device_function_stack)): |
| # Putting the device in the cache key ensures that call-site device |
| # annotations are respected. |
| device_functions = tuple(default_graph._device_functions_outer_to_inner) |
| else: |
| device_functions = () |
| |
| in_cross_replica_context = False |
| try: |
| in_cross_replica_context = (strategy_stack[-1].replica_context is None) # pylint: disable=protected-access |
| except (AttributeError, IndexError): |
| pass |
| |
| return CacheKey(input_signature, parent_graph, device_functions, |
| colocation_stack, in_cross_replica_context) |
| |
| def _create_graph_function(self, args, kwargs, override_flat_arg_shapes=None): |
| """Create a `ConcreteFunction` from `args` and `kwargs`.""" |
| if self.input_signature is None: |
| arglen = len(args) |
| else: |
| arglen = len(self.input_signature) |
| base_arg_names = self._function_spec.arg_names[:arglen] |
| num_missing_args = arglen - len(self._function_spec.arg_names) |
| missing_arg_names = [self._function_spec.vararg_name] * num_missing_args |
| # Produce a list of missing args of the form ["arg_0", "arg_1", ...], |
| # where arg is based on the self._function_spec.vararg_name. |
| missing_arg_names = [ |
| "%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names) |
| ] |
| arg_names = base_arg_names + missing_arg_names |
| graph_function = ConcreteFunction( |
| func_graph_module.func_graph_from_py_func( |
| self._name, |
| self._python_function, |
| args, |
| kwargs, |
| self.input_signature, |
| autograph=self._autograph, |
| autograph_options=self._autograph_options, |
| arg_names=arg_names, |
| override_flat_arg_shapes=override_flat_arg_shapes, |
| capture_by_value=self._capture_by_value), |
| self._function_attributes, |
| # Tell the ConcreteFunction to clean up its graph once it goes out of |
| # scope. This is not the default behavior since it gets used in some |
| # places (like Keras) where the FuncGraph lives longer than the |
| # ConcreteFunction. |
| shared_func_graph=False) |
| return graph_function |
| |
| def _define_function_with_shape_relaxation(self, args, kwargs): |
| """Define a function, relaxing arg shapes to avoid unecessary retracing.""" |
| |
| rank_only_cache_key = self._cache_key( |
| args, kwargs, include_tensor_ranks_only=True) |
| |
| arg_shapes = _flat_shape_list(args, kwargs) |
| relaxed_arg_shapes = self._function_cache.arg_relaxed_shapes.get( |
| rank_only_cache_key, None) |
| relaxed_arg_function = self._function_cache.arg_relaxed.get( |
| rank_only_cache_key, None) |
| |
| if (relaxed_arg_function is not None |
| and _compatible_shapes(flat_relaxed=relaxed_arg_shapes, |
| flat_to_check=arg_shapes)): |
| return relaxed_arg_function, args, kwargs |
| |
| if relaxed_arg_shapes is None: |
| relaxed_arg_shapes = arg_shapes |
| else: |
| if len(arg_shapes) != len(relaxed_arg_shapes): |
| raise RuntimeError("Expected arg_shapes len to match " |
| "relaxed_arg_shapes len: %d vs. %d" |
| % (len(arg_shapes), len(relaxed_arg_shapes))) |
| relaxed_arg_shapes = [ |
| common_shape(x, y) for (x, y) in zip( |
| arg_shapes, relaxed_arg_shapes)] |
| self._function_cache.arg_relaxed_shapes[rank_only_cache_key] = ( |
| relaxed_arg_shapes) |
| graph_function = self._create_graph_function( |
| args, kwargs, override_flat_arg_shapes=relaxed_arg_shapes) |
| self._function_cache.arg_relaxed[rank_only_cache_key] = graph_function |
| |
| return graph_function, args, kwargs |
| |
| def _maybe_define_function(self, args, kwargs): |
| """Gets a function for these inputs, defining it if necessary. |
| |
| `args` and `kwargs` can be None if this `Function` was created with an |
| `input_signature`. |
| |
| Args: |
| args: The varargs for the Python function. |
| kwargs: The keyword args for the Python function. |
| |
| Returns: |
| A graph function corresponding to the input signature implied by args and |
| kwargs, as well as the inputs that the object should be called with. |
| |
| Raises: |
| ValueError: If inputs are incompatible with the input signature. |
| TypeError: If the function inputs include non-hashable objects |
| RuntimeError: If there's an internal bug (inconsistency) in handling |
| shape relaxation retracing. |
| """ |
| if self.input_signature is None or args is not None or kwargs is not None: |
| args, kwargs = self._function_spec.canonicalize_function_inputs( |
| *args, **kwargs) |
| |
| cache_key = self._cache_key(args, kwargs) |
| |
| try: |
| hash(cache_key) |
| except TypeError as e: |
| raise TypeError( |
| "Arguments supplied to `defun`-generated functions must be" |
| " hashable. Original error: %s" % e) |
| |
| with self._lock: |
| graph_function = self._function_cache.primary.get(cache_key, None) |
| if graph_function is not None: |
| return graph_function, args, kwargs |
| |
| logging.vlog(1, |
| "Creating new FuncGraph for Python function %r (key: %r)", |
| self._python_function, cache_key) |
| logging.vlog(2, |
| "Python function signature [args: %s] [kwargs: %s]", |
| args, |
| kwargs) |
| |
| call_context_key = cache_key.replace(input_signature=None) |
| |
| ag_status = ( |
| ag_ctx.Status.ENABLED if self._autograph else ag_ctx.Status.DISABLED) |
| with ag_ctx.ControlStatusCtx( |
| status=ag_status, options=self._autograph_options): |
| |
| # Build a function with shape relaxation retracing if: |
| # 1. shape relaxation is explicitly enabled |
| # and 2. there's no provided input signature |
| # and 3. there's been a cache miss for this calling context |
| if (self._experimental_relax_shapes |
| and self.input_signature is None |
| and call_context_key in self._function_cache.missed): |
| return self._define_function_with_shape_relaxation(args, kwargs) |
| |
| self._function_cache.missed.add(call_context_key) |
| graph_function = self._function_cache.primary.get(cache_key, None) |
| if graph_function is None: |
| graph_function = self._create_graph_function(args, kwargs) |
| self._function_cache.primary[cache_key] = graph_function |
| return graph_function, args, kwargs |
| |
| |
| def register(func, *args, **kwargs): |
| """Register a specialization of a `Function` into the graph. |
| |
| This won't actually call the function with the inputs, and only put the |
| function definition into graph. Register function with different input param |
| will result into multiple version of functions registered in graph. |
| |
| Args: |
| func: the `Function` instance that generated by a @defun |
| *args: input arguments for the Python function. |
| **kwargs: input keyword arguments for the Python function. |
| |
| Returns: |
| a `ConcreteFunction` object specialized to inputs and execution context. |
| |
| Raises: |
| ValueError: When the input function is not a defun wrapped python function. |
| """ |
| if not isinstance(func, Function): |
| raise ValueError("Only defun function is allowed to be registered. " |
| "Got type: %s" % type(func)) |
| concrete_func = func.get_concrete_function(*args, **kwargs) |
| concrete_func.add_to_graph() |
| concrete_func.add_gradient_functions_to_graph() |
| return concrete_func |
| |
| |
| def validate_signature(signature): |
| if any(not isinstance(arg, tensor_spec.TensorSpec) |
| for arg in nest.flatten(signature, expand_composites=True)): |
| raise TypeError("Invalid input_signature {}; input_signature must be " |
| "a possibly nested sequence of TensorSpec objects." |
| .format(signature)) |
| |
| |
| def defun(func=None, |
| input_signature=None, |
| autograph=True, |
| experimental_autograph_options=None, |
| experimental_relax_shapes=False): |
| """Compiles a Python function into a callable TensorFlow graph. |
| |
| `defun` (short for "define function") compiles a Python function |
| composed of TensorFlow operations into a callable that executes a `tf.Graph` |
| containing those operations. The callable produced by `defun` contains only |
| the subgraph of TensorFlow operations that were executed when the Python |
| function was called with a particular input signature, defined as a list |
| of the shapes and dtypes of the Python function's Tensor-valued arguments and |
| the values of its non-Tensor Python objects. |
| |
| When eager execution is enabled, the ability to create graphs from Python |
| functions makes it possible to incrementally trade off debugability and |
| interactivity for performance. Functions compiled with `defun` cannot be |
| inspected with `pdb`; however, executing a graph |
| generated by `defun` sometimes takes less time and memory than eagerly |
| executing the corresponding Python function, since specifying computations as |
| graphs allows for optimizations like automatic buffer reuse and |
| parallelization among ops. Note that executing a `defun`-compiled function |
| incurs a small constant overhead, so eagerly executing sufficiently small |
| Python functions might take less time than executing their corresponding |
| `defun`-generated graphs. |
| |
| For a Python function to be compatible with `defun`, all of its arguments must |
| be hashable Python objects or lists thereof. The function itself may not |
| modify the list/map structure of its arguments. Additionally, it must return |
| zero or more `tf.Tensor` objects. If the Python function returns |
| a `tf.Variable`, its compiled version will return the value of that variable |
| as a `tf.Tensor`. |
| |
| Executing a graph generated by `defun` respects device annotations (i.e., |
| all `with tf.device` directives present in a Python function will also be |
| present in its corresponding graph), but it is not yet possible to execute the |
| generated graphs across multiple machines. |
| |
| _Example Usage_ |
| |
| ```python |
| import tensorflow as tf |
| |
| tf.compat.v1.enable_eager_execution() |
| |
| # A simple example. |
| def f(x, y): |
| return tf.reduce_mean(tf.multiply(x ** 2, 3) + y) |
| |
| g = tf.contrib.eager.defun(f) |
| |
| x = tf.constant([[2.0, 3.0]]) |
| y = tf.constant([[3.0, -2.0]]) |
| |
| # `f` and `g` will return the same value, but `g` will be executed as a |
| # TensorFlow graph. |
| assert f(x, y).numpy() == g(x, y).numpy() |
| |
| # `defun` is capable of compiling Python functions that close over Python |
| # objects, including Tensors and Variables. |
| @tf.contrib.eager.defun |
| def h(): |
| return f(x, y) |
| |
| assert (h().numpy() == f(x, y).numpy()).all() |
| |
| # `defun` automatically lifts variables out of the graphs it creates, |
| # allowing you to compile the `call` methods of `tf.keras.layers.Layer` and |
| # `tf.keras.Model` objects. |
| class MyModel(tf.keras.Model): |
| |
| def __init__(self, keep_probability=0.2): |
| super(MyModel, self).__init__() |
| self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) |
| self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) |
| self.keep_probability = keep_probability |
| |
| @tf.contrib.eager.defun |
| def call(self, inputs, training=True): |
| x = self.dense2(self.dense1(inputs)) |
| if training: |
| return tf.nn.dropout(x, self.keep_probability) |
| else: |
| return x |
| |
| model = MyModel() |
| model(x, training=True) # executes a graph, with dropout |
| model(x, training=False) # executes a graph, without dropout |
| |
| # `defun`-compiled functions are differentiable. |
| optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.01) |
| with tf.GradientTape() as tape: |
| outputs = model(x) |
| gradient = tape.gradient(outputs, model.trainable_variables) |
| optimizer.apply_gradients((grad, var) for grad, var in zip(gradient, |
| model.trainable_variables)) |
| ``` |
| |
| When using `defun`, there are subtleties regarding inputs, Python control |
| flow, and variable creation that one should be aware of. For concreteness, let |
| `f` be a Python function that returns zero or more `tf.Tensor` objects and |
| let `F = defun(f)`. `F` builds a graph for each unique input signature it |
| sees, Python control flow is baked into graphs, and operations related to |
| variable initialization are automatically lifted out of the graphs that `F` |
| generates and placed in the eager context if executing eagerly or into an |
| outer graph otherwise. |
| |
| _Input Signatures_ |
| |
| By default, `F = tf.contrib.eager.defun(f)` instantiates a separate graph |
| for every unique sequence of the shapes and dtypes of Tensor arguments and |
| the values of Python objects it is invoked with. For example, calling |
| `F(tf.random.uniform([2])` will execute a different graph than |
| `F(tf.random.uniform([3])` because the two inputs have different shapes. |
| The first time that `F(*args, **kwargs)` is called with a particular sequence |
| of Tensor shapes and dtypes and Python values, it constructs a graph by |
| tracing the execution of `f(*args, **kwargs)`; this graph is bound to an |
| input signature inferred from `(*args, **kwargs)` and cached for future reuse. |
| |
| NumPy arrays passed as inputs to `F` are converted to `tf.Tensor` objects |
| before being passed to `f`, and are treated as Tensors for caching. This |
| allows a function to be called multiple times with NumPy arrays having |
| different values but the same shape and dtype without re-tracing each time. |
| |
| `tf.contrib.eager.defun` caches graphs for your convenience, letting you |
| define TensorFlow functions without explicitly specifying their signatures. |
| However, this policy is conservative and potentially expensive; for example, |
| when different invocations of your function have differently-shaped Tensor |
| inputs, this policy might generate more graph functions than necessary. To |
| eliminate such costs, `tf.contrib.eager.defun` allows you to supply an |
| optional `input_signature` argument specifying the shapes and dtypes of the |
| inputs. In particular, the shapes may be partially unspecified, with `None`s |
| in the unknown dimensions. When an input signature is provided, |
| `tf.contrib.eager.defun` will only instantiate a single graph for the |
| decorated Python function. The following is an example: |
| |
| ```python |
| import tensorflow as tf |
| |
| # The first `TensorSpec` below describes the shape and dtype of `words`, |
| # and the second describes the shape and dtype of `another_tensor`. Note that |
| # the last dimension of the `words` `TensorSpec` is left unspecified. |
| @tf.contrib.eager.defun(input_signature=[ |
| tf.contrib.eager.TensorSpec(shape=[50, 300, None], dtype=tf.float32), |
| tf.contrib.eager.TensorSpec(shape=[300, 100], dtype=tf.float32) |
| ]) |
| def my_sequence_model(words, another_tensor): |
| ... |
| |
| # Note how the third dimension of the first input can vary freely. |
| words = tf.random.uniform(([50, 300, 10]) |
| second_input = tf.random.uniform([300, 100]) |
| my_sequence_model(words, second_input) |
| |
| words = tf.random.uniform(([50, 300, 20]) |
| my_sequence_model(words, second_input) |
| |
| # Passing an input with an incompatible shape will raise an error. |
| words = tf.random.uniform(([50, 100, 20]) |
| my_sequence_model(words, second_input) # <---- This will raise an error. |
| |
| ``` |
| |
| Python functions that are compiled with an `input_signature` must only accept |
| Tensors as arguments and must not take unnamed keyword arguments (**kwargs). |
| |
| _Tracing_ |
| |
| Be aware that because `F` only logs TensorFlow operations, all the other |
| Python code that `f` executes will only shape the _construction_ of the graphs |
| that `F` executes: the Python code won't be executed when the graphs |
| themselves are executed, though it will be executed every time the Python |
| function is traced (and a given Python function might be traced multiple |
| times, once for each input signature it is invoked with). For example, whereas |
| the Python function |
| |
| ```python |
| import tensorflow as tf |
| import numpy as np |
| |
| tf.compat.v1.enable_eager_execution() |
| |
| def add_noise(): |
| return tf.eye(5) + np.random.randn(5, 5) |
| ``` |
| |
| will return a different output everytime it is invoked, the compiled function |
| `compiled = tf.contrib.eager.defun(add_noise)` will return the same value |
| every time it is called, since a particular random offset generated by NumPy |
| will be inserted into the graph as a TensorFlow constant. The solution is to |
| replace the call to `np.random.randn` with `tf.random.normal((5, 5))`. |
| |
| _Python Side-Effects_ |
| |
| A corollary of the previous discussion on tracing is the following: If a |
| Python function `f` has Python side-effects, then executing `f` multiple times |
| will not necessarily be semantically equivalent to executing `F = |
| tf.contrib.eager.defun(f)` multiple times; this difference is due to the fact |
| that `defun` only captures the subgraph of TensorFlow operations that is |
| constructed when `f` is called in a graph-building context. |
| |
| _Python Control Flow_ |
| |
| The structure of many machine learning computations depend upon whether one is |
| training or validating, and it is common to nest specialized logic under `if |
| training:` blocks. By mapping each input signature to a unique graph, `defun` |
| lets users transparently compile such code, as the following code snippet |
| demonstrates: |
| |
| ```python |
| import tensorflow as tf |
| |
| tf.compat.v1.enable_eager_execution() |
| |
| @tf.contrib.eager.defun |
| def lossy_matmul(W, x, training=True): |
| outputs = tf.matmul(W, x) |
| if training: |
| outputs = tf.nn.dropout(outputs, keep_probability=0.2) |
| return outputs |
| |
| W = tf.random.normal((3, 5)) |
| x = tf.random.normal((5, 1)) |
| |
| # Executes a graph that applies dropout. |
| lossy_outputs = lossy_matmul(W, x, training=True) |
| |
| # Executes a graph that does not apply dropout. |
| exact_outputs = lossy_matmul(W, x, training=False) |
| ``` |
| |
| _TensorFlow Control Flow_ |
| |
| When `autograph` is `True`, data-dependent control flow is allowed as well. |
| Control flow statements that depend on `Tensor` values are staged into |
| corresponding TensorFlow ops. For example, the following code will work as |
| expected: |
| |
| ```python |
| @tf.contrib.eager.defun |
| def dynamic_rnn_loop(cell, seq): |
| state, output = cell.zero_state() |
| for input in seq: |
| state, output = cell(input, state) |
| return output |
| ``` |
| |
| For more information see `tf.autograph`. |
| |
| _Variables_ |
| |
| TensorFlow operations related to variable creation and initialization are |
| automatically lifted out of the graphs generated by `defun`. In practice, this |
| implies that variable creation and initialization only happen the first time |
| `F` is called, and that variables are reused every time thereafter. Many |
| TensorFlow APIs, like `tf.keras.layers.Layer` objects, create variables the |
| first time they are called and reuse them thereafter. Automatic variable |
| lifting makes it possible to compile these APIs without extra effort, at the |
| cost of introducing a discrepancy between the semantics of executing Python |
| functions and their corresponding compiled functions. For example: |
| |
| ```python |
| import tensorflow as tf |
| |
| tf.compat.v1.enable_eager_execution() |
| |
| def fn(): |
| x = tf.Variable(0.0) |
| x.assign_add(1.0) |
| return x.read_value() |
| |
| # `fn` is a Python function, so x is created, initialized, and destroyed upon |
| # every invocation |
| assert fn().numpy() == fn().numpy() == 1.0 |
| |
| compiled = tf.contrib.eager.defun(fn) |
| |
| # Compiling `fn` with `defun` hoists all variables outside of the generated |
| # graph, so initialization happens exactly once. |
| assert compiled().numpy() == 1.0 |
| assert compiled().numpy() == 2.0 |
| ``` |
| |
| Finally, because each input signature is bound to a unique graph, if your |
| Python function constructs `tf.Variable` objects, then each graph constructed |
| for that Python function will reference a unique set of variables. To |
| circumvent this problem, we recommend against compiling Python functions that |
| create `tf.Variable` objects. Instead, Python functions should either |
| lexically close over `tf.Variable` objects or accept them as arguments, |
| preferably encapsulated in an object-oriented container. If you must create |
| variables inside your Python function and you want each graph generated for it |
| to reference the same set of variables, add logic to your Python function that |
| ensures that variables are only created the first time it is called and are |
| reused for every subsequent invocation; note that this is precisely what |
| `tf.keras.layers.Layer` objects do, so we recommend using them to represent |
| variable-bearing computations whenever possible. |
| |
| Args: |
| func: function to be compiled. If `func` is None, returns a |
| decorator that can be invoked with a single argument - `func`. The |
| end result is equivalent to providing all the arguments up front. |
| In other words, defun(input_signature=...)(func) is equivalent to |
| defun(func, input_signature=...). The former allows |
| the following use case: |
| @tf.contrib.eager.defun(input_signature=...) |
| def foo(...): |
| ... |
| |
| input_signature: A possibly nested sequence of |
| `tf.contrib.eager.TensorSpec` objects specifying the shapes and dtypes of |
| the Tensors that will be supplied to this function. If `None`, a separate |
| function is instantiated for each inferred input signature. If a |
| signature is specified, every input to `func` must be a `Tensor`, and |
| `func` cannot accept `**kwargs`. |
| autograph: Whether `func` should be compiled before |
| constructing the graph. See https://www.tensorflow.org/guide/autograph |
| for more information. |
| experimental_autograph_options: Experimental knobs (in the form of a tuple |
| of tensorflow.autograph.Feature values) to control behavior when |
| autograph=True. |
| experimental_relax_shapes: When true, argument shapes may be relaxed to |
| avoid unecessary retracing. |
| |
| Returns: |
| If `func` is not None, returns a callable that will execute the compiled |
| function (and return zero or more `tf.Tensor` objects). |
| If `func` is None, returns a decorator that, when invoked with a single |
| `func` argument, returns a callable equivalent to the case above. |
| |
| Raises: |
| TypeError: If `input_signature` is neither `None` nor a sequence of |
| `tf.contrib.eager.TensorSpec` objects. |
| """ |
| return defun_with_attributes( |
| func=func, |
| input_signature=input_signature, |
| autograph=autograph, |
| experimental_autograph_options=experimental_autograph_options, |
| experimental_relax_shapes=experimental_relax_shapes) |
| |
| |
| def defun_with_attributes(func=None, |
| input_signature=None, |
| attributes=None, |
| autograph=True, |
| experimental_autograph_options=None, |
| experimental_relax_shapes=False): |
| """Compiles a Python function into a callable TensorFlow graph. |
| |
| This function supports adding extra function attributes. See detailed |
| documentation in defun(). Currently this is not exposed in public API since we |
| don't expect user to directly use attributes, and attribute won't work by |
| itself. This assumption might change in future. |
| |
| Args: |
| func: function to be compiled. |
| input_signature: same as defun()'s input_signature. |
| attributes: A dictionary of arguments which will be added to function def as |
| attributes. Currently only support primitive types as value, and only |
| whitelisted attribute name is allowed. Unwhitelisted attribute name or |
| unsupported value will result into ValueError. `func_name` is also one of |
| the whitelisted argument which is a python string, and sets the name for |
| this `ConcreteFunction` in the graph. |
| autograph: same as defun()'s autograph. |
| experimental_autograph_options: same as defun()'s |
| experimental_autograph_options. |
| experimental_relax_shapes: same as defun()'s experimental_relax_shapes |
| |
| Returns: |
| Same as the return value of defun, with attributes added to the function in |
| graph. |
| """ |
| if input_signature is not None: |
| validate_signature(input_signature) |
| |
| # TODO(apassos): deal with captured global state. Deal with control flow. |
| def decorated(function): |
| try: |
| if attributes: |
| name = attributes.pop("func_name", function.__name__) |
| else: |
| name = function.__name__ |
| except AttributeError: |
| name = "function" |
| return tf_decorator.make_decorator( |
| function, |
| Function( |
| function, |
| name, |
| input_signature=input_signature, |
| attributes=attributes, |
| autograph=autograph, |
| autograph_options=experimental_autograph_options, |
| experimental_relax_shapes=experimental_relax_shapes)) |
| |
| # This code path is for the `foo = tfe.defun(foo, ...)` use case |
| if func is not None: |
| return decorated(func) |
| |
| # This code path is for the |
| # |
| # @tfe.defun(...) |
| # def foo(...): |
| # ... |
| # |
| # use case, which is equivalent to `foo = tfe.defun(...)(foo)` |
| return decorated |
| |
| |
| # When a method is bound to objects of this type, it allows AutoGraph to |
| # recover a weak reference the original method's self pointer, so that it can |
| # execute it consistent with class_method_to_instance_method's |
| # bound_method_wrapper. |
| # TODO(b/119246461): This is not pretty. Use a descriptor instead? |
| class TfMethodTarget(object): |
| """Binding target for methods replaced by function and defun.""" |
| |
| def __init__(self, target, original_python_function): |
| self.weakrefself_target__ = target |
| self.weakrefself_func__ = weakref.ref(original_python_function) |
| |
| @property |
| def target(self): |
| return self.weakrefself_target__() |
| |
| def call(self, args, kwargs): |
| wrapped_fn = self.weakrefself_func__() |
| if tf_inspect.ismethod(wrapped_fn): |
| wrapped_fn = six.get_unbound_function(wrapped_fn) |
| return wrapped_fn(self.weakrefself_target__(), *args, **kwargs) |
| |
| |
| def class_method_to_instance_method(original_function, instance): |
| """Constructs a new `Function` with `self` bound.""" |
| weak_instance = weakref.ref(instance) |
| |
| # Note: while we could bind to a weakref proxy instead, that causes the |
| # bound method to be unhashable. |
| bound_method = types_lib.MethodType( |
| original_function.python_function, |
| TfMethodTarget(weak_instance, original_function.python_function)) |
| |
| # original_function is expected to be of one of the two `Function` types |
| # (defined either in function.py or def_function.py). |
| assert hasattr(original_function, "_name") |
| assert hasattr(original_function, "_autograph") |
| assert hasattr(original_function, "_function_spec") |
| assert hasattr(original_function, "python_function") |
| |
| weak_bound_method_wrapper = None |
| def bound_method_wrapper(*args, **kwargs): |
| """Wraps either a dummy MethodType or a converted AutoGraph function.""" |
| # __wrapped__ allows AutoGraph to swap in a converted function. |
| strong_bound_method_wrapper = weak_bound_method_wrapper() |
| wrapped_fn = strong_bound_method_wrapper.__wrapped__ |
| |
| if wrapped_fn is strong_bound_method_wrapper.__original_wrapped__: |
| # If __wrapped__ was not replaced, then call original_function. |
| # TODO(mdan): For better consistency, use the wrapper's call(). |
| wrapped_fn = original_function.python_function |
| if tf_inspect.ismethod(wrapped_fn): |
| wrapped_fn = six.get_unbound_function(wrapped_fn) |
| return wrapped_fn(weak_instance(), *args, **kwargs) |
| |
| # If __wrapped__ was replaced, then it is always an unbound function. |
| # However, the replacer is still responsible for attaching self properly. |
| # TODO(mdan): Is it possible to do it here instead? |
| return wrapped_fn(*args, **kwargs) |
| weak_bound_method_wrapper = weakref.ref(bound_method_wrapper) |
| |
| # pylint: disable=protected-access |
| # We make a dummy MethodType object to generate the correct bound method |
| # signature. The actual call is to a function with a weak reference to |
| # `instance`. |
| instance_func = type(original_function)( |
| tf_decorator.make_decorator(bound_method, bound_method_wrapper), |
| name=original_function._name, |
| autograph=original_function._autograph, |
| input_signature=original_function.input_signature) |
| # pylint: enable=protected-access |
| |
| # And we wrap the function with tf_decorator so inspection works correctly |
| wrapped_instance_func = tf_decorator.make_decorator( |
| original_function.python_function, instance_func) |
| return wrapped_instance_func |
| |
| |
| class _FunctionGarbageCollector(object): |
| """Cleans up cycles when a defun goes out of scope.""" |
| |
| def __init__(self, cache): |
| self._cache = cache |
| |
| def __del__(self): |
| if func_graph_module is None or memory is None: |
| return |
| try: |
| while self._cache: |
| self._cache.popitem() |
| memory.dismantle_ordered_dict(self._cache) |
| except: # pylint: disable=bare-except |
| pass |
| |
| |
| class ConcreteFunctionGarbageCollector(object): |
| """Cleans up reference cycles when a `ConcreteFunction` goes out of scope.""" |
| |
| def __init__(self, func_graph): |
| self._func_graph = func_graph |
| |
| def release(self): |
| """Call off the FuncGraph deletion.""" |
| self._func_graph = None |
| |
| def __del__(self): |
| if func_graph_module is None or memory is None or self._func_graph is None: |
| return |
| try: |
| func_graph_module.dismantle_func_graph(self._func_graph) |
| except: # pylint: disable=bare-except |
| pass |