blob: d145a7beda52e57e7c231e356e5cfc06f2ff8bc1 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""for_loop and pfor ops."""
# pylint: disable=g-direct-tensorflow-import
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.parallel_for.pfor import PFor
from tensorflow.python.ops.parallel_for.pfor import PForConfig
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None):
"""Runs `loop_fn` `iters` times and stacks the outputs.
Runs `loop_fn` `iters` times, with input values from 0 to `iters - 1`, and
stacks corresponding outputs of the different runs.
Args:
loop_fn: A function that takes an int32 scalar tf.Tensor object representing
the iteration number, and returns a possibly nested structure of tensor
objects. The shape of these outputs should not depend on the input.
loop_fn_dtypes: dtypes for the outputs of `loop_fn`.
iters: Number of iterations for which to run `loop_fn`.
parallel_iterations: The number of iterations that can be dispatched in
parallel. This knob can be used to control the total memory usage.
Returns:
Returns a nested structure of stacked output tensor objects with the same
nested structure as the output of `loop_fn`.
"""
flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes)
is_none_list = []
def while_body(i, *ta_list):
"""Body of while loop."""
fn_output = nest.flatten(loop_fn(i))
if len(fn_output) != len(flat_loop_fn_dtypes):
raise ValueError(
"Number of expected outputs, %d, does not match the number of "
"actual outputs, %d, from loop_fn" % (len(flat_loop_fn_dtypes),
len(fn_output)))
outputs = []
del is_none_list[:]
is_none_list.extend(x is None for x in fn_output)
for out, ta in zip(fn_output, ta_list):
# TODO(agarwal): support returning Operation objects from loop_fn.
if out is not None:
# out may be a ref tensor, wrap it in identity to get a non-ref tensor.
ta = ta.write(i, array_ops.expand_dims(out, 0))
outputs.append(ta)
return tuple([i + 1] + outputs)
if parallel_iterations is not None:
extra_args = {"parallel_iterations": parallel_iterations}
else:
extra_args = {}
ta_list = control_flow_ops.while_loop(
lambda i, *ta: i < iters,
while_body,
[0] + [tensor_array_ops.TensorArray(dtype.base_dtype, iters)
for dtype in flat_loop_fn_dtypes],
**extra_args)[1:]
# TODO(rachelim): enable this for sparse tensors
output = [None if is_none else ta.concat()
for ta, is_none in zip(ta_list, is_none_list)]
assert len(output) in (0, len(flat_loop_fn_dtypes))
if not output:
# This may happen for the case where iters == 0.
return None
else:
return nest.pack_sequence_as(loop_fn_dtypes, output)
def _flatten_first_two_dims(x):
"""Flattens the first two dimensions of x into a single dimension."""
old_shape = array_ops.shape(x)
new_shape = array_ops.concat([[old_shape[0] * old_shape[1]], old_shape[2:]],
axis=0)
return array_ops.reshape(x, new_shape)
PFOR_CONFIG_ARG = "pfor_config"
def _is_under_xla_context():
"""Check if we are currently inside an XLA compile context."""
g = ops.get_default_graph()
while g is not None:
control_flow_context = g._get_control_flow_context() # pylint: disable=protected-access
while control_flow_context is not None:
if control_flow_context.IsXLAContext():
return True
else:
control_flow_context = control_flow_context.outer_context
# If g is a FuncGraph, get its outer_graph.
g = getattr(g, "outer_graph", None)
return False
def pfor(loop_fn, iters, fallback_to_while_loop=True, parallel_iterations=None):
"""Equivalent to running `loop_fn` `iters` times and stacking the outputs.
`pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters`
times, with input from 0 to `iters - 1`, and stacking corresponding output of
each iteration. However the implementation does not use a `tf.while_loop`.
Instead it adds new operations to the graph that collectively compute the same
value as what running `loop_fn` in a loop would compute.
This is an experimental feature and currently has a lot of limitations:
- There should be no data dependency between the different iterations. For
example, a future iteration should not depend on a value or side-effect of
a previous iteration.
- Stateful kernels may mostly not be supported since these often imply a
data dependency or ordering of the iterations. We do support a limited set
of such stateful kernels though (like RandomFoo, Variable operations like
reads, etc).
- Conversion works only on a limited set of kernels for which a converter
has been registered.
- `loop_fn` has limited support for control flow operations. `tf.cond` in
particular is not supported.
- `loop_fn` should return nested structure of Tensors or Operations. However
if an Operation is returned, it should have zero outputs.
- The shape and dtype of `loop_fn` outputs should not depend on the input
to loop_fn.
Args:
loop_fn: A function that takes an int32 scalar tf.Tensor object representing
the iteration number, and optionally a keyword argument `pfor_config` set
to a PForConfig object. It returns a possibly nested structure of Tensor
or Operation objects. Note that if setting `parallel_iterations` argument
to something other than None, `loop_fn` may be called more than once
during graph construction. So it may need to avoid mutating global state.
iters: Number of iterations for which to run `loop_fn`.
fallback_to_while_loop: If true, on failing to vectorize an operation, pfor
fallbacks to using a `tf.while_loop` to dispatch the iterations.
parallel_iterations: A knob to control how many iterations are vectorized
and dispatched in parallel. The default value of None corresponds to
vectorizing all the iterations. If `parallel_iterations` is smaller than
`iters`, then chunks of at most that many iterations are dispatched in
sequence. This knob can be used to control the total memory usage.
Returns:
Returns a nested structure of stacked tensor objects with the same nested
structure as the output of `loop_fn`.
Raises:
ValueError: If parallel_iterations is not None and not an integer > 1.
"""
def f():
return _pfor_impl(loop_fn,
iters,
fallback_to_while_loop=fallback_to_while_loop,
parallel_iterations=parallel_iterations)
# Note that we wrap into a tf.function if in eager execution mode or under
# XLA compilation. The latter is so that we don't compile operations like
# tf.placeholder that are created by the loop body.
functions_run_eagerly = None
if context.executing_eagerly() or _is_under_xla_context():
functions_run_eagerly = def_function.functions_run_eagerly()
if functions_run_eagerly:
logging.warning(
"It looks like tf.function behavior was disabled, perhaps using "
"tf.config.run_functions_eagerly. Vectorization "
"primitives (e.g. tf.vectorized_map) require tf.function to work. "
"These primitives will override the disable.")
def_function.run_functions_eagerly(False)
f = def_function.function(f)
outputs = f()
if functions_run_eagerly is not None:
def_function.run_functions_eagerly(functions_run_eagerly)
return outputs
def _loop_fn_has_config(loop_fn):
"""Test if `loop_fn` has a `pfor_config` argument."""
if tf_inspect.isfunction(loop_fn):
argspec = tf_inspect.getargspec(loop_fn)
return PFOR_CONFIG_ARG in argspec.args
elif isinstance(loop_fn, functools.partial):
fn = loop_fn.func
argspec = tf_inspect.getargspec(fn)
return (PFOR_CONFIG_ARG in argspec.args and
PFOR_CONFIG_ARG not in loop_fn.keywords)
else:
loop_class = tf_decorator.unwrap(loop_fn)[1]
if not hasattr(loop_class, "__call__"):
raise ValueError("loop_fn object did not have a __call__ method")
argspec = tf_inspect.getargspec(loop_class.__call__)
return PFOR_CONFIG_ARG in argspec.args
def _pfor_impl(loop_fn,
iters,
fallback_to_while_loop,
parallel_iterations=None,
pfor_config=None):
"""Implementation of pfor."""
assert not context.executing_eagerly()
loop_fn_has_config = _loop_fn_has_config(loop_fn)
existing_ops = set(ops.get_default_graph().get_operations())
# Run the loop body
with ops.name_scope("loop_body"):
loop_var = array_ops.placeholder_with_default(0, shape=[])
if loop_fn_has_config:
if pfor_config is None:
pfor_config = PForConfig()
pfor_config._set_iters(iters) # pylint: disable=protected-access
loop_fn_outputs = loop_fn(loop_var, **{PFOR_CONFIG_ARG: pfor_config})
else:
assert pfor_config is None
loop_fn_outputs = loop_fn(loop_var)
# Convert outputs to Tensor if needed.
tmp_loop_fn_outputs = []
for loop_fn_output in nest.flatten(loop_fn_outputs):
if (loop_fn_output is not None and not isinstance(
loop_fn_output,
(ops.Operation, ops.Tensor, sparse_tensor.SparseTensor))):
if isinstance(loop_fn_output, indexed_slices.IndexedSlices):
logging.warn("Converting %s to a dense representation may make it slow."
" Alternatively, output the indices and values of the"
" IndexedSlices separately, and handle the vectorized"
" outputs directly." % loop_fn_output)
loop_fn_output = ops.convert_to_tensor(loop_fn_output)
tmp_loop_fn_outputs.append(loop_fn_output)
loop_fn_outputs = nest.pack_sequence_as(loop_fn_outputs, tmp_loop_fn_outputs)
new_ops = set(ops.get_default_graph().get_operations()) - existing_ops
iters = ops.convert_to_tensor(iters)
if parallel_iterations is not None:
if parallel_iterations < 1:
raise ValueError("parallel_iterations must be None or a positive integer")
if parallel_iterations == 1:
raise ValueError("Found parallel_iterations == 1. Use for_loop instead.")
iters_value = tensor_util.constant_value(iters)
if iters_value is not None and iters_value < parallel_iterations:
parallel_iterations = None
if parallel_iterations is None:
with ops.name_scope("pfor"):
converter = PFor(loop_var, iters, new_ops,
fallback_to_while_loop=fallback_to_while_loop,
pfor_config=pfor_config)
outputs = []
for loop_fn_output in nest.flatten(loop_fn_outputs):
outputs.append(converter.convert(loop_fn_output))
return nest.pack_sequence_as(loop_fn_outputs, outputs)
else:
if pfor_config is not None and pfor_config._has_reductions(): # pylint: disable=protected-access
raise ValueError("Setting parallel_iterations currently unsupported if"
" reductions across iterations are performed.")
num_tiled_iterations = iters // parallel_iterations
num_remaining_iterations = iters % parallel_iterations
# TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside
# a tf.function and extract the graph from there to vectorize it.
with ops.name_scope("pfor_untiled"):
converter = PFor(loop_var, num_remaining_iterations, new_ops,
fallback_to_while_loop=fallback_to_while_loop,
pfor_config=pfor_config)
remaining_outputs = []
flattened_loop_fn_outputs = nest.flatten(loop_fn_outputs)
for loop_fn_output in flattened_loop_fn_outputs:
remaining_outputs.append(converter.convert(loop_fn_output))
with ops.name_scope("pfor_tiled"):
loop_fn_dtypes = [ops.convert_to_tensor(x).dtype
for x in flattened_loop_fn_outputs]
def tiled_loop_body(j):
offset = j * parallel_iterations + num_remaining_iterations
def tiled_loop_fn(i, pfor_config=None):
if loop_fn_has_config:
return nest.flatten(loop_fn(i + offset, pfor_config=pfor_config))
else:
return nest.flatten(loop_fn(i + offset))
return _pfor_impl(
tiled_loop_fn,
parallel_iterations,
fallback_to_while_loop=fallback_to_while_loop,
pfor_config=pfor_config)
tiled_outputs = for_loop(tiled_loop_body, loop_fn_dtypes,
num_tiled_iterations, parallel_iterations=1)
tiled_outputs = [_flatten_first_two_dims(y) for y in tiled_outputs]
with ops.name_scope("pfor"):
iters_value = tensor_util.constant_value(iters)
if iters_value is None or iters_value % parallel_iterations:
outputs = control_flow_ops.cond(
math_ops.equal(num_remaining_iterations, 0),
lambda: tiled_outputs,
lambda: [array_ops.concat([x, y], axis=0)
for x, y in zip(remaining_outputs, tiled_outputs)])
else:
outputs = tiled_outputs
return nest.pack_sequence_as(loop_fn_outputs, nest.flatten(outputs))
@tf_export("vectorized_map")
def vectorized_map(fn, elems, fallback_to_while_loop=True):
"""Parallel map on the list of tensors unpacked from `elems` on dimension 0.
This method works similar to `tf.map_fn` but is optimized to run much faster,
possibly with a much larger memory footprint. The speedups are obtained by
vectorization (see https://arxiv.org/pdf/1903.04243.pdf). The idea behind
vectorization is to semantically launch all the invocations of `fn` in
parallel and fuse corresponding operations across all these invocations. This
fusion is done statically at graph generation time and the generated code is
often similar in performance to a manually fused version.
Because `tf.vectorized_map` fully parallelizes the batch, this method will
generally be significantly faster than using `tf.map_fn`, especially in eager
mode. However this is an experimental feature and currently has a lot of
limitations:
- There should be no data dependency between the different semantic
invocations of `fn`, i.e. it should be safe to map the elements of the
inputs in any order.
- Stateful kernels may mostly not be supported since these often imply a
data dependency. We do support a limited set of such stateful kernels
though (like RandomFoo, Variable operations like reads, etc).
- `fn` has limited support for control flow operations.
- `fn` should return nested structure of Tensors or Operations. However
if an Operation is returned, it should have zero outputs.
- The shape and dtype of any intermediate or output tensors in the
computation of `fn` should not depend on the input to `fn`.
Examples:
```python
def outer_product(a):
return tf.tensordot(a, a, 0)
batch_size = 100
a = tf.ones((batch_size, 32, 32))
c = tf.vectorized_map(outer_product, a)
assert c.shape == (batch_size, 32, 32, 32, 32)
```
```python
# Computing per-example gradients
batch_size = 10
num_features = 32
layer = tf.keras.layers.Dense(1)
def model_fn(arg):
with tf.GradientTape() as g:
inp, label = arg
inp = tf.expand_dims(inp, 0)
label = tf.expand_dims(label, 0)
prediction = layer(inp)
loss = tf.nn.l2_loss(label - prediction)
return g.gradient(loss, (layer.kernel, layer.bias))
inputs = tf.random.uniform([batch_size, num_features])
labels = tf.random.uniform([batch_size, 1])
per_example_gradients = tf.vectorized_map(model_fn, (inputs, labels))
assert per_example_gradients[0].shape == (batch_size, num_features, 1)
assert per_example_gradients[1].shape == (batch_size, 1)
```
Args:
fn: The callable to be performed. It accepts one argument, which will have
the same (possibly nested) structure as `elems`, and returns a possibly
nested structure of Tensors and Operations, which may be different than
the structure of `elems`.
elems: A tensor or (possibly nested) sequence of tensors, each of which will
be unpacked along their first dimension. The nested sequence of the
resulting slices will be mapped over by `fn`.
fallback_to_while_loop: If true, on failing to vectorize an operation,
the unsupported op is wrapped in a tf.while_loop to execute the map
iterations. Note that this fallback only happens for unsupported ops and
other parts of `fn` are still vectorized. If false, on encountering an
unsupported op, a ValueError is thrown. Note that the fallbacks can result
in slowdowns since vectorization often yields speedup of one to two orders
of magnitude.
Returns:
A tensor or (possibly nested) sequence of tensors. Each tensor packs the
results of applying fn to tensors unpacked from elems along the first
dimension, from first to last.
Raises:
ValueError: If vectorization fails and fallback_to_while_loop is False.
"""
def loop_fn(i):
gathered_elems = nest.map_structure(lambda x: array_ops.gather(x, i), elems)
return fn(gathered_elems)
batch_size = None
first_elem = ops.convert_to_tensor(nest.flatten(elems)[0])
if first_elem.shape.rank is not None:
batch_size = first_elem.shape.as_list()[0]
if batch_size is None:
batch_size = array_ops.shape(first_elem)[0]
return pfor(loop_fn, batch_size,
fallback_to_while_loop=fallback_to_while_loop)