blob: 27662d72c9f03543987454e67e229c56e68dce6e [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Scan dataset transformation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.tf_export import tf_export
class _ScanDataset(dataset_ops.UnaryDataset):
"""A dataset that scans a function across its input."""
def __init__(self, input_dataset, initial_state, scan_func):
"""See `scan()` for details."""
self._input_dataset = input_dataset
self._initial_state = structure.normalize_element(initial_state)
# Compute initial values for the state classes, shapes and types based on
# the initial state. The shapes may be refined by running `tf_scan_func` one
# or more times below.
self._state_structure = structure.type_spec_from_value(self._initial_state)
# Iteratively rerun the scan function until reaching a fixed point on
# `self._state_shapes`.
need_to_rerun = True
while need_to_rerun:
wrapped_func = dataset_ops.StructuredFunctionWrapper(
scan_func,
self._transformation_name(),
input_structure=(self._state_structure,
input_dataset.element_spec),
add_to_graph=False)
if not (isinstance(wrapped_func.output_types, collections_abc.Sequence)
and len(wrapped_func.output_types) == 2):
raise TypeError("The scan function must return a pair comprising the "
"new state and the output value.")
new_state_classes, self._output_classes = wrapped_func.output_classes
# Extract and validate class information from the returned values.
new_state_classes, output_classes = wrapped_func.output_classes
old_state_classes = nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
self._state_structure)
for new_state_class, old_state_class in zip(
nest.flatten(new_state_classes),
nest.flatten(old_state_classes)):
if not issubclass(new_state_class, old_state_class):
raise TypeError(
"The element classes for the new state must match the initial "
"state. Expected %s; got %s." %
(old_state_classes, new_state_classes))
# Extract and validate type information from the returned values.
new_state_types, output_types = wrapped_func.output_types
old_state_types = nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
self._state_structure)
for new_state_type, old_state_type in zip(
nest.flatten(new_state_types), nest.flatten(old_state_types)):
if new_state_type != old_state_type:
raise TypeError(
"The element types for the new state must match the initial "
"state. Expected %s; got %s." %
(old_state_types, new_state_types))
# Extract shape information from the returned values.
new_state_shapes, output_shapes = wrapped_func.output_shapes
old_state_shapes = nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
self._state_structure)
self._element_spec = structure.convert_legacy_structure(
output_types, output_shapes, output_classes)
flat_state_shapes = nest.flatten(old_state_shapes)
flat_new_state_shapes = nest.flatten(new_state_shapes)
weakened_state_shapes = [
original.most_specific_compatible_shape(new)
for original, new in zip(flat_state_shapes, flat_new_state_shapes)
]
need_to_rerun = False
for original_shape, weakened_shape in zip(flat_state_shapes,
weakened_state_shapes):
if original_shape.ndims is not None and (
weakened_shape.ndims is None or
original_shape.as_list() != weakened_shape.as_list()):
need_to_rerun = True
break
if need_to_rerun:
# TODO(b/110122868): Support a "most specific compatible structure"
# method for combining structures, to avoid using legacy structures
# in this method.
self._state_structure = structure.convert_legacy_structure(
old_state_types,
nest.pack_sequence_as(old_state_shapes, weakened_state_shapes),
old_state_classes)
self._scan_func = wrapped_func
self._scan_func.function.add_to_graph(ops.get_default_graph())
# pylint: disable=protected-access
variant_tensor = gen_experimental_dataset_ops.scan_dataset(
self._input_dataset._variant_tensor,
structure.to_tensor_list(self._state_structure, self._initial_state),
self._scan_func.function.captured_inputs,
f=self._scan_func.function,
preserve_cardinality=True,
**self._flat_structure)
super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
def _functions(self):
return [self._scan_func]
@property
def element_spec(self):
return self._element_spec
def _transformation_name(self):
return "tf.data.experimental.scan()"
@tf_export("data.experimental.scan")
def scan(initial_state, scan_func):
"""A transformation that scans a function across an input dataset.
This transformation is a stateful relative of `tf.data.Dataset.map`.
In addition to mapping `scan_func` across the elements of the input dataset,
`scan()` accumulates one or more state tensors, whose initial values are
`initial_state`.
Args:
initial_state: A nested structure of tensors, representing the initial state
of the accumulator.
scan_func: A function that maps `(old_state, input_element)` to
`(new_state, output_element). It must take two arguments and return a
pair of nested structures of tensors. The `new_state` must match the
structure of `initial_state`.
Returns:
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
def _apply_fn(dataset):
return _ScanDataset(dataset, initial_state, scan_func)
return _apply_fn