| # Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Utility functions for the graph_editor. |
| """ |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import collections |
| import re |
| from six import iteritems |
| from tensorflow.python.framework import ops as tf_ops |
| from tensorflow.python.ops import array_ops as tf_array_ops |
| |
| __all__ = [ |
| "make_list_of_op", |
| "get_tensors", |
| "make_list_of_t", |
| "get_generating_ops", |
| "get_consuming_ops", |
| "ControlOutputs", |
| "placeholder_name", |
| "make_placeholder_from_tensor", |
| "make_placeholder_from_dtype_and_shape", |
| ] |
| |
| |
| # The graph editor sometimes need to create placeholders, they are named |
| # "geph_*". "geph" stands for Graph-Editor PlaceHolder. |
| _DEFAULT_PLACEHOLDER_PREFIX = "geph" |
| |
| |
| def concatenate_unique(la, lb): |
| """Add all the elements of `lb` to `la` if they are not there already. |
| |
| The elements added to `la` maintain ordering with respect to `lb`. |
| |
| Args: |
| la: List of Python objects. |
| lb: List of Python objects. |
| Returns: |
| `la`: The list `la` with missing elements from `lb`. |
| """ |
| la_set = set(la) |
| for l in lb: |
| if l not in la_set: |
| la.append(l) |
| la_set.add(l) |
| return la |
| |
| |
| # TODO(fkp): very generic code, it should be moved in a more generic place. |
| class ListView(object): |
| """Immutable list wrapper. |
| |
| This class is strongly inspired by the one in tf.Operation. |
| """ |
| |
| def __init__(self, list_): |
| if not isinstance(list_, list): |
| raise TypeError("Expected a list, got: {}.".format(type(list_))) |
| self._list = list_ |
| |
| def __iter__(self): |
| return iter(self._list) |
| |
| def __len__(self): |
| return len(self._list) |
| |
| def __bool__(self): |
| return bool(self._list) |
| |
| # Python 3 wants __bool__, Python 2.7 wants __nonzero__ |
| __nonzero__ = __bool__ |
| |
| def __getitem__(self, i): |
| return self._list[i] |
| |
| def __add__(self, other): |
| if not isinstance(other, list): |
| other = list(other) |
| return list(self) + other |
| |
| |
| # TODO(fkp): very generic code, it should be moved in a more generic place. |
| def is_iterable(obj): |
| """Return true if the object is iterable.""" |
| if isinstance(obj, tf_ops.Tensor): |
| return False |
| try: |
| _ = iter(obj) |
| except Exception: # pylint: disable=broad-except |
| return False |
| return True |
| |
| |
| def flatten_tree(tree, leaves=None): |
| """Flatten a tree into a list. |
| |
| Args: |
| tree: iterable or not. If iterable, its elements (child) can also be |
| iterable or not. |
| leaves: list to which the tree leaves are appended (None by default). |
| Returns: |
| A list of all the leaves in the tree. |
| """ |
| if leaves is None: |
| leaves = [] |
| if isinstance(tree, dict): |
| for _, child in iteritems(tree): |
| flatten_tree(child, leaves) |
| elif is_iterable(tree): |
| for child in tree: |
| flatten_tree(child, leaves) |
| else: |
| leaves.append(tree) |
| return leaves |
| |
| |
| def transform_tree(tree, fn, iterable_type=tuple): |
| """Transform all the nodes of a tree. |
| |
| Args: |
| tree: iterable or not. If iterable, its elements (child) can also be |
| iterable or not. |
| fn: function to apply to each leaves. |
| iterable_type: type use to construct the resulting tree for unknown |
| iterable, typically `list` or `tuple`. |
| Returns: |
| A tree whose leaves has been transformed by `fn`. |
| The hierarchy of the output tree mimics the one of the input tree. |
| """ |
| if is_iterable(tree): |
| if isinstance(tree, dict): |
| res = tree.__new__(type(tree)) |
| res.__init__( |
| (k, transform_tree(child, fn)) for k, child in iteritems(tree)) |
| return res |
| elif isinstance(tree, tuple): |
| # NamedTuple? |
| if hasattr(tree, "_asdict"): |
| res = tree.__new__(type(tree), **transform_tree(tree._asdict(), fn)) |
| else: |
| res = tree.__new__(type(tree), |
| (transform_tree(child, fn) for child in tree)) |
| return res |
| elif isinstance(tree, collections.Sequence): |
| res = tree.__new__(type(tree)) |
| res.__init__(transform_tree(child, fn) for child in tree) |
| return res |
| else: |
| return iterable_type(transform_tree(child, fn) for child in tree) |
| else: |
| return fn(tree) |
| |
| |
| def check_graphs(*args): |
| """Check that all the element in args belong to the same graph. |
| |
| Args: |
| *args: a list of object with a obj.graph property. |
| Raises: |
| ValueError: if all the elements do not belong to the same graph. |
| """ |
| graph = None |
| for i, sgv in enumerate(args): |
| if graph is None and sgv.graph is not None: |
| graph = sgv.graph |
| elif sgv.graph is not None and sgv.graph is not graph: |
| raise ValueError("Argument[{}]: Wrong graph!".format(i)) |
| |
| |
| def get_unique_graph(tops, check_types=None, none_if_empty=False): |
| """Return the unique graph used by the all the elements in tops. |
| |
| Args: |
| tops: list of elements to check (usually a list of tf.Operation and/or |
| tf.Tensor). Or a tf.Graph. |
| check_types: check that the element in tops are of given type(s). If None, |
| the types (tf.Operation, tf.Tensor) are used. |
| none_if_empty: don't raise an error if tops is an empty list, just return |
| None. |
| Returns: |
| The unique graph used by all the tops. |
| Raises: |
| TypeError: if tops is not a iterable of tf.Operation. |
| ValueError: if the graph is not unique. |
| """ |
| if isinstance(tops, tf_ops.Graph): |
| return tops |
| if not is_iterable(tops): |
| raise TypeError("{} is not iterable".format(type(tops))) |
| if check_types is None: |
| check_types = (tf_ops.Operation, tf_ops.Tensor) |
| elif not is_iterable(check_types): |
| check_types = (check_types,) |
| g = None |
| for op in tops: |
| if not isinstance(op, check_types): |
| raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str( |
| t) for t in check_types]), type(op))) |
| if g is None: |
| g = op.graph |
| elif g is not op.graph: |
| raise ValueError("Operation {} does not belong to given graph".format(op)) |
| if g is None and not none_if_empty: |
| raise ValueError("Can't find the unique graph of an empty list") |
| return g |
| |
| |
| def make_list_of_op(ops, check_graph=True, allow_graph=True, ignore_ts=False): |
| """Convert ops to a list of `tf.Operation`. |
| |
| Args: |
| ops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single |
| operation. |
| check_graph: if `True` check if all the operations belong to the same graph. |
| allow_graph: if `False` a `tf.Graph` cannot be converted. |
| ignore_ts: if True, silently ignore `tf.Tensor`. |
| Returns: |
| A newly created list of `tf.Operation`. |
| Raises: |
| TypeError: if ops cannot be converted to a list of `tf.Operation` or, |
| if `check_graph` is `True`, if all the ops do not belong to the |
| same graph. |
| """ |
| if isinstance(ops, tf_ops.Graph): |
| if allow_graph: |
| return ops.get_operations() |
| else: |
| raise TypeError("allow_graph is False: cannot convert a tf.Graph.") |
| else: |
| if not is_iterable(ops): |
| ops = [ops] |
| if not ops: |
| return [] |
| if check_graph: |
| check_types = None if ignore_ts else tf_ops.Operation |
| get_unique_graph(ops, check_types=check_types) |
| return [op for op in ops if isinstance(op, tf_ops.Operation)] |
| |
| |
| # TODO(fkp): move this function in tf.Graph? |
| def get_tensors(graph): |
| """get all the tensors which are input or output of an op in the graph. |
| |
| Args: |
| graph: a `tf.Graph`. |
| Returns: |
| A list of `tf.Tensor`. |
| Raises: |
| TypeError: if graph is not a `tf.Graph`. |
| """ |
| if not isinstance(graph, tf_ops.Graph): |
| raise TypeError("Expected a graph, got: {}".format(type(graph))) |
| ts = [] |
| for op in graph.get_operations(): |
| ts += op.outputs |
| return ts |
| |
| |
| def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False): |
| """Convert ts to a list of `tf.Tensor`. |
| |
| Args: |
| ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor. |
| check_graph: if `True` check if all the tensors belong to the same graph. |
| allow_graph: if `False` a `tf.Graph` cannot be converted. |
| ignore_ops: if `True`, silently ignore `tf.Operation`. |
| Returns: |
| A newly created list of `tf.Tensor`. |
| Raises: |
| TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or, |
| if `check_graph` is `True`, if all the ops do not belong to the same graph. |
| """ |
| if isinstance(ts, tf_ops.Graph): |
| if allow_graph: |
| return get_tensors(ts) |
| else: |
| raise TypeError("allow_graph is False: cannot convert a tf.Graph.") |
| else: |
| if not is_iterable(ts): |
| ts = [ts] |
| if not ts: |
| return [] |
| if check_graph: |
| check_types = None if ignore_ops else tf_ops.Tensor |
| get_unique_graph(ts, check_types=check_types) |
| return [t for t in ts if isinstance(t, tf_ops.Tensor)] |
| |
| |
| def get_generating_ops(ts): |
| """Return all the generating ops of the tensors in `ts`. |
| |
| Args: |
| ts: a list of `tf.Tensor` |
| Returns: |
| A list of all the generating `tf.Operation` of the tensors in `ts`. |
| Raises: |
| TypeError: if `ts` cannot be converted to a list of `tf.Tensor`. |
| """ |
| ts = make_list_of_t(ts, allow_graph=False) |
| return [t.op for t in ts] |
| |
| |
| def get_consuming_ops(ts): |
| """Return all the consuming ops of the tensors in ts. |
| |
| Args: |
| ts: a list of `tf.Tensor` |
| Returns: |
| A list of all the consuming `tf.Operation` of the tensors in `ts`. |
| Raises: |
| TypeError: if ts cannot be converted to a list of `tf.Tensor`. |
| """ |
| ts = make_list_of_t(ts, allow_graph=False) |
| ops = [] |
| for t in ts: |
| for op in t.consumers(): |
| if op not in ops: |
| ops.append(op) |
| return ops |
| |
| |
| class ControlOutputs(object): |
| """The control outputs topology.""" |
| |
| def __init__(self, graph): |
| """Create a dictionary of control-output dependencies. |
| |
| Args: |
| graph: a `tf.Graph`. |
| Returns: |
| A dictionary where a key is a `tf.Operation` instance and the |
| corresponding value is a list of all the ops which have the key |
| as one of their control-input dependencies. |
| Raises: |
| TypeError: graph is not a `tf.Graph`. |
| """ |
| if not isinstance(graph, tf_ops.Graph): |
| raise TypeError("Expected a tf.Graph, got: {}".format(type(graph))) |
| self._control_outputs = {} |
| self._graph = graph |
| self._version = None |
| self._build() |
| |
| def update(self): |
| """Update the control outputs if the graph has changed.""" |
| if self._version != self._graph.version: |
| self._build() |
| return self |
| |
| def _build(self): |
| """Build the control outputs dictionary.""" |
| self._control_outputs.clear() |
| ops = self._graph.get_operations() |
| for op in ops: |
| for control_input in op.control_inputs: |
| if control_input not in self._control_outputs: |
| self._control_outputs[control_input] = [] |
| if op not in self._control_outputs[control_input]: |
| self._control_outputs[control_input].append(op) |
| self._version = self._graph.version |
| |
| def get_all(self): |
| return self._control_outputs |
| |
| def get(self, op): |
| """return the control outputs of op.""" |
| if op in self._control_outputs: |
| return self._control_outputs[op] |
| else: |
| return () |
| |
| @property |
| def graph(self): |
| return self._graph |
| |
| |
| def scope_finalize(scope): |
| if scope and scope[-1] != "/": |
| scope += "/" |
| return scope |
| |
| |
| def scope_dirname(scope): |
| slash = scope.rfind("/") |
| if slash == -1: |
| return "" |
| return scope[:slash + 1] |
| |
| |
| def scope_basename(scope): |
| slash = scope.rfind("/") |
| if slash == -1: |
| return scope |
| return scope[slash + 1:] |
| |
| |
| def placeholder_name(t=None, scope=None, prefix=_DEFAULT_PLACEHOLDER_PREFIX): |
| """Create placeholder name for the graph editor. |
| |
| Args: |
| t: optional tensor on which the placeholder operation's name will be based |
| on |
| scope: absolute scope with which to prefix the placeholder's name. None |
| means that the scope of t is preserved. "" means the root scope. |
| prefix: placeholder name prefix. |
| Returns: |
| A new placeholder name prefixed by "geph". Note that "geph" stands for |
| Graph Editor PlaceHolder. This convention allows to quickly identify the |
| placeholder generated by the Graph Editor. |
| Raises: |
| TypeError: if t is not None or a tf.Tensor. |
| """ |
| if scope is not None: |
| scope = scope_finalize(scope) |
| if t is not None: |
| if not isinstance(t, tf_ops.Tensor): |
| raise TypeError("Expected a tf.Tenfor, got: {}".format(type(t))) |
| op_dirname = scope_dirname(t.op.name) |
| op_basename = scope_basename(t.op.name) |
| if scope is None: |
| scope = op_dirname |
| |
| if op_basename.startswith("{}__".format(prefix)): |
| ph_name = op_basename |
| else: |
| ph_name = "{}__{}_{}".format(prefix, op_basename, t.value_index) |
| |
| return scope + ph_name |
| else: |
| if scope is None: |
| scope = "" |
| return "{}{}".format(scope, prefix) |
| |
| |
| def make_placeholder_from_tensor(t, scope=None, |
| prefix=_DEFAULT_PLACEHOLDER_PREFIX): |
| """Create a `tf.compat.v1.placeholder` for the Graph Editor. |
| |
| Note that the correct graph scope must be set by the calling function. |
| |
| Args: |
| t: a `tf.Tensor` whose name will be used to create the placeholder (see |
| function placeholder_name). |
| scope: absolute scope within which to create the placeholder. None means |
| that the scope of `t` is preserved. `""` means the root scope. |
| prefix: placeholder name prefix. |
| |
| Returns: |
| A newly created `tf.compat.v1.placeholder`. |
| Raises: |
| TypeError: if `t` is not `None` or a `tf.Tensor`. |
| """ |
| return tf_array_ops.placeholder( |
| dtype=t.dtype, shape=t.get_shape(), |
| name=placeholder_name(t, scope=scope, prefix=prefix)) |
| |
| |
| def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None, |
| prefix=_DEFAULT_PLACEHOLDER_PREFIX): |
| """Create a tf.compat.v1.placeholder for the Graph Editor. |
| |
| Note that the correct graph scope must be set by the calling function. |
| The placeholder is named using the function placeholder_name (with no |
| tensor argument). |
| |
| Args: |
| dtype: the tensor type. |
| shape: the tensor shape (optional). |
| scope: absolute scope within which to create the placeholder. None means |
| that the scope of t is preserved. "" means the root scope. |
| prefix: placeholder name prefix. |
| |
| Returns: |
| A newly created tf.placeholder. |
| """ |
| return tf_array_ops.placeholder( |
| dtype=dtype, shape=shape, |
| name=placeholder_name(scope=scope, prefix=prefix)) |
| |
| |
| _INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$") |
| |
| |
| def get_predefined_collection_names(): |
| """Return all the predefined collection names.""" |
| return [getattr(tf_ops.GraphKeys, key) for key in dir(tf_ops.GraphKeys) |
| if not _INTERNAL_VARIABLE_RE.match(key)] |
| |
| |
| def find_corresponding_elem(target, dst_graph, dst_scope="", src_scope=""): |
| """Find corresponding op/tensor in a different graph. |
| |
| Args: |
| target: A `tf.Tensor` or a `tf.Operation` belonging to the original graph. |
| dst_graph: The graph in which the corresponding graph element must be found. |
| dst_scope: A scope which is prepended to the name to look for. |
| src_scope: A scope which is removed from the original of `target` name. |
| |
| Returns: |
| The corresponding tf.Tensor` or a `tf.Operation`. |
| |
| Raises: |
| ValueError: if `src_name` does not start with `src_scope`. |
| TypeError: if `target` is not a `tf.Tensor` or a `tf.Operation` |
| KeyError: If the corresponding graph element cannot be found. |
| """ |
| src_name = target.name |
| if src_scope: |
| src_scope = scope_finalize(src_scope) |
| if not src_name.startswidth(src_scope): |
| raise ValueError("{} does not start with {}".format(src_name, src_scope)) |
| src_name = src_name[len(src_scope):] |
| |
| dst_name = src_name |
| if dst_scope: |
| dst_scope = scope_finalize(dst_scope) |
| dst_name = dst_scope + dst_name |
| |
| if isinstance(target, tf_ops.Tensor): |
| return dst_graph.get_tensor_by_name(dst_name) |
| if isinstance(target, tf_ops.Operation): |
| return dst_graph.get_operation_by_name(dst_name) |
| raise TypeError("Expected tf.Tensor or tf.Operation, got: {}", type(target)) |
| |
| |
| def find_corresponding(targets, dst_graph, dst_scope="", src_scope=""): |
| """Find corresponding ops/tensors in a different graph. |
| |
| `targets` is a Python tree, that is, a nested structure of iterable |
| (list, tupple, dictionary) whose leaves are instances of |
| `tf.Tensor` or `tf.Operation` |
| |
| Args: |
| targets: A Python tree containing `tf.Tensor` or `tf.Operation` |
| belonging to the original graph. |
| dst_graph: The graph in which the corresponding graph element must be found. |
| dst_scope: A scope which is prepended to the name to look for. |
| src_scope: A scope which is removed from the original of `top` name. |
| |
| Returns: |
| A Python tree containin the corresponding tf.Tensor` or a `tf.Operation`. |
| |
| Raises: |
| ValueError: if `src_name` does not start with `src_scope`. |
| TypeError: if `top` is not a `tf.Tensor` or a `tf.Operation` |
| KeyError: If the corresponding graph element cannot be found. |
| """ |
| def func(top): |
| return find_corresponding_elem(top, dst_graph, dst_scope, src_scope) |
| return transform_tree(targets, func) |