blob: 945440212278a013572b361420b9a15a94f03ee0 [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training related logic for Keras model in TF 2.0 context.
Note that all the code under this module is under active development, please DO
NOT use it unless you are really sure what you are doing.
"""
# pylint: disable=protected-access
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework.ops import composite_tensor
from tensorflow.python.keras import backend
from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
from tensorflow.python.keras.engine import training_eager
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import nest
def _get_or_make_execution_function(model, mode):
"""Makes or reuses function to run one step of distributed model execution."""
model._init_distributed_function_cache_if_not_compiled()
# Use a key with 'v2' to distinguish from fall-back execution functions.
key = (mode, 'v2')
distributed_function = dist_utils.get_distributed_function(model, key)
if distributed_function:
return distributed_function
distribution_function = _make_execution_function(model, mode)
dist_utils.set_distributed_function(model, key, distribution_function)
return distribution_function
def _make_execution_function(model, mode):
"""Creates a function to run one step of distributed model execution."""
per_replica_function = _make_replica_execution_function(model, mode)
def distributed_function(input_iterator):
"""A single step of the distributed execution across replicas."""
x, y, sample_weights = _prepare_feed_values(
model, input_iterator, mode)
# Call `Model.{train,test,predict}_on_batch` on every replica passing
# PerReplicas as arguments. On every replica inside this call, each
# PerReplica object will return the value for that replica. The outputs
# are PerReplicas too.
strategy = distribution_strategy_context.get_strategy()
outputs = strategy.experimental_run_v2(
per_replica_function, args=(x, y, sample_weights))
# Out of PerReplica outputs reduce or pick values to return.
all_outputs = dist_utils.unwrap_output_dict(
strategy, outputs, mode)
return all_outputs
if not model.run_eagerly:
distributed_function = def_function.function(
distributed_function, autograph=False)
def execution_function(input_fn):
# `numpy` translates Tensors to values in Eager mode.
return nest.map_structure(_non_none_constant_value,
distributed_function(input_fn))
return execution_function
def _non_none_constant_value(v):
constant_value = tensor_util.constant_value(v)
return constant_value if constant_value is not None else v
def _prepare_feed_values(model, inputs, mode):
"""Prepare feed values to the model execution function.
Arguments:
model: Model to prepare feed values for.
inputs: An iterator of model inputs, targets, and sample_weights.
model inputs may be lists, single values, or dicts mapping input feed
names to values.
mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
Returns:
Feed values for the model in the given mode. This is a tuple of
the structure (inputs, targets, sample_weights), where each of
(tuple, targets, sample_weights) may be a python list. Single values
for inputs will always be wrapped in lists.
"""
inputs, targets, sample_weights = _get_input_from_iterator(inputs)
# When the inputs are dict, then we want to flatten it in the same order as
# the input layers, such that the data are fed into the input layers in the
# correct order.
if isinstance(inputs, dict):
inputs = [inputs[key] for key in model._feed_input_names]
else:
inputs = training_utils.ModelInputs(inputs).as_list()
if mode == ModeKeys.PREDICT:
sample_weights = []
targets = []
ins = [inputs, targets, sample_weights]
return tuple(ins)
def _get_input_from_iterator(iterator):
"""Get elements from the iterator and verify the input shape and type."""
next_element = next(iterator)
if (tensor_util.is_tensor(next_element) or
isinstance(next_element, (dict, composite_tensor.CompositeTensor))):
next_element = [next_element]
if len(next_element) == 1:
x, = next_element
y = None
sample_weights = None
elif len(next_element) == 2:
x, y = next_element
sample_weights = None
else:
x, y, sample_weights = next_element
# Validate that all the elements in x and y are of the same type and shape.
dist_utils.validate_distributed_dataset_inputs(
distribution_strategy_context.get_strategy(), x, y, sample_weights)
return x, y, sample_weights
def _make_replica_execution_function(model, mode):
"""A single step of the distributed execution on a replica."""
if mode == ModeKeys.TRAIN:
func = functools.partial(train_on_batch, model)
elif mode == ModeKeys.TEST:
func = functools.partial(test_on_batch, model)
else:
def _predict_on_batch(x, y=None, sample_weights=None):
del y, sample_weights
return predict_on_batch(model, x)
func = _predict_on_batch
if mode != ModeKeys.PREDICT:
# `reset_metrics` is set to False to maintain stateful metrics across
# batch-level calls.
func = functools.partial(func, reset_metrics=False)
return func
def _prepare_model_with_inputs(model, dataset):
"""Use the data from the adapter to config the model.
Model need to be properly configured before training, eg build with inputs, or
compile with inputs for subclass model.
Args:
model: a Keras model object.
dataset: a eager dataset instance where the data will be extracted.
"""
if not model.inputs:
inputs, target, _ = model._build_model_with_inputs(dataset, targets=None)
else:
inputs, target, _ = _get_input_from_iterator(iter(dataset))
if not model._is_compiled and model.optimizer:
model._compile_from_inputs(inputs, target, dataset, None)
if target is not None:
training_utils.prepare_sample_weight_modes(model._training_endpoints,
model.sample_weight_mode)
def train_on_batch(
model,
x,
y=None,
sample_weight=None,
class_weight=None,
reset_metrics=True):
"""Runs a single gradient update on a single batch of data.
Arguments:
model: The model to train.
x: Input data. It could be:
- A Numpy array (or array-like), or a list of arrays
(in case the model has multiple inputs).
- A TensorFlow tensor, or a list of tensors
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data` dataset.
y: Target data. Like the input data `x`, it could be either Numpy
array(s) or TensorFlow tensor(s). It should be consistent with `x`
(you cannot have Numpy inputs and tensor targets, or inversely). If
`x` is a dataset `y` should not be specified
(since targets will be obtained from the iterator).
sample_weight: Optional array of the same length as x, containing
weights to apply to the model's loss for each sample. In the case of
temporal data, you can pass a 2D array with shape (samples,
sequence_length), to apply a different weight to every timestep of
every sample. In this case you should make sure to specify
sample_weight_mode="temporal" in compile(). This argument is not
supported when `x` is a dataset.
class_weight: Optional dictionary mapping class indices (integers) to a
weight (float) to apply to the model's loss for the samples from this
class during training. This can be useful to tell the model to "pay
more attention" to samples from an under-represented class.
reset_metrics: If `True`, the metrics returned will be only for this
batch. If `False`, the metrics will be statefully accumulated across
batches.
Returns:
Scalar training loss
(if the model has a single output and no metrics)
or list of scalars (if the model has multiple outputs
and/or metrics). The attribute `model.metrics_names` will give you
the display labels for the scalar outputs.
Raises:
ValueError: In case of invalid user-provided arguments.
"""
model._assert_compile_was_called()
# TODO(scottzhu): Standardization should happen in the data handlers,
## not on a per batch basis in the *_on_batch methods
# Validate and standardize user data.
x, y, sample_weights = model._standardize_user_data(
x, y, sample_weight=sample_weight, class_weight=class_weight,
extract_tensors_from_dataset=True)
batch_size = array_ops.shape(nest.flatten(x, expand_composites=True)[0])[0]
# If `model._distribution_strategy` is True, then we are in a replica context
# at this point because of the check above. `train_on_batch` is being run
# for each replica by `model._distribution_strategy` and the same code path
# as Eager is expected to be taken.
outputs = training_eager.train_on_batch(
model,
x,
y,
sample_weights=sample_weights,
output_loss_metrics=model._output_loss_metrics)
if reset_metrics:
model.reset_metrics()
outputs['batch_size'] = math_ops.cast(batch_size, dtypes.int64)
return outputs
def test_on_batch(model, x, y=None, sample_weight=None, reset_metrics=True):
"""Test the model on a single batch of samples.
Arguments:
model: The model to test.
x: Input data. It could be:
- A Numpy array (or array-like), or a list of arrays
(in case the model has multiple inputs).
- A TensorFlow tensor, or a list of tensors
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data` dataset.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
tensor targets, or inversely). If `x` is a dataset,
`y` should not be specified
(since targets will be obtained from the iterator).
sample_weight: Optional array of the same length as x, containing
weights to apply to the model's loss for each sample.
In the case of temporal data, you can pass a 2D array
with shape (samples, sequence_length),
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
sample_weight_mode="temporal" in compile(). This argument is not
supported when `x` is a dataset.
reset_metrics: If `True`, the metrics returned will be only for this
batch. If `False`, the metrics will be statefully accumulated across
batches.
Returns:
Scalar test loss (if the model has a single output and no metrics)
or list of scalars (if the model has multiple outputs
and/or metrics). The attribute `model.metrics_names` will give you
the display labels for the scalar outputs.
Raises:
ValueError: In case of invalid user-provided arguments.
"""
model._assert_compile_was_called()
# TODO(scottzhu): Standardization should happen in the data handlers,
## not on a per batch basis in the *_on_batch methods
# Validate and standardize user data.
x, y, sample_weights = model._standardize_user_data(
x, y, sample_weight=sample_weight, extract_tensors_from_dataset=True)
batch_size = array_ops.shape(nest.flatten(x, expand_composites=True)[0])[0]
outputs = training_eager.test_on_batch(
model,
x,
y,
sample_weights=sample_weights,
output_loss_metrics=model._output_loss_metrics)
if reset_metrics:
model.reset_metrics()
outputs['batch_size'] = math_ops.cast(batch_size, dtypes.int64)
return outputs
def predict_on_batch(model, x):
"""Returns predictions for a single batch of samples.
Arguments:
model: The model to predict with.
x: Input data. It could be:
- A Numpy array (or array-like), or a list of arrays
(in case the model has multiple inputs).
- A TensorFlow tensor, or a list of tensors
(in case the model has multiple inputs).
- A `tf.data` dataset.
Returns:
Numpy array(s) of predictions.
Raises:
ValueError: In case of mismatch between given number of inputs and
expectations of the model.
"""
# TODO(scottzhu): Standardization should happen in the data handlers,
## not on a per batch basis in the *_on_batch methods
# Validate and standardize user data.
inputs, _, _ = model._standardize_user_data(
x, extract_tensors_from_dataset=True)
# If `model._distribution_strategy` is True, then we are in a replica context
# at this point.
inputs = training_utils.cast_if_floating_dtype(inputs)
if isinstance(inputs, collections.Sequence):
# Unwrap lists with only one input, as we do when training on batch
if len(inputs) == 1:
inputs = inputs[0]
with backend.eager_learning_phase_scope(0):
return model(inputs) # pylint: disable=not-callable