blob: b559d56281154768314b69fc5904b110aad7a25e [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training related logic for Keras model in TF 2.0 context.
Note that all the code under this module is under active development, please DO
NOT use it unless you are really sure what you are doing.
"""
# pylint: disable=protected-access
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import errors
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.engine import training_v2_utils
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
# The list of DataAdapter that support validation_split, only numpy and data
# tensor support validation_split for now.
_ADAPTER_FOR_VALIDATION_SPLIT = [data_adapter.TensorLikeDataAdapter]
# The list of DataAdapter that support model._standardize_user_data. Currently
# keras.sequence/python generator will cause error when calling
# model._standardize_user_data, this should be updated in future cl, eg, the
# dataset/generate/sequence input will be peeked and processed by
# model._standardize_user_data()
_ADAPTER_FOR_STANDARDIZE_USER_DATA = [
data_adapter.TensorLikeDataAdapter, data_adapter.DatasetAdapter
]
def run_one_epoch(model,
iterator,
execution_function,
dataset_size=None,
batch_size=None,
strategy=None,
steps_per_epoch=None,
num_samples=None,
mode=ModeKeys.TRAIN,
training_context=None,
total_epochs=None):
"""Run the execution function with the data from iterator.
Given the dataset iterator and execution function, get the data from iterator
and call it with the execution function to get the result (metric/loss).
It will run for steps_per_epoch or until to the iterator is fully consumed.
Args:
model: The keras model to run.
iterator: the dataset iterator to fetch the data.
execution_function: a tf.function that can be called with data.
dataset_size: the size of iterator, None when unknown.
batch_size: The size of the current batch.
strategy: the distribution strategy instance from the model.
steps_per_epoch: the number of steps to run for the epoch.
num_samples: the number of samples for the whole epoch if known. This can be
used to calculate the final partial batch, and scale the loss.
mode: the mode for the current epoch.
training_context: the context that contains callbacks and progress bar.
total_epochs: the total number of epochs that will be run.
Used when throw error when the iterator unexpectedly
reaches its end.
Returns:
The loss and metric value from the model.
"""
# Only use the sample to count if there is a partial batch at the end.
use_steps = num_samples is None
if mode == ModeKeys.PREDICT:
aggregator = training_utils.OutputsAggregator(
use_steps=use_steps,
steps=steps_per_epoch,
num_samples=num_samples,
batch_size=batch_size)
else:
aggregator = training_utils.MetricsAggregator(
use_steps=use_steps, steps=steps_per_epoch, num_samples=num_samples)
callbacks = training_context.callbacks
progbar = training_context.progbar
if callbacks.model.stop_training:
return
target_steps = steps_per_epoch or np.inf
step = 0
while step < target_steps:
if use_steps:
current_batch_size = 1
elif step < target_steps - 1:
current_batch_size = batch_size
else:
current_batch_size = num_samples - step * batch_size
with training_context.on_batch(
step=step, mode=mode, size=current_batch_size) as batch_logs:
try:
batch_outs = execution_function(iterator)
except (StopIteration, errors.OutOfRangeError):
# TODO(kaftan): File bug about tf function and errors.OutOfRangeError?
# Are there any other C++ errors tf function should recapture?
# The only acceptable case here is that the input has a unknown
# length, and configured to fully consume it.
if (dataset_size is None
and steps_per_epoch is None
and step > 0):
# The input passed by the user ran out of batches.
# Now we know the cardinality of the input(dataset or generator).
steps_per_epoch = step
aggregator.steps = steps_per_epoch
progbar.params['steps'] = steps_per_epoch
progbar.progbar.target = steps_per_epoch
else:
callbacks.model.stop_training = True
logging.warning(
'Your input ran out of data; interrupting training. '
'Make sure that your dataset or generator can generate at '
'least `steps_per_epoch * epochs` batches (in this case, '
'{} batches). You may need to use the repeat() function '
'when building your dataset.'.format(
total_epochs * steps_per_epoch))
# In either case, break out the loop for training batch.
# Also note the training_context that data inputs are exhausted, so all
# the post batch hooks can be skipped.
batch_logs['data_exhausted'] = True
break
if mode != ModeKeys.PREDICT:
data_batch_size = batch_outs['batch_size']
batch_outs = (batch_outs['total_loss'] + batch_outs['output_losses']
+ batch_outs['metrics'])
if current_batch_size != data_batch_size:
batch_logs['size'] = data_batch_size
current_batch_size = data_batch_size
else:
batch_outs = _aggregate_predict_results(strategy, batch_outs, model)
if step == 0:
aggregator.create(batch_outs)
if use_steps:
aggregator.aggregate(batch_outs)
else:
aggregator.aggregate(
batch_outs,
batch_start=step * batch_size,
batch_end=step * batch_size + current_batch_size)
cbks.make_logs(model, batch_logs, batch_outs, mode)
step += 1
if callbacks.model.stop_training:
break
# End of an epoch.
aggregator.finalize()
return aggregator.results
class Loop(training_utils.TrainingLoop):
"""The training loop for the TF 2.0.
This class has some existing assumption for runtime, eg eager by default,
have distribution strategy, etc.
"""
def fit(
self, model, x=None, y=None, batch_size=None, epochs=1, verbose=1,
callbacks=None, validation_split=0., validation_data=None, shuffle=True,
class_weight=None, sample_weight=None, initial_epoch=0,
steps_per_epoch=None, validation_steps=None, validation_freq=1, **kwargs):
batch_size = model._validate_or_infer_batch_size(
batch_size, steps_per_epoch, x)
strategy = _get_distribution_strategy(model)
batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size(
strategy, x, batch_size, steps_per_epoch, ModeKeys.TRAIN)
dist_utils.validate_callbacks(input_callbacks=callbacks,
optimizer=model.optimizer)
# Enter tf.distribute.Strategy scope.
with strategy.scope():
training_data_adapter, validation_adapter = _process_training_inputs(
model,
x,
y,
batch_size=batch_size,
sample_weights=sample_weight,
class_weights=class_weight,
validation_split=validation_split,
steps_per_epoch=steps_per_epoch,
shuffle=shuffle,
validation_data=validation_data,
validation_steps=validation_steps,
distribution_strategy=strategy)
total_samples = _get_total_number_of_samples(training_data_adapter)
use_sample = total_samples is not None
do_validation = (validation_adapter is not None)
if not steps_per_epoch:
steps_per_epoch = training_data_adapter.get_size()
# tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch))
training_context = TrainingContext()
initial_epoch = model._maybe_load_initial_epoch_from_ckpt(
initial_epoch, ModeKeys.TRAIN)
training_dataset = training_data_adapter.get_dataset()
# Raise an error if steps_per_epoch isn't specified but the dataset
# is infinite.
# TODO(scottzhu): This check should probably happen in the adapter
training_utils.infer_steps_for_dataset(
model,
training_dataset,
steps_per_epoch,
steps_name='steps_per_epoch',
epochs=0)
training_dataset = strategy.experimental_distribute_dataset(
training_dataset)
training_function = training_v2_utils._get_or_make_execution_function(
model, ModeKeys.TRAIN)
training_data_iter = None
# Only recreate iterator when the data has a fixed length, which will be
# fully consumed every epoch, or has a unknown length (dataset, generator)
# and will be fully consumed (steps_per_epoch is None)
recreate_training_iterator = (training_data_adapter.get_size() is not None
or steps_per_epoch is None)
if do_validation:
if not validation_steps:
validation_steps = validation_adapter.get_size()
eval_function = training_v2_utils._get_or_make_execution_function(
model, ModeKeys.TEST)
eval_data_iter = None
validation_dataset = validation_adapter.get_dataset()
# Raise an error if validation_steps isn't specified but the validation
# dataset is infinite.
# TODO(scottzhu): This check should probably happen in the adapter
training_utils.infer_steps_for_dataset(
model,
validation_dataset,
validation_steps,
steps_name='validation_steps',
epochs=0)
validation_dataset = strategy.experimental_distribute_dataset(
validation_dataset)
callbacks = cbks.configure_callbacks(
callbacks,
model,
do_validation=do_validation,
batch_size=batch_size,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
samples=total_samples,
count_mode='samples' if use_sample else 'steps',
verbose=0, # Handle ProgBarLogger separately in this loop.
mode=ModeKeys.TRAIN)
with training_context.on_start(
model, callbacks, use_sample, verbose, ModeKeys.TRAIN):
# TODO(scottzhu): Handle TPUStrategy training loop
for epoch in range(initial_epoch, epochs):
if training_context.callbacks.model.stop_training:
break
# Training
with training_context.on_epoch(epoch, ModeKeys.TRAIN) as epoch_logs:
model.reset_metrics()
if training_data_iter is None or recreate_training_iterator:
if (training_data_iter is not None and
distribution_strategy_context.has_strategy()):
# TODO(kaftan): remove this when MultiDeviceIterator is a
## compositetensor (unless this is more efficient)
training_data_iter._initializer # pylint: disable=pointless-statement
else:
training_data_iter = iter(training_dataset)
training_result = run_one_epoch(
model,
training_data_iter,
training_function,
dataset_size=training_data_adapter.get_size(),
batch_size=training_data_adapter.batch_size(),
strategy=strategy,
steps_per_epoch=steps_per_epoch,
num_samples=total_samples,
mode=ModeKeys.TRAIN,
training_context=training_context,
total_epochs=epochs)
cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN)
# Evaluation
if (do_validation and
training_utils.should_run_validation(validation_freq, epoch) and
not callbacks.model.stop_training):
if (eval_data_iter is not None and
distribution_strategy_context.has_strategy()):
# TODO(kaftan): remove this when MultiDeviceIterator is a
## compositetensor (unless this is more efficient)
eval_data_iter._initializer # pylint: disable=pointless-statement
else:
eval_data_iter = iter(validation_dataset)
val_total_samples = _get_total_number_of_samples(
validation_adapter)
eval_context = TrainingContext()
with eval_context.on_start(
model, callbacks, use_sample, verbose=0, mode=ModeKeys.TEST):
with eval_context.on_epoch(epoch, ModeKeys.TEST):
model.reset_metrics()
eval_result = run_one_epoch(
model,
eval_data_iter,
eval_function,
dataset_size=validation_adapter.get_size(),
batch_size=validation_adapter.batch_size(),
strategy=strategy,
steps_per_epoch=validation_steps,
num_samples=val_total_samples,
mode=ModeKeys.TEST,
training_context=eval_context,
total_epochs=1)
cbks.make_logs(model, epoch_logs, eval_result, ModeKeys.TEST,
prefix='val_')
return model.history
def _model_iteration(
self, model, mode, x=None, y=None, batch_size=None, verbose=1,
sample_weight=None, steps=None, callbacks=None, **kwargs):
batch_size = model._validate_or_infer_batch_size(
batch_size, steps, x)
strategy = _get_distribution_strategy(model)
batch_size, steps = dist_utils.process_batch_and_step_size(
strategy, x, batch_size, steps, mode)
dist_utils.validate_callbacks(input_callbacks=callbacks,
optimizer=model.optimizer)
# Enter tf.distribute.Strategy scope.
with strategy.scope():
adapter = _process_inputs(
model,
x,
y,
batch_size=batch_size,
sample_weights=sample_weight,
steps=steps,
distribution_strategy=strategy)
total_samples = _get_total_number_of_samples(adapter)
use_sample = total_samples is not None
if not steps:
steps = adapter.get_size()
# tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch))
training_context = TrainingContext()
dataset = adapter.get_dataset()
# Raise an error if `steps` isn't specified but the dataset
# is infinite.
# TODO(scottzhu): This check should probably happen in the adapter
training_utils.infer_steps_for_dataset(
model, dataset, steps, steps_name='steps', epochs=0)
dataset = strategy.experimental_distribute_dataset(dataset)
execution_function = training_v2_utils._get_or_make_execution_function(
model, mode)
data_iterator = iter(dataset)
callbacks = cbks.configure_callbacks(
callbacks,
model,
do_validation=False,
batch_size=batch_size,
epochs=1,
steps_per_epoch=steps,
samples=use_sample,
count_mode='samples' if use_sample else 'steps',
verbose=0, # Handle ProgBarLogger separately in this loop.
mode=mode)
with training_context.on_start(
model, callbacks, use_sample, verbose, mode):
# TODO(scottzhu): Handle TPUStrategy training loop
with training_context.on_epoch(0, mode) as epoch_logs:
model.reset_metrics()
result = run_one_epoch(
model,
data_iterator,
execution_function,
dataset_size=adapter.get_size(),
batch_size=adapter.batch_size(),
strategy=strategy,
steps_per_epoch=steps,
num_samples=total_samples,
mode=mode,
training_context=training_context,
total_epochs=1)
cbks.make_logs(model, epoch_logs, result, mode)
if len(result) == 1:
result = result[0]
return result
def evaluate(
self, model, x=None, y=None, batch_size=None, verbose=1,
sample_weight=None, steps=None, callbacks=None, **kwargs):
return self._model_iteration(
model, ModeKeys.TEST, x=x, y=y, batch_size=batch_size, verbose=verbose,
sample_weight=sample_weight, steps=steps, callbacks=callbacks, **kwargs)
def predict(self, model, x, batch_size=None, verbose=0, steps=None,
callbacks=None, **kwargs):
return self._model_iteration(
model, ModeKeys.PREDICT, x=x, batch_size=batch_size, verbose=verbose,
steps=steps, callbacks=callbacks, **kwargs)
def _get_distribution_strategy(model):
"""Get the model's distribution strategy."""
if model._compile_time_distribution_strategy:
strategy = model._compile_time_distribution_strategy
else:
# Grab the active strategy if the model was never compiled
# but it is now predicting.
strategy = distribution_strategy_context.get_strategy()
return strategy
def _process_training_inputs(model, x, y, batch_size=None,
sample_weights=None, class_weights=None,
steps_per_epoch=None, validation_split=0.,
validation_data=None, validation_steps=None,
shuffle=True, distribution_strategy=None):
"""Process the data input for fit() with respect to validation_split."""
if validation_split and 0. < validation_split < 1. and validation_data:
raise ValueError('validation_data and validation_split cannot be used '
'at same time.')
adapter_cls = data_adapter.select_data_adapter(x, y)
# Handle validation_split, we want to split the data and get the training
# section before we give it to data adapter.
if validation_split and 0. < validation_split < 1.:
if adapter_cls not in _ADAPTER_FOR_VALIDATION_SPLIT:
raise ValueError(
'`validation_split` argument is not supported when '
'data adapter is {}. Received: x={}, validation_split={}'.format(
adapter_cls, x, validation_split))
# Retrieve the training section from x and y, and then construct dataset
# from it.
x, y, sample_weights = model._standardize_user_data(
x, y, sample_weight=sample_weights,
class_weight=class_weights,
batch_size=batch_size,
check_steps=True,
steps=steps_per_epoch)
(x, y, sample_weights,
val_x, val_y,
val_sample_weights) = training_utils.split_training_and_validation_data(
x, y, sample_weights, validation_split)
train_adapter = adapter_cls(x, y, batch_size=batch_size,
sample_weights=sample_weights, shuffle=shuffle,
distribution_strategy=distribution_strategy)
val_adapter = adapter_cls(val_x, val_y,
sample_weights=val_sample_weights,
batch_size=batch_size,
distribution_strategy=distribution_strategy)
else:
train_adapter = _process_inputs(model, x, y, sample_weights=sample_weights,
batch_size=batch_size,
class_weights=class_weights,
shuffle=shuffle, steps=steps_per_epoch,
distribution_strategy=distribution_strategy)
val_adapter = None
if validation_data:
(val_x, val_y,
val_sample_weights) = training_utils.unpack_validation_data(
validation_data)
# For eval data, we use the training data batch_size it was unknown.
# This is useful for generator/sequence training data input with numpy
# validation data input.
if not batch_size:
batch_size = train_adapter.batch_size()
val_adapter = _process_inputs(model, val_x, val_y,
sample_weights=val_sample_weights,
batch_size=batch_size,
class_weights=class_weights,
steps=validation_steps,
distribution_strategy=distribution_strategy)
elif validation_steps:
raise ValueError('`validation_steps` should not be specified if '
'`validation_data` is None.')
return train_adapter, val_adapter
def _process_inputs(model, x, y, batch_size=None, sample_weights=None,
class_weights=None, shuffle=False, steps=None,
distribution_strategy=None):
"""Process the inputs for fit/eval/predict()."""
adapter_cls = data_adapter.select_data_adapter(x, y)
if adapter_cls in _ADAPTER_FOR_STANDARDIZE_USER_DATA:
x, y, sample_weights = model._standardize_user_data(
x,
y,
sample_weight=sample_weights,
class_weight=class_weights,
batch_size=batch_size,
check_steps=True,
steps=steps)
adapter = adapter_cls(x, y, batch_size=batch_size, steps=steps,
sample_weights=sample_weights, shuffle=shuffle,
distribution_strategy=distribution_strategy)
# As a fallback for the data type that does not work with
# _standardize_user_data, use the _prepare_model_with_inputs.
if adapter_cls not in _ADAPTER_FOR_STANDARDIZE_USER_DATA:
training_v2_utils._prepare_model_with_inputs(model, adapter.get_dataset())
return adapter
def _get_total_number_of_samples(adapter):
if not adapter.get_size() or not adapter.batch_size():
return None
total_sample = adapter.get_size() * adapter.batch_size()
if adapter.has_partial_batch():
total_sample -= (adapter.batch_size() - adapter.partial_batch_size())
return total_sample
def _aggregate_predict_results(strategy, batch_outs, model):
if not isinstance(batch_outs, list):
batch_outs = [batch_outs]
total_batch_outs = []
for i in range(len(model.outputs)):
num_replicas = strategy.num_replicas_in_sync
nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
total_batch_outs.append(
dist_utils.concat_along_batch_dimension(nest.flatten(nested_outs)))
return total_batch_outs
class TrainingContext(object):
"""Utility object that wrap around callbacks and progress bars."""
@tf_contextlib.contextmanager
def on_start(self, model, callbacks=None, use_samples=False, verbose=0,
mode=ModeKeys.TRAIN):
"""Provide a scope for the whole training process."""
# TODO(omalleyt): Handle ProgBar as part of Callbacks once hooks are ready.
progbar = training_utils.get_progbar(
model, 'samples' if use_samples else 'steps')
progbar.params = callbacks.params
progbar.params['verbose'] = verbose
callbacks.model.stop_training = False
callbacks._call_begin_hook(mode)
progbar.on_train_begin()
# Cache those two instance so that it can be used in other functions.
self.callbacks = callbacks
self.progbar = progbar
try:
yield
finally:
# End of all epochs
self.callbacks._call_end_hook(mode)
@tf_contextlib.contextmanager
def on_epoch(self, epoch=0, mode=ModeKeys.TRAIN):
"""Provide a scope for running one epoch."""
epoch_logs = {}
if mode == ModeKeys.TRAIN:
self.callbacks.on_epoch_begin(epoch, epoch_logs)
self.progbar.on_epoch_begin(epoch, epoch_logs)
try:
yield epoch_logs
finally:
if mode == ModeKeys.TRAIN:
# Epochs only apply to `fit`.
self.callbacks.on_epoch_end(epoch, epoch_logs)
self.progbar.on_epoch_end(epoch, epoch_logs)
@tf_contextlib.contextmanager
def on_batch(self, step=0, mode=ModeKeys.TRAIN, size=1):
"""Provide a scope for running one batch."""
batch_logs = {'batch': step, 'size': size}
self.callbacks._call_batch_hook(
mode, 'begin', step, batch_logs)
self.progbar.on_batch_begin(step, batch_logs)
try:
yield batch_logs
finally:
if not batch_logs.pop('data_exhausted', False):
self.callbacks._call_batch_hook(
mode, 'end', step, batch_logs)
self.progbar.on_batch_end(step, batch_logs)