blob: 4b53d182f342100aeaafdc6a25de49c15ddac700 [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for the graph_editor.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
from six import iteritems
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.ops import array_ops as tf_array_ops
__all__ = [
"make_list_of_op",
"get_tensors",
"make_list_of_t",
"get_generating_ops",
"get_consuming_ops",
"ControlOutputs",
"placeholder_name",
"make_placeholder_from_tensor",
"make_placeholder_from_dtype_and_shape",
]
# The graph editor sometimes need to create placeholders, they are named
# "geph_*". "geph" stands for Graph-Editor PlaceHolder.
_DEFAULT_PLACEHOLDER_PREFIX = "geph"
def concatenate_unique(la, lb):
"""Add all the elements of `lb` to `la` if they are not there already.
The elements added to `la` maintain ordering with respect to `lb`.
Args:
la: List of Python objects.
lb: List of Python objects.
Returns:
`la`: The list `la` with missing elements from `lb`.
"""
la_set = set(la)
for l in lb:
if l not in la_set:
la.append(l)
la_set.add(l)
return la
# TODO(fkp): very generic code, it should be moved in a more generic place.
class ListView(object):
"""Immutable list wrapper.
This class is strongly inspired by the one in tf.Operation.
"""
def __init__(self, list_):
if not isinstance(list_, list):
raise TypeError("Expected a list, got: {}.".format(type(list_)))
self._list = list_
def __iter__(self):
return iter(self._list)
def __len__(self):
return len(self._list)
def __bool__(self):
return bool(self._list)
# Python 3 wants __bool__, Python 2.7 wants __nonzero__
__nonzero__ = __bool__
def __getitem__(self, i):
return self._list[i]
def __add__(self, other):
if not isinstance(other, list):
other = list(other)
return list(self) + other
# TODO(fkp): very generic code, it should be moved in a more generic place.
def is_iterable(obj):
"""Return true if the object is iterable."""
if isinstance(obj, tf_ops.Tensor):
return False
try:
_ = iter(obj)
except Exception: # pylint: disable=broad-except
return False
return True
def flatten_tree(tree, leaves=None):
"""Flatten a tree into a list.
Args:
tree: iterable or not. If iterable, its elements (child) can also be
iterable or not.
leaves: list to which the tree leaves are appended (None by default).
Returns:
A list of all the leaves in the tree.
"""
if leaves is None:
leaves = []
if isinstance(tree, dict):
for _, child in iteritems(tree):
flatten_tree(child, leaves)
elif is_iterable(tree):
for child in tree:
flatten_tree(child, leaves)
else:
leaves.append(tree)
return leaves
def transform_tree(tree, fn, iterable_type=tuple):
"""Transform all the nodes of a tree.
Args:
tree: iterable or not. If iterable, its elements (child) can also be
iterable or not.
fn: function to apply to each leaves.
iterable_type: type use to construct the resulting tree for unknown
iterable, typically `list` or `tuple`.
Returns:
A tree whose leaves has been transformed by `fn`.
The hierarchy of the output tree mimics the one of the input tree.
"""
if is_iterable(tree):
if isinstance(tree, dict):
res = tree.__new__(type(tree))
res.__init__(
(k, transform_tree(child, fn)) for k, child in iteritems(tree))
return res
elif isinstance(tree, tuple):
# NamedTuple?
if hasattr(tree, "_asdict"):
res = tree.__new__(type(tree), **transform_tree(tree._asdict(), fn))
else:
res = tree.__new__(type(tree),
(transform_tree(child, fn) for child in tree))
return res
elif isinstance(tree, collections.Sequence):
res = tree.__new__(type(tree))
res.__init__(transform_tree(child, fn) for child in tree)
return res
else:
return iterable_type(transform_tree(child, fn) for child in tree)
else:
return fn(tree)
def check_graphs(*args):
"""Check that all the element in args belong to the same graph.
Args:
*args: a list of object with a obj.graph property.
Raises:
ValueError: if all the elements do not belong to the same graph.
"""
graph = None
for i, sgv in enumerate(args):
if graph is None and sgv.graph is not None:
graph = sgv.graph
elif sgv.graph is not None and sgv.graph is not graph:
raise ValueError("Argument[{}]: Wrong graph!".format(i))
def get_unique_graph(tops, check_types=None, none_if_empty=False):
"""Return the unique graph used by the all the elements in tops.
Args:
tops: list of elements to check (usually a list of tf.Operation and/or
tf.Tensor). Or a tf.Graph.
check_types: check that the element in tops are of given type(s). If None,
the types (tf.Operation, tf.Tensor) are used.
none_if_empty: don't raise an error if tops is an empty list, just return
None.
Returns:
The unique graph used by all the tops.
Raises:
TypeError: if tops is not a iterable of tf.Operation.
ValueError: if the graph is not unique.
"""
if isinstance(tops, tf_ops.Graph):
return tops
if not is_iterable(tops):
raise TypeError("{} is not iterable".format(type(tops)))
if check_types is None:
check_types = (tf_ops.Operation, tf_ops.Tensor)
elif not is_iterable(check_types):
check_types = (check_types,)
g = None
for op in tops:
if not isinstance(op, check_types):
raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str(
t) for t in check_types]), type(op)))
if g is None:
g = op.graph
elif g is not op.graph:
raise ValueError("Operation {} does not belong to given graph".format(op))
if g is None and not none_if_empty:
raise ValueError("Can't find the unique graph of an empty list")
return g
def make_list_of_op(ops, check_graph=True, allow_graph=True, ignore_ts=False):
"""Convert ops to a list of `tf.Operation`.
Args:
ops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single
operation.
check_graph: if `True` check if all the operations belong to the same graph.
allow_graph: if `False` a `tf.Graph` cannot be converted.
ignore_ts: if True, silently ignore `tf.Tensor`.
Returns:
A newly created list of `tf.Operation`.
Raises:
TypeError: if ops cannot be converted to a list of `tf.Operation` or,
if `check_graph` is `True`, if all the ops do not belong to the
same graph.
"""
if isinstance(ops, tf_ops.Graph):
if allow_graph:
return ops.get_operations()
else:
raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
else:
if not is_iterable(ops):
ops = [ops]
if not ops:
return []
if check_graph:
check_types = None if ignore_ts else tf_ops.Operation
get_unique_graph(ops, check_types=check_types)
return [op for op in ops if isinstance(op, tf_ops.Operation)]
# TODO(fkp): move this function in tf.Graph?
def get_tensors(graph):
"""get all the tensors which are input or output of an op in the graph.
Args:
graph: a `tf.Graph`.
Returns:
A list of `tf.Tensor`.
Raises:
TypeError: if graph is not a `tf.Graph`.
"""
if not isinstance(graph, tf_ops.Graph):
raise TypeError("Expected a graph, got: {}".format(type(graph)))
ts = []
for op in graph.get_operations():
ts += op.outputs
return ts
def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False):
"""Convert ts to a list of `tf.Tensor`.
Args:
ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor.
check_graph: if `True` check if all the tensors belong to the same graph.
allow_graph: if `False` a `tf.Graph` cannot be converted.
ignore_ops: if `True`, silently ignore `tf.Operation`.
Returns:
A newly created list of `tf.Tensor`.
Raises:
TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or,
if `check_graph` is `True`, if all the ops do not belong to the same graph.
"""
if isinstance(ts, tf_ops.Graph):
if allow_graph:
return get_tensors(ts)
else:
raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
else:
if not is_iterable(ts):
ts = [ts]
if not ts:
return []
if check_graph:
check_types = None if ignore_ops else tf_ops.Tensor
get_unique_graph(ts, check_types=check_types)
return [t for t in ts if isinstance(t, tf_ops.Tensor)]
def get_generating_ops(ts):
"""Return all the generating ops of the tensors in `ts`.
Args:
ts: a list of `tf.Tensor`
Returns:
A list of all the generating `tf.Operation` of the tensors in `ts`.
Raises:
TypeError: if `ts` cannot be converted to a list of `tf.Tensor`.
"""
ts = make_list_of_t(ts, allow_graph=False)
return [t.op for t in ts]
def get_consuming_ops(ts):
"""Return all the consuming ops of the tensors in ts.
Args:
ts: a list of `tf.Tensor`
Returns:
A list of all the consuming `tf.Operation` of the tensors in `ts`.
Raises:
TypeError: if ts cannot be converted to a list of `tf.Tensor`.
"""
ts = make_list_of_t(ts, allow_graph=False)
ops = []
for t in ts:
for op in t.consumers():
if op not in ops:
ops.append(op)
return ops
class ControlOutputs(object):
"""The control outputs topology."""
def __init__(self, graph):
"""Create a dictionary of control-output dependencies.
Args:
graph: a `tf.Graph`.
Returns:
A dictionary where a key is a `tf.Operation` instance and the
corresponding value is a list of all the ops which have the key
as one of their control-input dependencies.
Raises:
TypeError: graph is not a `tf.Graph`.
"""
if not isinstance(graph, tf_ops.Graph):
raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
self._control_outputs = {}
self._graph = graph
self._version = None
self._build()
def update(self):
"""Update the control outputs if the graph has changed."""
if self._version != self._graph.version:
self._build()
return self
def _build(self):
"""Build the control outputs dictionary."""
self._control_outputs.clear()
ops = self._graph.get_operations()
for op in ops:
for control_input in op.control_inputs:
if control_input not in self._control_outputs:
self._control_outputs[control_input] = []
if op not in self._control_outputs[control_input]:
self._control_outputs[control_input].append(op)
self._version = self._graph.version
def get_all(self):
return self._control_outputs
def get(self, op):
"""return the control outputs of op."""
if op in self._control_outputs:
return self._control_outputs[op]
else:
return ()
@property
def graph(self):
return self._graph
def scope_finalize(scope):
if scope and scope[-1] != "/":
scope += "/"
return scope
def scope_dirname(scope):
slash = scope.rfind("/")
if slash == -1:
return ""
return scope[:slash + 1]
def scope_basename(scope):
slash = scope.rfind("/")
if slash == -1:
return scope
return scope[slash + 1:]
def placeholder_name(t=None, scope=None, prefix=_DEFAULT_PLACEHOLDER_PREFIX):
"""Create placeholder name for the graph editor.
Args:
t: optional tensor on which the placeholder operation's name will be based
on
scope: absolute scope with which to prefix the placeholder's name. None
means that the scope of t is preserved. "" means the root scope.
prefix: placeholder name prefix.
Returns:
A new placeholder name prefixed by "geph". Note that "geph" stands for
Graph Editor PlaceHolder. This convention allows to quickly identify the
placeholder generated by the Graph Editor.
Raises:
TypeError: if t is not None or a tf.Tensor.
"""
if scope is not None:
scope = scope_finalize(scope)
if t is not None:
if not isinstance(t, tf_ops.Tensor):
raise TypeError("Expected a tf.Tenfor, got: {}".format(type(t)))
op_dirname = scope_dirname(t.op.name)
op_basename = scope_basename(t.op.name)
if scope is None:
scope = op_dirname
if op_basename.startswith("{}__".format(prefix)):
ph_name = op_basename
else:
ph_name = "{}__{}_{}".format(prefix, op_basename, t.value_index)
return scope + ph_name
else:
if scope is None:
scope = ""
return "{}{}".format(scope, prefix)
def make_placeholder_from_tensor(t, scope=None,
prefix=_DEFAULT_PLACEHOLDER_PREFIX):
"""Create a `tf.compat.v1.placeholder` for the Graph Editor.
Note that the correct graph scope must be set by the calling function.
Args:
t: a `tf.Tensor` whose name will be used to create the placeholder (see
function placeholder_name).
scope: absolute scope within which to create the placeholder. None means
that the scope of `t` is preserved. `""` means the root scope.
prefix: placeholder name prefix.
Returns:
A newly created `tf.compat.v1.placeholder`.
Raises:
TypeError: if `t` is not `None` or a `tf.Tensor`.
"""
return tf_array_ops.placeholder(
dtype=t.dtype, shape=t.get_shape(),
name=placeholder_name(t, scope=scope, prefix=prefix))
def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None,
prefix=_DEFAULT_PLACEHOLDER_PREFIX):
"""Create a tf.compat.v1.placeholder for the Graph Editor.
Note that the correct graph scope must be set by the calling function.
The placeholder is named using the function placeholder_name (with no
tensor argument).
Args:
dtype: the tensor type.
shape: the tensor shape (optional).
scope: absolute scope within which to create the placeholder. None means
that the scope of t is preserved. "" means the root scope.
prefix: placeholder name prefix.
Returns:
A newly created tf.placeholder.
"""
return tf_array_ops.placeholder(
dtype=dtype, shape=shape,
name=placeholder_name(scope=scope, prefix=prefix))
_INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$")
def get_predefined_collection_names():
"""Return all the predefined collection names."""
return [getattr(tf_ops.GraphKeys, key) for key in dir(tf_ops.GraphKeys)
if not _INTERNAL_VARIABLE_RE.match(key)]
def find_corresponding_elem(target, dst_graph, dst_scope="", src_scope=""):
"""Find corresponding op/tensor in a different graph.
Args:
target: A `tf.Tensor` or a `tf.Operation` belonging to the original graph.
dst_graph: The graph in which the corresponding graph element must be found.
dst_scope: A scope which is prepended to the name to look for.
src_scope: A scope which is removed from the original of `target` name.
Returns:
The corresponding tf.Tensor` or a `tf.Operation`.
Raises:
ValueError: if `src_name` does not start with `src_scope`.
TypeError: if `target` is not a `tf.Tensor` or a `tf.Operation`
KeyError: If the corresponding graph element cannot be found.
"""
src_name = target.name
if src_scope:
src_scope = scope_finalize(src_scope)
if not src_name.startswidth(src_scope):
raise ValueError("{} does not start with {}".format(src_name, src_scope))
src_name = src_name[len(src_scope):]
dst_name = src_name
if dst_scope:
dst_scope = scope_finalize(dst_scope)
dst_name = dst_scope + dst_name
if isinstance(target, tf_ops.Tensor):
return dst_graph.get_tensor_by_name(dst_name)
if isinstance(target, tf_ops.Operation):
return dst_graph.get_operation_by_name(dst_name)
raise TypeError("Expected tf.Tensor or tf.Operation, got: {}", type(target))
def find_corresponding(targets, dst_graph, dst_scope="", src_scope=""):
"""Find corresponding ops/tensors in a different graph.
`targets` is a Python tree, that is, a nested structure of iterable
(list, tupple, dictionary) whose leaves are instances of
`tf.Tensor` or `tf.Operation`
Args:
targets: A Python tree containing `tf.Tensor` or `tf.Operation`
belonging to the original graph.
dst_graph: The graph in which the corresponding graph element must be found.
dst_scope: A scope which is prepended to the name to look for.
src_scope: A scope which is removed from the original of `top` name.
Returns:
A Python tree containin the corresponding tf.Tensor` or a `tf.Operation`.
Raises:
ValueError: if `src_name` does not start with `src_scope`.
TypeError: if `top` is not a `tf.Tensor` or a `tf.Operation`
KeyError: If the corresponding graph element cannot be found.
"""
def func(top):
return find_corresponding_elem(top, dst_graph, dst_scope, src_scope)
return transform_tree(targets, func)