blob: 525d5b44c4f2505b2331f1c5c3b539637879f6a7 [file] [log] [blame]
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocessing stage."""
# pylint: disable=g-classes-have-attributes
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.keras.engine import base_preprocessing_layer
from tensorflow.python.keras.engine import functional
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.util import nest
# Sequential methods should take precedence.
class PreprocessingStage(sequential.Sequential,
base_preprocessing_layer.PreprocessingLayer):
"""A sequential preprocessing stage.
This preprocessing stage wraps a list of preprocessing layers into a
Sequential-like object that enables you to `adapt()` the whole list via
a single `adapt()` call on the preprocessing stage.
Args:
layers: List of layers. Can include layers that aren't preprocessing layers.
name: String. Optional name for the preprocessing stage object.
"""
def adapt(self, data, reset_state=True):
"""Adapt the state of the layers of the preprocessing stage to the data.
Args:
data: A batched Dataset object, or a NumPy array, or an EagerTensor.
Data to be iterated over to adapt the state of the layers in this
preprocessing stage.
reset_state: Whether this call to `adapt` should reset the state of
the layers in this preprocessing stage.
"""
if not isinstance(data,
(dataset_ops.DatasetV2, np.ndarray, ops.EagerTensor)):
raise ValueError(
'`adapt()` requires a batched Dataset, an EagerTensor, '
'or a Numpy array as input, '
'got {}'.format(type(data)))
if isinstance(data, dataset_ops.DatasetV2):
# Validate the datasets to try and ensure we haven't been passed one with
# infinite size. That would cause an infinite loop here.
if tf_utils.dataset_is_infinite(data):
raise ValueError(
'The dataset passed to `adapt()` has an infinite number of '
'elements. Please use dataset.take(...) to make the number '
'of elements finite.')
for current_layer_index in range(0, len(self.layers)):
if not hasattr(self.layers[current_layer_index], 'adapt'):
# Skip any layer that does not need adapting.
continue
def map_fn(x):
"""Maps `PreprocessingStage` inputs to inputs at `current_layer_index`.
Args:
x: Batch of inputs seen in entry of the `PreprocessingStage` instance.
Returns:
Batch of inputs to be processed by layer
`self.layers[current_layer_index]`
"""
if current_layer_index == 0: # pylint: disable=cell-var-from-loop
return x
for i in range(current_layer_index): # pylint: disable=cell-var-from-loop
x = self.layers[i](x)
return x
if isinstance(data, dataset_ops.DatasetV2):
current_layer_data = data.map(map_fn)
else:
current_layer_data = map_fn(data)
self.layers[current_layer_index].adapt(current_layer_data,
reset_state=reset_state)
# Functional methods shoud take precedence.
class FunctionalPreprocessingStage(functional.Functional,
base_preprocessing_layer.PreprocessingLayer):
"""A functional preprocessing stage.
This preprocessing stage wraps a graph of preprocessing layers into a
Functional-like object that enables you to `adapt()` the whole graph via
a single `adapt()` call on the preprocessing stage.
Preprocessing stage is not a complete model, so it cannot be called with
`fit()`. However, it is possible to add regular layers that may be trainable
to a preprocessing stage.
A functional preprocessing stage is created in the same way as `Functional`
models. A stage can be instantiated by passing two arguments to
`__init__`. The first argument is the `keras.Input` Tensors that represent
the inputs to the stage. The second argument specifies the output
tensors that represent the outputs of this stage. Both arguments can be a
nested structure of tensors.
Example:
>>> inputs = {'x2': tf.keras.Input(shape=(5,)),
... 'x1': tf.keras.Input(shape=(1,))}
>>> norm_layer = tf.keras.layers.experimental.preprocessing.Normalization()
>>> y = norm_layer(inputs['x2'])
>>> y, z = tf.keras.layers.Lambda(lambda x: (x, x))(inputs['x1'])
>>> outputs = [inputs['x1'], [y, z]]
>>> stage = FunctionalPreprocessingStage(inputs, outputs)
Args:
inputs: An input tensor (must be created via `tf.keras.Input()`), or a list,
a dict, or a nested strcture of input tensors.
outputs: An output tensor, or a list, a dict or a nested structure of output
tensors.
name: String, optional. Name of the preprocessing stage.
"""
def fit(self, *args, **kwargs):
raise ValueError(
'Preprocessing stage is not a complete model, and hence should not be '
'`fit`. Instead, you may feed data to `adapt` the stage to set '
'appropriate states of the layers in the stage.')
def adapt(self, data, reset_state=True):
"""Adapt the state of the layers of the preprocessing stage to the data.
Args:
data: A batched Dataset object, a NumPy array, an EagerTensor, or a list,
dict or nested structure of Numpy Arrays or EagerTensors. The elements
of Dataset object need to conform with inputs of the stage. The first
dimension of NumPy arrays or EagerTensors are understood to be batch
dimension. Data to be iterated over to adapt the state of the layers in
this preprocessing stage.
reset_state: Whether this call to `adapt` should reset the state of the
layers in this preprocessing stage.
Examples:
>>> # For a stage with dict input
>>> inputs = {'x2': tf.keras.Input(shape=(5,)),
... 'x1': tf.keras.Input(shape=(1,))}
>>> outputs = [inputs['x1'], inputs['x2']]
>>> stage = FunctionalPreprocessingStage(inputs, outputs)
>>> ds = tf.data.Dataset.from_tensor_slices({'x1': tf.ones((4,5)),
... 'x2': tf.ones((4,1))})
>>> sorted(ds.element_spec.items()) # Check element_spec
[('x1', TensorSpec(shape=(5,), dtype=tf.float32, name=None)),
('x2', TensorSpec(shape=(1,), dtype=tf.float32, name=None))]
>>> stage.adapt(ds)
>>> data_np = {'x1': np.ones((4, 5)), 'x2': np.ones((4, 1))}
>>> stage.adapt(data_np)
"""
if not isinstance(data, dataset_ops.Dataset):
data = self._flatten_to_reference_inputs(data)
if any(not isinstance(datum, (np.ndarray, ops.EagerTensor))
for datum in data):
raise ValueError(
'`adapt()` requires a batched Dataset, a list of EagerTensors '
'or Numpy arrays as input, got {}'.format(type(data)))
ds_input = [
dataset_ops.Dataset.from_tensor_slices(x).batch(1) for x in data
]
if isinstance(data, dataset_ops.Dataset):
# Validate the datasets to try and ensure we haven't been passed one with
# infinite size. That would cause an infinite loop here.
if tf_utils.dataset_is_infinite(data):
raise ValueError(
'The dataset passed to `adapt()` has an infinite number of '
'elements. Please use dataset.take(...) to make the number '
'of elements finite.')
# Unzip dataset object to a list of single input dataset.
ds_input = _unzip_dataset(data)
# Dictionary mapping reference tensors to datasets
ds_dict = {}
tensor_usage_count = self._tensor_usage_count
for x, y in zip(self.inputs, ds_input):
x_id = str(id(x))
ds_dict[x_id] = [y] * tensor_usage_count[x_id]
nodes_by_depth = self._nodes_by_depth
depth_keys = sorted(nodes_by_depth.keys(), reverse=True)
def build_map_fn(node, args, kwargs):
if not isinstance(args.element_spec, tuple):
def map_fn(*x):
return nest.flatten(node.layer(*x, **kwargs))
else:
def map_fn(*x):
return nest.flatten(node.layer(x, **kwargs))
return map_fn
for depth in depth_keys:
for node in nodes_by_depth[depth]:
# Input node
if node.is_input:
continue
# Node with input not computed yet
if any(t_id not in ds_dict for t_id in node.flat_input_ids):
continue
args, kwargs = node.map_arguments(ds_dict)
args = dataset_ops.Dataset.zip(nest.list_to_tuple(*args))
if node.layer.stateful and hasattr(node.layer, 'adapt'):
node.layer.adapt(args, reset_state=reset_state)
map_fn = build_map_fn(node, args, kwargs)
outputs = args.map(map_fn)
outputs = _unzip_dataset(outputs)
# Update ds_dict.
for x_id, y in zip(node.flat_output_ids, outputs):
ds_dict[x_id] = [y] * tensor_usage_count[x_id]
def _unzip_dataset(ds):
"""Unzip dataset into a list of single element datasets.
Args:
ds: A Dataset object.
Returns:
A list of Dataset object, each correspond to one of the `element_spec` of
the input Dataset object.
Example:
>>> ds1 = tf.data.Dataset.from_tensor_slices([1, 2, 3])
>>> ds2 = tf.data.Dataset.from_tensor_slices([4, 5, 6])
>>> ds_zipped_tuple = tf.data.Dataset.zip((ds1, ds2))
>>> ds_unzipped_tuple = _unzip_dataset(ds_zipped_tuple)
>>> ds_zipped_dict = tf.data.Dataset.zip({'ds1': ds1, 'ds2': ds2})
>>> ds_unzipped_dict = _unzip_dataset(ds_zipped_dict)
Then the two elements of `ds_unzipped_tuple` and `ds_unzipped_dict` are both
the same as `ds1` and `ds2`.
"""
element_count = len(nest.flatten(ds.element_spec))
ds_unzipped = []
for i in range(element_count):
def map_fn(*x, j=i):
return nest.flatten(x)[j]
ds_unzipped.append(ds.map(map_fn))
return ds_unzipped