| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """FuncGraph and related functionality.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import collections as py_collections |
| import itertools |
| import weakref |
| |
| from tensorflow.core.framework import attr_value_pb2 |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import execute |
| from tensorflow.python.eager import tape |
| from tensorflow.python.eager.graph_only_ops import graph_placeholder |
| from tensorflow.python.framework import composite_tensor |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_spec |
| from tensorflow.python.framework.auto_control_deps import AutomaticControlDependencies |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import custom_gradient |
| from tensorflow.python.ops import resource_variable_ops |
| from tensorflow.python.ops import tensor_array_ops |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.util import compat |
| from tensorflow.python.util import memory |
| from tensorflow.python.util import nest |
| from tensorflow.python.util import tf_contextlib |
| from tensorflow.python.util import tf_decorator |
| from tensorflow.python.util.lazy_loader import LazyLoader |
| |
| # This is to avoid a circular dependency: |
| # function -> func_graph |
| function = LazyLoader("function", globals(), |
| "tensorflow.python.eager.function") |
| def_function = LazyLoader( |
| "def_function", globals(), |
| "tensorflow.python.eager.def_function") |
| |
| WHITELIST_COLLECTIONS = [ |
| ops.GraphKeys.GLOBAL_VARIABLES, |
| ops.GraphKeys.LOCAL_VARIABLES, |
| ops.GraphKeys.TRAINABLE_VARIABLES, |
| variable_scope._VARSTORE_KEY, # pylint: disable=protected-access |
| variable_scope._VARSCOPESTORE_KEY # pylint: disable=protected-access |
| ] |
| |
| |
| class UnknownArgument(object): |
| """Signifies an argument which is not currently handled.""" |
| pass |
| |
| |
| def convert_structure_to_signature(structure, arg_names=None): |
| """Convert a potentially nested structure to a signature. |
| |
| Args: |
| structure: Structure to convert, where top level collection is a list or a |
| tuple. |
| arg_names: Optional list of arguments that has equal number of elements as |
| `structure` and is used for naming corresponding TensorSpecs. |
| |
| Returns: |
| Identical structure that has TensorSpec objects instead of Tensors and |
| UknownArgument instead of any unsupported types. |
| """ |
| def encode_arg(arg, path): |
| """A representation for this argument, for converting into signatures.""" |
| if isinstance(arg, ops.Tensor): |
| user_specified_name = None |
| try: |
| user_specified_name = compat.as_str( |
| arg.op.get_attr("_user_specified_name")) |
| except ValueError: |
| pass |
| |
| if path and user_specified_name and user_specified_name != path[0]: |
| # The user has explicitly named the argument differently than the name |
| # of the function argument. |
| name = user_specified_name |
| else: |
| name = "/".join([str(p) for p in path]) |
| return tensor_spec.TensorSpec(arg.shape, arg.dtype, name) |
| if isinstance(arg, composite_tensor.CompositeTensor): |
| # TODO(b/133606651) Do we need to inject arg_name? |
| return arg._type_spec # pylint: disable=protected-access |
| if isinstance(arg, ( |
| int, |
| float, |
| bool, |
| type(None), |
| dtypes.DType, |
| tensor_spec.TensorSpec, |
| )): |
| return arg |
| return UnknownArgument() |
| |
| # We are using the flattened paths to name the TensorSpecs. We need an |
| # explicit name for them downstream. |
| flattened = nest.flatten_with_tuple_paths(structure) |
| if arg_names: |
| if len(arg_names) != len(structure): |
| raise ValueError( |
| "Passed in arg_names don't match actual signature (%s)." % arg_names) |
| # Replace all top-level names with their actual arg_names. If a path before |
| # was "(2,'a',1)", it will become "(arg_names[2],'a',1)". |
| flattened = [ |
| ((arg_names[path[0]],) + path[1:], arg) for path, arg in flattened |
| ] |
| |
| mapped = [encode_arg(arg, path) for path, arg in flattened] |
| return nest.pack_sequence_as(structure, mapped) |
| |
| |
| class FuncGraph(ops.Graph): |
| """Graph representing a function body. |
| |
| Attributes: |
| name: The name of the function. |
| inputs: Placeholder tensors representing the inputs to this function. The |
| tensors are in this FuncGraph. This represents "regular" inputs as well as |
| captured inputs (i.e. the values of self.captures), with the regular |
| inputs coming first. |
| outputs: Tensors that will be returned by this function. The tensors are in |
| this FuncGraph. |
| control_outputs: Operations that must be executed before the function |
| represented by this graph can be said to have been executed. |
| structured_input_signature: A tuple of (args, kwargs), which are both |
| possibly-nested python objects that were received by this function. Note |
| that these structures might contain Python `None`s. |
| structured_outputs: A possibly-nested python object which will be returned |
| by this function. The Tensors in this structure are the same as those of |
| self.outputs. Note that this structure might contain Python `None`s. |
| variables: Variables that should be watched during function execution. |
| outer_graph: The graph this function is defined in. May be another FuncGraph |
| or the global default Graph. |
| captures: Maps external tensor -> internal tensor (i.e. input placeholder). |
| The entries are in the order they were captured. |
| deferred_captures: Maps arbitrary key -> (closure, nest of placeholders), |
| where at function call time the value of closure() will be used to feed |
| the nest of placeholders. |
| control_captures: Set of external ops on which this graph has a control |
| dependency. |
| seed: The graph-level random seed. |
| capture_by_value: If True, the func graph will capture Variables by value |
| instead of reference. |
| """ |
| |
| def __init__(self, name, collections=None, capture_by_value=None): |
| """Construct a new FuncGraph. |
| |
| The graph will inherit its graph key, collections, seed, and distribution |
| strategy stack from the current context or graph. |
| |
| Args: |
| name: the name of the function. |
| collections: a dictionary of collections this FuncGraph should start |
| with. If not specified (None), the FuncGraph will read (but not write |
| to) the outer graph's collections that are not whitelisted, and both |
| read and write to the outer graph's collections that are whitelisted. |
| The current whitelisted collections are the global variables, the |
| local variables, and the trainable variables. |
| Defaults to None. |
| capture_by_value: An optional boolean. If True, the func graph will |
| capture Variables by value instead of reference. By default inherit |
| from outer graphs, and failing that will default to False. |
| """ |
| super(FuncGraph, self).__init__() |
| |
| self.name = name |
| self.inputs = [] |
| self.outputs = [] |
| self.control_outputs = [] |
| self.control_captures = set() |
| self.structured_input_signature = None |
| self.structured_outputs = None |
| self._weak_variables = [] |
| self._watched_variables = weakref.WeakSet() |
| self.outer_graph = ops.get_default_graph() |
| self.captures = py_collections.OrderedDict() |
| # If not None, records the names of output args of this function. Used to |
| # preserve the output names in the signature of a serialized+deserialized |
| # function. Private at the moment mostly because it's often out of date. |
| self._output_names = None |
| self.deferred_captures = py_collections.OrderedDict() |
| # Inherit capture-by-value from outer graph. |
| if capture_by_value is not None: |
| self.capture_by_value = capture_by_value |
| elif self.outer_graph is not None and isinstance( |
| self.outer_graph, FuncGraph): |
| self.capture_by_value = self.outer_graph.capture_by_value |
| else: |
| self.capture_by_value = False |
| |
| self._building_function = True |
| # Map from resource tensor name to last op (in program order) which uses |
| # this tensor. Used to enforce that execution order matches program order |
| # for resource tensors. |
| self._last_op_using_resource_tensor = {} |
| |
| graph = self.outer_graph |
| |
| if context.executing_eagerly(): |
| self.seed = context.global_seed() |
| # [for tf-data user migration from TF1.0 to 2.0] seed_used keep track of |
| # any None op_seed for random_op in the function, in which case we end up |
| # using function seed, which could be unintended behavior for the op. |
| self._seed_used = False |
| else: |
| self.seed = graph.seed |
| self._seed_used = False |
| # TODO(allenl): Figure out if we can remove colocation stack |
| # specialization (currently used in cond_v2), here and in the cache key. |
| self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access |
| |
| if collections is None: |
| for collection_name in graph.get_all_collection_keys(): |
| if collection_name not in WHITELIST_COLLECTIONS: |
| self._collections[collection_name] = graph.get_collection( |
| collection_name) |
| for collection_name in WHITELIST_COLLECTIONS: |
| self._collections[collection_name] = graph.get_collection_ref( |
| collection_name) |
| else: |
| self._collections = collections |
| |
| def __str__(self): |
| return "FuncGraph(name=%s, id=%s)" % (self.name, id(self)) |
| |
| def watch_variable(self, v): |
| """Marks the variable v as accessed while building this graph.""" |
| while self is not None and isinstance(self, FuncGraph): |
| self._watched_variables.add(v) |
| self = self.outer_graph |
| |
| def capture_call_time_value(self, closure, spec, key=None): |
| """Creates a placeholder which at call time has the value closure(). |
| |
| Useful, for example, to respect TensorFlow context managers, which are often |
| dynamically scoped. |
| |
| Args: |
| closure: function which takes no arguments, to be evaluated at function |
| call time, returning a nest of tensors compatible with `spec`. |
| spec: nest of TypeSpec for the value to capture. |
| key: optional. If not None, multiple calls to lazy_capture with the same |
| key in the same graph will return the same placeholder, and the |
| first closure will be used at function call time. |
| |
| Returns: |
| Nest of placeholders which, at function call time, will be fed with the |
| result of calling closure(). |
| |
| Raises: |
| ValueError: at function call time, if the return value of closure() is |
| not compatible with `spec`. |
| """ |
| if key is None: |
| key = object() |
| if key not in self.deferred_captures: |
| |
| def convert_to_placeholder(s): |
| if not isinstance(s, tensor_spec.TensorSpec): |
| raise TypeError( |
| "Expected a nest of `TypeSpec` objects, found %s of type %s." % |
| (s, type(s))) |
| return array_ops.placeholder(dtype=s.dtype, shape=s.shape) |
| |
| placeholder = nest.map_structure( |
| convert_to_placeholder, spec, expand_composites=True) |
| |
| def wrapped_closure(): |
| ret_nest = closure() |
| nest.assert_same_structure(spec, ret_nest, expand_composites=True) |
| # This uses the tensor dtype defined in `spec` when converting values |
| # in `ret_nest` to tensors. |
| # pylint: disable=protected-access |
| y = nest.map_structure(lambda s, r: s._to_components(r), spec, ret_nest, |
| expand_composites=False) |
| # pylint: enable=protected-access |
| return nest.flatten(y, expand_composites=True) |
| |
| self.deferred_captures[key] = (wrapped_closure, placeholder) |
| return self.deferred_captures[key][1] |
| |
| def control_dependencies(self, control_inputs): |
| """Handles control dependencies. |
| |
| FuncGraph wraps Graph's control_dependencies logic by first filtering out |
| any external tensors / operations and storing them in the graph's |
| control_captures member. Any consumers of this function graph must then |
| decide how to handle the control captures. |
| |
| Args: |
| control_inputs: A list of `Operation` or `Tensor` objects which |
| must be executed or computed before running the operations |
| defined in the context. Can also be `None` to clear the control |
| dependencies. |
| |
| Returns: |
| A context manager that specifies control dependencies for all |
| operations constructed within the context. |
| |
| Raises: |
| TypeError: If `control_inputs` is not a list of `Operation` or |
| `Tensor` objects. |
| """ |
| if control_inputs is None: |
| return super(FuncGraph, self).control_dependencies(control_inputs) |
| |
| filtered_control_inputs = [] |
| for c in control_inputs: |
| # Check for _UnreadVariable |
| if (isinstance(c, ops.IndexedSlices) or |
| (hasattr(c, "_handle") and hasattr(c, "op"))): |
| c = c.op |
| graph_element = ops._as_graph_element(c) # pylint: disable=protected-access |
| if graph_element is None: |
| graph_element = c |
| if graph_element is not None and getattr( |
| graph_element, "graph", None) is not self: |
| self.control_captures.add(graph_element) |
| else: |
| filtered_control_inputs.append(graph_element) |
| return super(FuncGraph, self).control_dependencies(filtered_control_inputs) |
| |
| def as_default(self): |
| outer_cm = super(FuncGraph, self).as_default() |
| |
| @tf_contextlib.contextmanager |
| def inner_cm(): |
| """Context manager for copying distribute.Strategy scope information.""" |
| graph = ops.get_default_graph() |
| # pylint: disable=protected-access |
| # TODO(b/112906995, nareshmodi): distribution strategy depends on |
| # inheriting this stack from the default graph even in eager mode. Maybe |
| # it should be part of the eager context? This would also allow us to |
| # remove a get_default_graph() call from the function cache lookup. |
| old_strategy_stack = self._distribution_strategy_stack |
| self._distribution_strategy_stack = list( |
| graph._distribution_strategy_stack) |
| # We ignore device placements from any outer scopes while tracing the |
| # function when possible, to avoid hard-coding them in the function |
| # graph. "Default" placements come from the PartitionedCallOp's placement, |
| # so that the same trace of the Python function may be placed on several |
| # different devices and saved functions may be placed on new devices when |
| # restored. |
| old_device_stack = self._device_function_stack |
| if context.executing_eagerly(): |
| if self._distribution_strategy_stack: |
| self._device_function_stack = self._device_function_stack.copy() |
| self._add_device_to_stack(context.context().device_name) |
| else: |
| if (self._distribution_strategy_stack |
| or device_stack_has_callable(graph._device_function_stack)): |
| # Hard-code devices from device functions in the function body |
| self._device_function_stack = graph._device_function_stack.copy() |
| |
| old_creator_stack = self._variable_creator_stack |
| self._variable_creator_stack = graph._variable_creator_stack |
| # Inherit the graph key, since this is used for matching variables in |
| # optimizers. |
| old_graph_key = self._graph_key |
| self._graph_key = graph._graph_key |
| # Inherit the auto_cast_variable_read_dtype, since this should not change |
| # inside a function. |
| old_auto_cast_var_read_dtype = self._auto_cast_variable_read_dtype |
| self._auto_cast_variable_read_dtype = graph._auto_cast_variable_read_dtype |
| # pylint: enable=protected-access |
| |
| with outer_cm as g: |
| try: |
| yield g |
| finally: |
| self._distribution_strategy_stack = old_strategy_stack |
| self._device_function_stack = old_device_stack |
| self._variable_creator_stack = old_creator_stack |
| self._graph_key = old_graph_key |
| self._auto_cast_variable_read_dtype = old_auto_cast_var_read_dtype |
| return inner_cm() |
| |
| @property |
| def output_types(self): |
| return [t.dtype for t in self.outputs] |
| |
| @property |
| def output_shapes(self): |
| return [t.shape for t in self.outputs] |
| |
| @property |
| def variables(self): |
| """A list of variables accessed by this FuncGraph. |
| |
| Note that functions keep only weak references to variables. Calling the |
| function after a variable it accesses has been deleted is an error. |
| |
| Yields: |
| Strong references to variables accessed by this FuncGraph. |
| """ |
| for weak_v in self._weak_variables: |
| v = weak_v() |
| if v is None: |
| raise AssertionError( |
| "Called a function referencing variables which have been deleted. " |
| "This likely means that function-local variables were created and " |
| "not referenced elsewhere in the program. This is generally a " |
| "mistake; consider storing variables in an object attribute on " |
| "first call.") |
| yield v |
| |
| @variables.setter |
| def variables(self, var_list): |
| self._weak_variables = [weakref.ref(v) for v in var_list] |
| |
| def _capture_by_value( |
| self, |
| op_type, |
| inputs, |
| dtypes, # pylint: disable=redefined-outer-name |
| input_types=None, |
| name=None, |
| attrs=None, |
| op_def=None, |
| compute_device=True): |
| # When capturing by value, do the read outside |
| reverse_captures = dict((v, k) for k, v in self.captures.items()) |
| uncaptured_inputs = [reverse_captures.get(t, t) for t in inputs] |
| with ops.init_scope(): |
| if context.executing_eagerly(): |
| attr_list = ("dtype", int(attrs["dtype"].type)) |
| value, = execute.execute( |
| compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list, |
| context.context()) |
| else: |
| op = ops.get_default_graph()._create_op_internal( # pylint: disable=protected-access |
| op_type, |
| uncaptured_inputs, |
| dtypes, |
| input_types, |
| name, |
| attrs, |
| op_def, |
| compute_device) |
| value = op.outputs[0] |
| captured_value = self.capture(value) |
| return captured_value.op |
| |
| def create_op( |
| self, |
| op_type, |
| inputs, |
| dtypes=None, # pylint: disable=redefined-outer-name |
| input_types=None, |
| name=None, |
| attrs=None, |
| op_def=None, |
| compute_shapes=True, |
| compute_device=True): |
| """Like Graph.create_op, except handles external input tensors. |
| |
| This overload adds functionality to create_op to "capture" any external |
| input tensors, i.e. tensors from the eager context or outer function graphs |
| if this is a nested function. See `capture` for more information. |
| |
| Args: |
| op_type: The `Operation` type to create. This corresponds to the |
| `OpDef.name` field for the proto that defines the operation. |
| inputs: A list of `Tensor` objects that will be inputs to the `Operation`. |
| dtypes: (Optional) A list of `DType` objects that will be the types of the |
| tensors that the operation produces. |
| input_types: (Optional.) A list of `DType`s that will be the types of |
| the tensors that the operation consumes. By default, uses the base |
| `DType` of each input in `inputs`. Operations that expect |
| reference-typed inputs must specify `input_types` explicitly. |
| name: (Optional.) A string name for the operation. If not specified, a |
| name is generated based on `op_type`. |
| attrs: (Optional.) A dictionary where the key is the attribute name (a |
| string) and the value is the respective `attr` attribute of the |
| `NodeDef` proto that will represent the operation (an `AttrValue` |
| proto). |
| op_def: (Optional.) The `OpDef` proto that describes the `op_type` that |
| the operation will have. |
| compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always |
| computed). |
| compute_device: (Optional.) If True, device functions will be executed |
| to compute the device property of the Operation. |
| |
| Returns: |
| An `Operation` object. |
| """ |
| del compute_shapes |
| if self.capture_by_value and op_type in ["ReadVariableOp", |
| "ResourceGather"]: |
| return self._capture_by_value(op_type, inputs, dtypes, input_types, name, |
| attrs, op_def, compute_device) |
| |
| # This capturing logic interacts poorly with control flow contexts which |
| # want to replace inputs of ops far too late in the process. This can lead |
| # the context to get confused and try to create an Enter for an Enter. We |
| # can detect this here and skip the additional Enter which can confuse loop |
| # validation logic. |
| if op_type == "Enter" and inputs[0].op.type == "Enter": |
| if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s: |
| return inputs[0].op |
| # Calling AddValue on the control flow contexts to force creation of the |
| # backward accumulators in the original graph before we create placeholders |
| # to capture the inputs. |
| ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access |
| for i, inp in enumerate(inputs): |
| # TPU Estimator defines a control flow context with no AddValue method. |
| if ctxt is not None and hasattr(ctxt, "AddValue"): |
| inp = ctxt.AddValue(inp) |
| inp = self.capture(inp) |
| inputs[i] = inp |
| return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access |
| op_type, inputs, dtypes, input_types, name, attrs, op_def, |
| compute_device) |
| |
| def capture(self, tensor, name=None): |
| """Captures `tensor` if it's external to this graph. |
| |
| If `tensor` is from a different graph, returns a placeholder for it. |
| `tensor` and the placeholder will appear in self.captures, and the |
| placeholder will appear in self.inputs. Multiple calls to this method with |
| the same `tensor` argument will return the same placeholder. If `tensor` is |
| from this graph, returns `tensor`. |
| |
| Args: |
| tensor: Tensor. May be from this FuncGraph or a different graph. |
| name: Optional name if a placeholder is created. |
| |
| Returns: |
| Tensor from this FuncGraph. |
| """ |
| # Note: _forward_func_graph is currently only set when building the gradient |
| # graph graph of a defun call. If the backwards graph tries to capture |
| # tensors those will be captured first in the forward graph. This |
| # makes sure that any tensor needed by a custom_gradient is correctly |
| # captured. |
| |
| # TODO(b/134097853): figure out a better way to check distributed variables |
| if hasattr(tensor, "_distribute_strategy") and hasattr(tensor, "_values"): |
| # This checks if the 'tensor' is a DistributedVariable. When it is a |
| # DistributedVariable, we do not want to check its "graph" attr as the |
| # following if branch does, because "graph" is not an attr for the |
| # container DistributedVariable object, and the underlying components may |
| # not have been initialized yet. |
| # The reason we do not use isinstance() is due to cyclic dependency issue. |
| if name is None: |
| name = str("distributed_variable") |
| return self._capture_helper(tensor, name) |
| if (getattr(tensor, "graph", None) is not self and |
| hasattr(self, "_forward_func_graph") and |
| isinstance(self._forward_func_graph, FuncGraph)): |
| tensor = self._forward_func_graph.capture(tensor) |
| if isinstance(tensor, ops.EagerTensor): |
| if name is None: |
| name = str(ops.uid()) |
| return self._capture_helper(tensor, name) |
| if tensor.graph is not self: |
| if name is None: |
| name = tensor.op.name |
| inner_graph = tensor.graph |
| while inner_graph is not None and isinstance(inner_graph, FuncGraph): |
| if inner_graph is self: |
| raise ValueError( |
| "Trying to capture a tensor from an inner function. This can be " |
| "caused by accessing a tensor defined inside a loop or " |
| "conditional body, or a subfunction, from a calling function, " |
| "without going through the proper return value mechanism. " |
| "Consider using TensorFlow mechanisms such as TensorArrays " |
| "to return tensors from inner functions or loop / conditional " |
| "bodies. Tensor: %s; tensor graph: %s; this graph: %s" |
| % (tensor, tensor.graph, self)) |
| inner_graph = inner_graph.outer_graph |
| return self._capture_helper(tensor, name) |
| return tensor |
| |
| def _capture_helper(self, tensor, name): |
| captured_tensor = self.captures.get(tensor, None) |
| if captured_tensor is None: |
| captured_tensor = _create_substitute_placeholder(tensor, name=name, |
| dtype=tensor.dtype) |
| self.captures[tensor] = captured_tensor |
| self.inputs.append(captured_tensor) |
| tape.record_operation("captured_value", [captured_tensor], [tensor], |
| lambda x: [x]) |
| return captured_tensor |
| |
| @property |
| def external_captures(self): |
| """External tensors captured by this function.""" |
| return list(self.captures.keys()) |
| |
| @property |
| def internal_captures(self): |
| """Placeholders in this function corresponding captured tensors.""" |
| return list(self.captures.values()) |
| |
| |
| def func_graph_from_py_func(name, |
| python_func, |
| args, |
| kwargs, |
| signature=None, |
| func_graph=None, |
| autograph=False, |
| autograph_options=None, |
| add_control_dependencies=True, |
| arg_names=None, |
| op_return_value=None, |
| collections=None, |
| capture_by_value=None, |
| override_flat_arg_shapes=None): |
| """Returns a `FuncGraph` generated from `python_func`. |
| |
| Args: |
| name: an identifier for the function. |
| python_func: the Python function to trace. |
| args: the positional args with which the Python function should be called; |
| ignored if a signature is provided. |
| kwargs: the keyword args with which the Python function should be called; |
| ignored if a signature is provided. |
| signature: a possibly nested sequence of `TensorSpecs` specifying the shapes |
| and dtypes of the arguments. When a signature is provided, `args` and |
| `kwargs` are ignored, and `python_func` is traced with Tensors conforming |
| to `signature`. If `None`, the shapes and dtypes are inferred from the |
| inputs. |
| func_graph: Optional. An instance of FuncGraph. If provided, we will use |
| this graph else a new one is built and returned. |
| autograph: whether to use autograph to compile `python_func`. |
| See https://www.tensorflow.org/guide/autograph for more information. |
| autograph_options: additional knobs to control when `autograph=True`. |
| See https://www.tensorflow.org/guide/autograph for more information. |
| add_control_dependencies: If True, automatically adds control dependencies |
| to ensure program order matches execution order and stateful ops always |
| execute. |
| arg_names: Optional list of argument names, used to give input placeholders |
| recognizable names. |
| op_return_value: Optional. A Tensor. If set and `python_func` returns |
| Operations, those return values will be replaced with this value. If not |
| set, returning an Operation triggers an error. |
| collections: a dictionary of collections this FuncGraph should start |
| with. If not specified (None), the FuncGraph will read (but not write to) |
| the outer graph's collections that are not whitelisted, and both |
| read and write to the outer graph's collections that are whitelisted. |
| The current whitelisted collections are the global variables, the |
| local variables, and the trainable variables. |
| Defaults to None. |
| capture_by_value: An optional boolean. If True, the func graph will capture |
| Variables by value instead of reference. By default inherit from outer |
| graphs, and failing that will default to False. |
| override_flat_arg_shapes: An optional list of instances that are either |
| `None` or `TensorShape`. The length must match that of |
| `nest.flatten((args, kwargs), expand_composites=True)`. The entries |
| containing value `None` must match entries in flattened arguments |
| containing non-tensors, while entries containing a `TensorShape` must |
| match entries in the flattened arguments containing tensors. |
| |
| Returns: |
| A FuncGraph. |
| |
| Raises: |
| TypeError: If any of `python_func`'s return values is neither `None` nor a |
| `Tensor`. |
| ValueError: If both `signature` and `override_flat_arg_shapes` are |
| passed in. |
| """ |
| if op_return_value is not None: |
| assert isinstance(op_return_value, ops.Tensor), op_return_value |
| if func_graph is None: |
| func_graph = FuncGraph(name, collections=collections, |
| capture_by_value=capture_by_value) |
| assert isinstance(func_graph, FuncGraph) |
| if add_control_dependencies: |
| control_manager = AutomaticControlDependencies() |
| else: |
| control_manager = ops.NullContextmanager() |
| with func_graph.as_default(), control_manager as a: |
| current_scope = variable_scope.get_variable_scope() |
| default_use_recource = current_scope.use_resource |
| current_scope.set_use_resource(True) |
| |
| if signature is not None and override_flat_arg_shapes is not None: |
| raise ValueError( |
| "Passed both signature and override_flat_arg_shapes: %s and %s." |
| % (signature, override_flat_arg_shapes)) |
| |
| if signature is not None: |
| args = signature |
| kwargs = {} |
| |
| # Creates and names placeholders for all arguments. |
| if override_flat_arg_shapes is not None: |
| flat_args = nest.flatten(args, expand_composites=True) |
| arg_shapes = override_flat_arg_shapes[:len(flat_args)] |
| kwarg_shapes = override_flat_arg_shapes[len(flat_args):] |
| else: |
| arg_shapes = None |
| kwarg_shapes = None |
| func_args = _get_defun_inputs_from_args( |
| args, arg_names, flat_shapes=arg_shapes) |
| func_kwargs = _get_defun_inputs_from_kwargs( |
| kwargs, flat_shapes=kwarg_shapes) |
| |
| # Convert all Tensors into TensorSpecs before saving the structured inputs. |
| # If storing pure concrete functions that are not called through polymorphic |
| # functions, we don't have access to FunctionSpec, so we need to call the |
| # TensorSpecs by their `arg_names` for later binding. |
| func_graph.structured_input_signature = ( |
| convert_structure_to_signature(func_args, arg_names), |
| convert_structure_to_signature(func_kwargs)) |
| |
| flat_func_args = nest.flatten(func_args, expand_composites=True) |
| flat_func_kwargs = nest.flatten(func_kwargs, expand_composites=True) |
| # Temporarily set inputs to allow graph building code to inspect |
| # them. Reassigned below. |
| func_graph.inputs = [arg for arg in flat_func_args + flat_func_kwargs |
| if isinstance(arg, ops.Tensor)] |
| |
| # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`. |
| # Variables to help check whether mutation happens in calling the function |
| # Copy the recursive list, tuple and map structure, but not base objects |
| func_args_before = nest.pack_sequence_as(func_args, flat_func_args, |
| expand_composites=True) |
| func_kwargs_before = nest.pack_sequence_as( |
| func_kwargs, flat_func_kwargs, expand_composites=True) |
| |
| def convert(x): |
| """Converts a function output to a Tensor.""" |
| if x is None: |
| return None |
| if op_return_value is not None and isinstance(x, ops.Operation): |
| # TODO(b/79881896): we currently can't capture external control deps, so |
| # this won't work if x needs to be captured (i.e. if python_func returns |
| # captured Operations). |
| with ops.control_dependencies([x]): |
| x = array_ops.identity(op_return_value) |
| elif not isinstance(x, tensor_array_ops.TensorArray): |
| try: |
| x = ops.convert_to_tensor_or_composite(x) |
| except (ValueError, TypeError): |
| raise TypeError( |
| "To be compatible with tf.contrib.eager.defun, Python functions " |
| "must return zero or more Tensors; in compilation of %s, found " |
| "return value of type %s, which is not a Tensor." % |
| (str(python_func), type(x))) |
| if add_control_dependencies: |
| x = a.mark_as_return(x) |
| return x |
| |
| try: |
| if autograph: |
| from tensorflow.python import autograph # pylint: disable=g-import-not-at-top |
| _, original_func = tf_decorator.unwrap(python_func) |
| |
| def wrapper(*args, **kwargs): |
| """Calls a converted version of original_func.""" |
| # TODO(mdan): Push this block higher in tf.function's call stack. |
| try: |
| return autograph.converted_call( |
| original_func, |
| autograph.ConversionOptions( |
| recursive=True, |
| optional_features=autograph_options, |
| force_conversion=True, |
| ), args, kwargs) |
| except Exception as e: # pylint:disable=broad-except |
| if hasattr(e, "ag_error_metadata"): |
| raise e.ag_error_metadata.to_exception(type(e)) |
| else: |
| raise |
| |
| # Wrapping around a decorator allows checks like tf_inspect.getargspec |
| # to be accurate. |
| converted_func = tf_decorator.make_decorator(original_func, wrapper) |
| python_func = tf_decorator.rewrap(python_func, original_func, |
| converted_func) |
| |
| func_outputs = python_func(*func_args, **func_kwargs) |
| |
| # invariant: `func_outputs` contains only Tensors, CompositeTensors, |
| # TensorArrays and `None`s. |
| func_outputs = nest.map_structure(convert, func_outputs, |
| expand_composites=True) |
| |
| check_mutation(func_args_before, func_args) |
| check_mutation(func_kwargs_before, func_kwargs) |
| finally: |
| current_scope.set_use_resource(default_use_recource) |
| |
| # Variables in `func_args`, `func_kwargs` should be explicit inputs |
| # to the function, not captured inputs. |
| graph_variables = list(func_graph._watched_variables) # pylint: disable=protected-access |
| arg_variables = set() |
| inputs = [] |
| for arg in (nest.flatten(func_args, expand_composites=True) + |
| nest.flatten(func_kwargs, expand_composites=True)): |
| if isinstance(arg, resource_variable_ops.BaseResourceVariable): |
| # Even if an argument variable was not used in the function, we've |
| # already manually captured the resource Tensor when creating argument |
| # placeholders. |
| resource_placeholder = func_graph.captures.pop(arg.handle, None) |
| if resource_placeholder is None: |
| continue |
| arg_variables.add(arg) |
| inputs.append(resource_placeholder) |
| elif isinstance(arg, ops.Tensor): |
| inputs.append(arg) |
| variables = [v for v in graph_variables if v not in arg_variables] |
| func_graph.inputs = ( |
| inputs + |
| list(func_graph.captures.values()) + |
| nest.flatten( |
| [x[1] for x in func_graph.deferred_captures.values()], |
| expand_composites=True)) |
| |
| func_graph.structured_outputs = func_outputs |
| # Returning a closed-over tensor does not trigger convert_to_tensor. |
| func_graph.outputs.extend( |
| func_graph.capture(x) |
| for x in flatten(func_graph.structured_outputs) |
| if x is not None) |
| |
| func_graph.variables = variables |
| |
| if add_control_dependencies: |
| func_graph.control_outputs.extend(control_manager.ops_which_must_run) |
| |
| return func_graph |
| |
| |
| def maybe_captured(tensor): |
| """If t is a captured value placeholder, returns the original captured value. |
| |
| Args: |
| tensor: Tensor. |
| |
| Returns: |
| A tensor, potentially from a different Graph/FuncGraph. |
| """ |
| if (not isinstance(tensor, ops.EagerTensor) and |
| tensor.op.graph.building_function and tensor.op.type == "Placeholder"): |
| for input_t, placeholder_t in tensor.op.graph.captures.items(): |
| if tensor == placeholder_t: |
| return maybe_captured(input_t) |
| # pylint: enable=protected-access |
| return tensor |
| |
| |
| def device_stack_has_callable(device_stack): |
| """Checks whether a device stack contains a callable.""" |
| return any(callable(spec._device_name_or_function) # pylint: disable=protected-access |
| for spec in device_stack.peek_objs()) |
| |
| |
| def check_mutation(n1, n2): |
| """Check if two list of arguments are exactly the same.""" |
| errmsg = ("Function to be traced should not modify structure of input " |
| "arguments. Check if your function has list and dictionary " |
| "operations that alter input arguments, " |
| "such as `list.pop`, `list.append`") |
| try: |
| nest.assert_same_structure(n1, n2, expand_composites=True) |
| except ValueError: |
| raise ValueError(errmsg) |
| |
| for arg1, arg2 in zip(nest.flatten(n1, expand_composites=True), |
| nest.flatten(n2, expand_composites=True)): |
| if arg1 is not arg2: |
| raise ValueError(errmsg) |
| |
| |
| # TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this. |
| def flatten(sequence): |
| """Like nest.flatten w/ expand_composites, but returns flow for TensorArrays. |
| |
| Args: |
| sequence: A nested structure of Tensors, CompositeTensors, and |
| TensorArrays. |
| |
| Returns: |
| A list of tensors. |
| """ |
| flat_sequence = nest.flatten(sequence, expand_composites=True) |
| return [ |
| item.flow if isinstance(item, tensor_array_ops.TensorArray) else item |
| for item in flat_sequence] |
| |
| |
| # TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this. |
| def pack_sequence_as(structure, flat_sequence): |
| """Like `nest.pack_sequence_as` but also builds TensorArrays from flows. |
| |
| Args: |
| structure: The structure to pack into. May contain Tensors, |
| CompositeTensors, or TensorArrays. |
| flat_sequence: An iterable containing tensors. |
| |
| Returns: |
| A nested structure. |
| |
| Raises: |
| AssertionError if `structure` and `flat_sequence` are not compatible. |
| """ |
| flat_sequence = list(flat_sequence) |
| flattened_structure = nest.flatten(structure, expand_composites=True) |
| if len(flattened_structure) != len(flat_sequence): |
| raise ValueError("Mismatch in element count") |
| for i in range(len(flat_sequence)): |
| if isinstance(flattened_structure[i], tensor_array_ops.TensorArray): |
| flat_sequence[i] = tensor_array_ops.build_ta_with_new_flow( |
| old_ta=flattened_structure[i], flow=flat_sequence[i]) |
| return nest.pack_sequence_as(structure, flat_sequence, expand_composites=True) |
| |
| |
| def _create_substitute_placeholder(value, name=None, dtype=None): |
| """Creates a placeholder for `value` and propagates shape info to it.""" |
| # Note: setting ops.control_dependencies(None) ensures we always put |
| # capturing placeholders outside of any control flow context. |
| with ops.control_dependencies(None): |
| placeholder = graph_placeholder( |
| dtype=dtype or value.dtype, shape=value.shape, name=name) |
| custom_gradient.copy_handle_data(value, placeholder) |
| return placeholder |
| |
| |
| def _get_defun_inputs_from_args(args, names, flat_shapes=None): |
| """Maps Python function positional args to graph-construction inputs.""" |
| return _get_defun_inputs( |
| args, names, structure=args, flat_shapes=flat_shapes) |
| |
| |
| def _get_defun_inputs(args, names, structure, flat_shapes=None): |
| """Maps python function args to graph-construction inputs. |
| |
| Args: |
| args: A flat list of user-specified arguments. |
| names: A list of strings with user-specified argument names, same length as |
| `args`. May be `None`, in which case a generic name is used. |
| structure: The original argument list or dictionary. |
| flat_shapes: A flat list of values that are either `None` or |
| instances of `TensorShape`. If provided, then length must match |
| that of `nest.flatten(args, expand_composites=True)`; and locations where |
| `args` are instances of `Tensor` must have a corresponding `TensorShape` |
| in `flat_shapes`. May be `None`, in which case exact shapes are read |
| directly from the args. |
| |
| Returns: |
| Placeholders with the same structure as `structure`. |
| |
| Raises: |
| RuntimeError: if `flat_shapes` is provided, but |
| `len(flat_shapes) != len(nest.flatten(args, expand_composites=True))`. |
| RuntimeError: if a shape from `flat_shapes` is not None |
| for an argument that is not a `Tensor`, `TensorSpec`, |
| or `ResourceVariable`. |
| """ |
| func_graph = ops.get_default_graph() |
| function_inputs = [] |
| if names is None: |
| names = [None] * len(args) |
| if flat_shapes is None: |
| shapes_iter = itertools.repeat(None) |
| else: |
| len_flat_args = len(nest.flatten(args, expand_composites=True)) |
| if len_flat_args != len(flat_shapes): |
| raise RuntimeError( |
| "Length of fully flat shapes (%d) must match that of " |
| "flatten(args) (%d). args: %s, flat_shapes: %s" |
| % (len(flat_shapes), |
| len_flat_args, |
| args, |
| flat_shapes)) |
| shapes_iter = iter(flat_shapes) |
| for arg_value, name in zip(args, names): |
| flattened = nest.flatten(arg_value, expand_composites=True) |
| tensor_specs = [ |
| arg for arg in flattened if isinstance(arg, tensor_spec.TensorSpec) |
| ] |
| specified_names = [arg.name for arg in tensor_specs if arg.name] |
| if specified_names and len(specified_names) < len(tensor_specs): |
| raise ValueError("If specifying TensorSpec names for nested structures, " |
| "either zero or all names have to be specified.") |
| |
| for arg in flattened: |
| # We have a shape entry for each arg, regadless of whether it's a real |
| # Tensor or not. For non-tensor entries it should be None. |
| shape = next(shapes_iter) |
| if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)): |
| if isinstance(arg, tensor_spec.TensorSpec) and arg.name: |
| requested_name = arg.name |
| else: |
| requested_name = name |
| placeholder_shape = shape if shape is not None else arg.shape |
| try: |
| placeholder = graph_placeholder( |
| arg.dtype, placeholder_shape, |
| name=requested_name) |
| except ValueError: |
| # Sometimes parameter names are not valid op names, so fall back to |
| # unnamed placeholders. |
| placeholder = graph_placeholder(arg.dtype, placeholder_shape) |
| if name is not None: |
| # Record the requested/user-specified name in case it's different than |
| # the uniquified name, for validation when exporting signatures. |
| placeholder.op._set_attr( # pylint: disable=protected-access |
| "_user_specified_name", |
| attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name))) |
| function_inputs.append(placeholder) |
| elif isinstance(arg, resource_variable_ops.BaseResourceVariable): |
| # Capture arg variables to create placeholders for them. These will be |
| # removed as captures after the function is traced (since otherwise we'd |
| # just add it back with a new placeholder when the variable was |
| # referenced). |
| placeholder = func_graph.capture(arg.handle, name=name) |
| placeholder.op._set_attr( # pylint: disable=protected-access |
| "_user_specified_name", |
| attr_value_pb2.AttrValue(s=compat.as_bytes(name))) |
| function_inputs.append(arg) |
| else: |
| if shape is not None: |
| raise RuntimeError( |
| "Expected provided shape override to be None for arg that isn't " |
| "a Tensor, but saw arg: '%s', shape: '%s'. args: %s" |
| % (arg, shape, args)) |
| function_inputs.append(arg) |
| return nest.pack_sequence_as(structure, function_inputs, |
| expand_composites=True) |
| |
| |
| def _get_defun_inputs_from_kwargs(kwargs, flat_shapes): |
| """Maps Python function keyword args to graph-construction inputs.""" |
| if kwargs: |
| names, args = zip(*sorted(kwargs.items())) |
| else: |
| names = [] |
| args = [] |
| return _get_defun_inputs( |
| args, names, structure=kwargs, flat_shapes=flat_shapes) |
| |
| |
| def dismantle_func_graph(func_graph): |
| """Removes reference cycles in `func_graph` FuncGraph. |
| |
| Helpful for making sure the garbage collector doesn't need to run when |
| the FuncGraph goes out of scope, e.g. in tests using defun with |
| @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True). |
| |
| Args: |
| func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable |
| after this function. |
| """ |
| # TODO(b/115366440): Delete this method when a custom OrderedDict is added. |
| # Clearing captures using clear() leaves some cycles around. |
| while func_graph.captures: |
| func_graph.captures.popitem() |
| memory.dismantle_ordered_dict(func_graph.captures) |
| while func_graph.deferred_captures: |
| func_graph.deferred_captures.popitem() |
| memory.dismantle_ordered_dict(func_graph.deferred_captures) |
| ops.dismantle_graph(func_graph) |