| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Part of the Keras training engine related to distributed training. |
| """ |
| # pylint: disable=protected-access |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| import numpy as np |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.framework import errors |
| from tensorflow.python.keras import backend as K |
| from tensorflow.python.keras import callbacks as cbks |
| from tensorflow.python.keras import optimizers |
| from tensorflow.python.keras.engine import distributed_training_utils |
| from tensorflow.python.keras.utils.generic_utils import Progbar |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.platform import tf_logging as logging |
| from tensorflow.python.training import distribute as distribute_lib |
| |
| |
| # TODO(priyag, sourabhbajaj): Refactor this file to address code duplication. |
| |
| |
| def fit_loop( |
| model, |
| iterator, |
| epochs=100, |
| verbose=1, |
| callbacks=None, |
| val_iterator=None, |
| initial_epoch=0, |
| steps_per_epoch=None, |
| validation_steps=None): |
| """Fit loop for training with DistributionStrategy. |
| |
| Arguments: |
| model: Keras Model instance. |
| iterator: Iterator for input data. |
| epochs: Number of times to iterate over the data |
| verbose: Integer, Verbosity mode, 0, 1 or 2 |
| callbacks: List of callbacks to be called during training |
| val_iterator: Iterator for validation data. |
| initial_epoch: Epoch at which to start training |
| (useful for resuming a previous training run) |
| steps_per_epoch: Total number of steps (batches of samples) |
| before declaring one epoch finished and starting the |
| next epoch. Ignored with the default value of `None`. |
| validation_steps: Number of steps to run validation for |
| (only if doing validation from data tensors). |
| Ignored with the default value of `None`. |
| |
| Returns: |
| `History` object. |
| |
| Raises: |
| ValueError: in case of invalid arguments. |
| """ |
| current_strategy = model._distribution_strategy |
| |
| # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged. |
| if current_strategy.__class__.__name__ == 'TPUStrategy': |
| return _experimental_fit_loop( |
| model, iterator, epochs, verbose, callbacks, initial_epoch, |
| steps_per_epoch) |
| |
| if not model._grouped_model: |
| clone_model_on_towers(model, current_strategy, make_callback_model=True) |
| |
| def _per_device_train_function(model): |
| model._make_train_function() |
| return (model.train_function.inputs, |
| model.train_function.outputs, |
| model.train_function.updates_op, |
| model.train_function.session_kwargs) |
| |
| inputs, targets = _get_input_from_iterator(iterator, model) |
| with current_strategy.scope(): |
| # Create train ops on each of the devices when we call |
| # `_per_device_train_function`. |
| (grouped_inputs, grouped_outputs, grouped_updates, |
| grouped_session_args) = current_strategy.call_for_each_tower( |
| _per_device_train_function, model._grouped_model) |
| # Unwrap all the per device values returned from `call_for_each_tower`. |
| # Unwrapping per device values gives you a list of values that can be |
| # used to construct a new train function that is composed of update ops on |
| # all the devices over which the model is distributed. |
| (all_inputs, all_outputs, all_updates, |
| all_session_args) = distributed_training_utils.unwrap_values( |
| current_strategy, grouped_inputs, grouped_outputs, |
| grouped_updates, grouped_session_args, with_loss_tensor=True) |
| |
| # Dataset inputs and targets are also per devices values that need to be |
| # unwrapped. |
| dataset_inputs = distributed_training_utils.flatten_perdevice_values( |
| current_strategy, inputs) |
| dataset_targets = distributed_training_utils.flatten_perdevice_values( |
| current_strategy, targets) |
| |
| # Create a train function that is composed of all the parameters above. |
| distributed_train_function = K.Function( |
| all_inputs, all_outputs, |
| updates=all_updates, |
| name='distributed_train_function', |
| **all_session_args) |
| |
| # We need to set sample_weights to None since there are sample weight |
| # placeholders that are created with default values. |
| sample_weights = [None for _ in range(len(model.outputs) * |
| current_strategy.num_towers)] |
| if model.uses_learning_phase and not isinstance(K.learning_phase(), int): |
| ins = dataset_inputs + dataset_targets + sample_weights + [1] |
| else: |
| ins = dataset_inputs + dataset_targets |
| |
| do_validation = False |
| if validation_steps: |
| do_validation = True |
| |
| # Copy the weights from the original model to each of the replicated models. |
| orig_model_weights = model.get_weights() |
| with current_strategy.scope(): |
| distributed_model = current_strategy.unwrap(model._grouped_model)[0] |
| distributed_training_utils.set_weights( |
| current_strategy, distributed_model, orig_model_weights) |
| |
| callbacks = cbks.configure_callbacks( |
| callbacks, |
| model, |
| do_validation=do_validation, |
| val_inputs=None, |
| val_targets=None, |
| epochs=epochs, |
| steps_per_epoch=steps_per_epoch, |
| verbose=verbose) |
| out_labels = model.metrics_names or [] |
| callbacks.on_train_begin() |
| |
| assert steps_per_epoch is not None |
| |
| for epoch in range(initial_epoch, epochs): |
| callbacks.on_epoch_begin(epoch) |
| epoch_logs = {} |
| for step_index in range(steps_per_epoch): |
| batch_logs = {'batch': step_index, 'size': 1} |
| callbacks.on_batch_begin(step_index, batch_logs) |
| try: |
| outs = distributed_train_function(ins) |
| except errors.OutOfRangeError: |
| logging.warning('Your dataset iterator ran out of data; ' |
| 'interrupting training. Make sure that your dataset ' |
| 'can generate at least `steps_per_epoch * epochs` ' |
| 'batches (in this case, %d batches).' % |
| steps_per_epoch * epochs) |
| break |
| |
| if not isinstance(outs, list): |
| outs = [outs] |
| |
| outs = _aggregate_metrics_across_towers( |
| current_strategy.num_towers, out_labels, outs) |
| for l, o in zip(out_labels, outs): |
| batch_logs[l] = o |
| callbacks.on_batch_end(step_index, batch_logs) |
| if callbacks.model.stop_training: |
| break |
| if do_validation: |
| val_outs = test_loop( |
| model, |
| val_iterator, |
| steps=validation_steps, |
| verbose=0) |
| if not isinstance(val_outs, list): |
| val_outs = [val_outs] |
| # Same labels assumed. |
| for l, o in zip(out_labels, val_outs): |
| epoch_logs['val_' + l] = o |
| |
| callbacks.on_epoch_end(epoch, epoch_logs) |
| if callbacks.model.stop_training: |
| break |
| callbacks.on_train_end() |
| |
| # Copy the weights back from the replicated model to the original model. |
| with current_strategy.scope(): |
| updated_weights = current_strategy.unwrap( |
| model._grouped_model)[0].get_weights() |
| model.set_weights(updated_weights) |
| return model.history |
| |
| |
| def _experimental_fit_loop( |
| model, |
| iterator, |
| epochs=100, |
| verbose=1, |
| callbacks=None, |
| initial_epoch=0, |
| steps_per_epoch=None): |
| """Fit loop for training with TPU DistributionStrategy. |
| |
| Arguments: |
| model: Keras Model instance. |
| iterator: Iterator that returns inputs and targets |
| epochs: Number of times to iterate over the data |
| verbose: Integer, Verbosity mode, 0, 1 or 2 |
| callbacks: List of callbacks to be called during training |
| initial_epoch: Epoch at which to start training |
| (useful for resuming a previous training run) |
| steps_per_epoch: Total number of steps (batches of samples) |
| before declaring one epoch finished and starting the |
| next epoch. Ignored with the default value of `None`. |
| |
| Returns: |
| Returns `None`. |
| |
| Raises: |
| ValueError: in case of invalid arguments. |
| """ |
| current_strategy = model._distribution_strategy |
| |
| # TODO(priyag): Add validation that shapes are fully defined for TPU case. |
| |
| K.get_session().run(current_strategy.initialize()) |
| |
| def _per_device_train_function(model): |
| model._make_train_function() |
| return (model.train_function.inputs, |
| model.train_function.outputs, |
| model.train_function.updates_op, |
| model.train_function.session_kwargs) |
| |
| # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here. |
| K.set_learning_phase(1) |
| |
| def step_fn(ctx, inputs, targets): |
| """Clones the model and calls make_train_function.""" |
| # TODO(priyag, sourabhbajaj): The model gets cloned every time |
| # fit/test/predict is called. We should look into caching this keyed on |
| # input shapes. |
| clone_model_on_towers( |
| model, |
| current_strategy, |
| make_callback_model=True, |
| inputs=inputs, |
| targets=targets) |
| |
| (grouped_inputs, grouped_outputs, grouped_updates, |
| grouped_session_args) = current_strategy.call_for_each_tower( |
| _per_device_train_function, model._grouped_model) |
| (all_inputs, all_outputs, all_updates, |
| all_session_args) = distributed_training_utils.unwrap_values( |
| current_strategy, grouped_inputs, grouped_outputs, |
| grouped_updates, grouped_session_args) |
| combined_fn = K.Function( |
| all_inputs, all_outputs, |
| updates=all_updates, |
| name='distributed_train_function', |
| **all_session_args) |
| |
| out_labels = model.metrics_names or [] |
| for label, output in zip(out_labels, combined_fn.outputs): |
| if label == 'loss': |
| aggregation = distribute_lib.get_loss_reduction() |
| else: |
| # We aggregate all other metrics using mean for now. This is temporary |
| # workaround until new metrics are in place. |
| aggregation = variable_scope.VariableAggregation.MEAN |
| ctx.set_last_step_output(label, output, aggregation) |
| |
| # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: |
| # feed_dict, session kwargs, run options, run_metadata for now. These should |
| # be handled appropriately |
| return combined_fn.updates_op |
| |
| # Add initial dummy values for loss and other metric tensors. |
| initial_loop_values = {} |
| initial_loop_values['loss'] = constant_op.constant(1e7) |
| for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors): |
| initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype) |
| |
| if steps_per_epoch is None: |
| raise ValueError('steps_per_epoch should be specified in the fit call.') |
| steps_per_run_var = K.variable( |
| value=min(steps_per_epoch, current_strategy.steps_per_run), |
| dtype='int32', |
| name='steps_per_run_var') |
| |
| with current_strategy.scope(): |
| ctx = current_strategy.run_steps_on_dataset( |
| step_fn, iterator, iterations=steps_per_run_var, |
| initial_loop_values=initial_loop_values) |
| |
| train_op = ctx.run_op |
| output_tensors = ctx.last_step_outputs |
| |
| # Copy the weights from the original model to each of the replicated models. |
| orig_model_weights = model.get_weights() |
| with current_strategy.scope(): |
| distributed_model = current_strategy.unwrap(model._grouped_model)[0] |
| distributed_training_utils.set_weights( |
| current_strategy, distributed_model, orig_model_weights) |
| callbacks = cbks.configure_callbacks( |
| callbacks, |
| model, |
| do_validation=False, |
| val_inputs=None, |
| val_targets=None, |
| epochs=epochs, |
| steps_per_epoch=steps_per_epoch, |
| verbose=verbose) |
| # TODO(priyag, sourabhbajaj): Add callbacks support for per step callback |
| # TODO(priyag, sourabhbajaj): Add validation. |
| |
| # Calculate the steps each time on the device. |
| steps_to_run = [current_strategy.steps_per_run] * ( |
| steps_per_epoch // current_strategy.steps_per_run) |
| if steps_per_epoch % current_strategy.steps_per_run: |
| steps_to_run.append(steps_per_epoch % current_strategy.steps_per_run) |
| |
| callbacks.on_train_begin() |
| for epoch in range(initial_epoch, epochs): |
| callbacks.on_epoch_begin(epoch) |
| epoch_logs = {} |
| step_index = 0 |
| prev_step_count = None |
| for step_count in steps_to_run: |
| batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count} |
| callbacks.on_batch_begin(step_index, batch_logs) |
| if prev_step_count is None or step_count != prev_step_count: |
| steps_per_run_var.load(step_count, K.get_session()) |
| prev_step_count = step_count |
| try: |
| _, outputs = K.get_session().run([train_op, output_tensors]) |
| except errors.OutOfRangeError: |
| logging.warning('Your dataset iterator ran out of data; ' |
| 'interrupting training. Make sure that your dataset ' |
| 'can generate at least `steps_per_epoch * epochs` ' |
| 'batches (in this case, %d batches).' % |
| steps_per_epoch * epochs) |
| break |
| |
| batch_logs.update(outputs) |
| callbacks.on_batch_end(step_index, batch_logs) |
| step_index = step_index + step_count |
| if callbacks.model.stop_training: |
| break |
| |
| callbacks.on_epoch_end(epoch, epoch_logs) |
| if callbacks.model.stop_training: |
| break |
| callbacks.on_train_end() |
| |
| # Copy the weights back from the replicated model to the original model. |
| with current_strategy.scope(): |
| updated_weights = current_strategy.unwrap( |
| model._grouped_model)[0].get_weights() |
| model.set_weights(updated_weights) |
| |
| K.get_session().run(current_strategy.finalize()) |
| return model.history |
| |
| |
| def test_loop(model, iterator, verbose=0, steps=None): |
| """Test loop for evaluating with DistributionStrategy. |
| |
| Arguments: |
| model: Keras Model instance. |
| iterator: Iterator for input data. |
| verbose: Integer, Verbosity mode 0 or 1. |
| steps: Total number of steps (batches of samples) |
| before declaring predictions finished. |
| Ignored with the default value of `None`. |
| |
| Returns: |
| Scalar loss (if the model has a single output and no metrics) |
| or list of scalars (if the model has multiple outputs |
| and/or metrics). The attribute `model.metrics_names` will give you |
| the display labels for the outputs. |
| """ |
| current_strategy = model._distribution_strategy |
| |
| # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged. |
| if current_strategy.__class__.__name__ == 'TPUStrategy': |
| return _experimental_test_loop(model, iterator, verbose, steps) |
| |
| if not model._grouped_model: |
| clone_model_on_towers(model, current_strategy) |
| |
| def _per_device_test_function(model): |
| model._make_test_function() |
| return (model.test_function.inputs, |
| model.test_function.outputs, |
| model.test_function.updates_op, |
| model.test_function.session_kwargs) |
| |
| inputs, targets = _get_input_from_iterator(iterator, model) |
| with current_strategy.scope(): |
| (grouped_inputs, grouped_outputs, grouped_updates, |
| grouped_session_args) = current_strategy.call_for_each_tower( |
| _per_device_test_function, model._grouped_model) |
| |
| (all_inputs, all_outputs, all_updates, |
| all_session_args) = distributed_training_utils.unwrap_values( |
| current_strategy, grouped_inputs, grouped_outputs, grouped_updates, |
| grouped_session_args, with_loss_tensor=True) |
| |
| dataset_inputs = distributed_training_utils.flatten_perdevice_values( |
| current_strategy, inputs) |
| dataset_targets = distributed_training_utils.flatten_perdevice_values( |
| current_strategy, targets) |
| |
| distributed_test_function = K.Function( |
| all_inputs, all_outputs, |
| updates=all_updates, |
| name='distributed_test_function', |
| **all_session_args) |
| |
| # We need to set sample_weights to None since there are sample weight |
| # placeholders that are created with default values. |
| sample_weights = [None for _ in range(len(model.outputs) * |
| current_strategy.num_towers)] |
| if model.uses_learning_phase and not isinstance(K.learning_phase(), int): |
| ins = dataset_inputs + dataset_targets + sample_weights + [0] |
| else: |
| ins = dataset_inputs + dataset_targets |
| |
| outs = [] |
| if verbose == 1: |
| progbar = Progbar(target=steps) |
| |
| # Copy the weights from the original model to each of the replicated models. |
| orig_model_weights = model.get_weights() |
| with current_strategy.scope(): |
| distributed_model = current_strategy.unwrap(model._grouped_model)[0] |
| distributed_training_utils.set_weights( |
| current_strategy, distributed_model, orig_model_weights) |
| |
| assert steps is not None |
| for step in range(steps): |
| batch_outs = distributed_test_function(ins) |
| batch_outs = _aggregate_metrics_across_towers( |
| current_strategy.num_towers, model.metrics_names, batch_outs) |
| if isinstance(batch_outs, list): |
| if step == 0: |
| outs = [0.] * len(batch_outs) |
| for i, batch_out in enumerate(batch_outs): |
| outs[i] += batch_out |
| else: |
| if step == 0: |
| outs.append(0.) |
| outs[0] += batch_outs |
| if verbose >= 1: |
| progbar.update(step + 1) |
| for i in range(len(outs)): |
| outs[i] /= steps |
| |
| if len(outs) == 1: |
| return outs[0] |
| return outs |
| |
| |
| def _experimental_test_loop(model, iterator, verbose=0, steps=None): |
| """Test loop for evaluating with TPU DistributionStrategy. |
| |
| Arguments: |
| model: Keras Model instance. |
| iterator: Iterator for input data. |
| verbose: Integer, Verbosity mode 0 or 1. |
| steps: Total number of steps (batches of samples) |
| before declaring predictions finished. |
| Ignored with the default value of `None`. |
| |
| Returns: |
| Scalar loss (if the model has a single output and no metrics) |
| or list of scalars (if the model has multiple outputs |
| and/or metrics). The attribute `model.metrics_names` will give you |
| the display labels for the outputs. |
| """ |
| current_strategy = model._distribution_strategy |
| K.get_session().run(current_strategy.initialize()) |
| |
| def _per_device_test_function(model): |
| model._make_test_function() |
| return (model.test_function.inputs, |
| model.test_function.outputs, |
| model.test_function.updates_op, |
| model.test_function.session_kwargs) |
| |
| # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here. |
| K.set_learning_phase(0) |
| |
| def step_fn(ctx, inputs, targets): |
| """Clones the model and calls make_test_function.""" |
| # TODO(priyag, sourabhbajaj): The model gets cloned every time |
| # fit/test/predict is called. We should look into caching this keyed on |
| # input shapes. |
| clone_model_on_towers( |
| model, |
| current_strategy, |
| make_callback_model=False, |
| inputs=inputs, |
| targets=targets) |
| |
| (grouped_inputs, grouped_outputs, grouped_updates, |
| grouped_session_args) = current_strategy.call_for_each_tower( |
| _per_device_test_function, model._grouped_model) |
| |
| (all_inputs, all_outputs, all_updates, |
| all_session_args) = distributed_training_utils.unwrap_values( |
| current_strategy, grouped_inputs, grouped_outputs, grouped_updates, |
| grouped_session_args) |
| |
| combined_fn = K.Function( |
| all_inputs, all_outputs, |
| updates=all_updates, |
| name='distributed_test_function', |
| **all_session_args) |
| |
| for label, output in zip(model.metrics_names, combined_fn.outputs): |
| if label == 'loss': |
| aggregation = distribute_lib.get_loss_reduction() |
| else: |
| # We aggregate all other metrics using mean for now. This is temporary |
| # workaround until new metrics are in place. |
| aggregation = variable_scope.VariableAggregation.MEAN |
| ctx.set_last_step_output(label, output, aggregation) |
| |
| return combined_fn.updates_op |
| |
| # Add initial dummy values for loss and other metric tensors. |
| initial_loop_values = {} |
| initial_loop_values['loss'] = constant_op.constant(1e7) |
| for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors): |
| initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype) |
| |
| with current_strategy.scope(): |
| # TODO(priyag): Use steps_per_run when we use new metrics as they will |
| # allow handling metric computation at each step using variables. |
| ctx = current_strategy.run_steps_on_dataset( |
| step_fn, iterator, iterations=1, |
| initial_loop_values=initial_loop_values) |
| |
| test_op = ctx.run_op |
| output_tensors = ctx.last_step_outputs |
| |
| if verbose == 1: |
| progbar = Progbar(target=steps) |
| |
| # Copy the weights from the original model to each of the replicated models. |
| orig_model_weights = model.get_weights() |
| with current_strategy.scope(): |
| distributed_model = current_strategy.unwrap(model._grouped_model)[0] |
| distributed_training_utils.set_weights( |
| current_strategy, distributed_model, orig_model_weights) |
| |
| assert steps is not None |
| outs = [0.] * len(model.metrics_names) |
| for step in range(steps): |
| _, batch_outs = K.get_session().run([test_op, output_tensors]) |
| for i, label in enumerate(model.metrics_names): |
| outs[i] += batch_outs[label] |
| if verbose >= 1: |
| progbar.update(step + 1) |
| for i in range(len(outs)): |
| outs[i] /= (steps) |
| |
| K.get_session().run(current_strategy.finalize()) |
| |
| if len(outs) == 1: |
| return outs[0] |
| return outs |
| |
| |
| def predict_loop(model, iterator, verbose=0, steps=None): |
| """Predict loop for predicting with DistributionStrategy. |
| |
| Arguments: |
| model: Keras Model instance. |
| iterator: Iterator for input data. |
| verbose: Integer, Verbosity mode 0 or 1. |
| steps: Total number of steps (batches of samples) |
| before declaring `_predict_loop` finished. |
| Ignored with the default value of `None`. |
| |
| Returns: |
| Array of predictions (if the model has a single output) |
| or list of arrays of predictions |
| (if the model has multiple outputs). |
| """ |
| current_strategy = model._distribution_strategy |
| |
| # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged. |
| if current_strategy.__class__.__name__ == 'TPUStrategy': |
| return _experimental_predict_loop(model, iterator, verbose, steps) |
| |
| if not model._grouped_model: |
| clone_model_on_towers(model, current_strategy) |
| |
| def _per_device_predict_function(model): |
| model._make_predict_function() |
| return (model.predict_function.inputs, |
| model.predict_function.outputs, |
| model.predict_function.updates_op, |
| model.predict_function.session_kwargs) |
| |
| inputs, _ = _get_input_from_iterator(iterator, model) |
| with current_strategy.scope(): |
| (grouped_inputs, grouped_outputs, grouped_updates, |
| grouped_session_args) = current_strategy.call_for_each_tower( |
| _per_device_predict_function, model._grouped_model) |
| |
| (all_inputs, all_outputs, all_updates, |
| all_session_args) = distributed_training_utils.unwrap_values( |
| current_strategy, grouped_inputs, grouped_outputs, grouped_updates, |
| grouped_session_args) |
| |
| dataset_inputs = distributed_training_utils.flatten_perdevice_values( |
| current_strategy, inputs) |
| |
| distributed_predict_function = K.Function( |
| all_inputs, all_outputs, |
| updates=all_updates, |
| name='distributed_predict_function', |
| **all_session_args) |
| |
| if model.uses_learning_phase and not isinstance(K.learning_phase(), int): |
| ins = dataset_inputs + [0] |
| else: |
| ins = dataset_inputs |
| |
| if verbose == 1: |
| progbar = Progbar(target=steps) |
| |
| # Copy the weights from the original model to each of the replicated models. |
| orig_model_weights = model.get_weights() |
| with current_strategy.scope(): |
| distributed_model = current_strategy.unwrap(model._grouped_model)[0] |
| distributed_training_utils.set_weights( |
| current_strategy, distributed_model, orig_model_weights) |
| |
| if steps is not None: |
| # Since we do not know how many samples we will see, we cannot pre-allocate |
| # the returned Numpy arrays. Instead, we store one array per batch seen |
| # and concatenate them upon returning. |
| unconcatenated_outs = [] |
| for step in range(steps): |
| batch_outs = distributed_predict_function(ins) |
| if not isinstance(batch_outs, list): |
| batch_outs = [batch_outs] |
| if step == 0: |
| for _ in batch_outs: |
| unconcatenated_outs.append([]) |
| # TODO(anjalisridhar): Should combine the outputs from multiple towers |
| # correctly here. |
| for i, batch_out in enumerate(batch_outs): |
| unconcatenated_outs[i].append(batch_out) |
| if verbose >= 1: |
| progbar.update(step + 1) |
| if len(unconcatenated_outs) == 1: |
| return np.concatenate(unconcatenated_outs[0], axis=0) |
| return [ |
| np.concatenate(unconcatenated_outs[i], axis=0) |
| for i in range(len(unconcatenated_outs)) |
| ] |
| |
| |
| def _experimental_predict_loop(model, iterator, verbose=0, steps=None): |
| """Predict loop for predicting with TPU DistributionStrategy. |
| |
| Arguments: |
| model: Keras Model instance. |
| iterator: Iterator for input data. |
| verbose: Integer, Verbosity mode 0 or 1. |
| steps: Total number of steps (batches of samples) |
| before declaring `_predict_loop` finished. |
| Ignored with the default value of `None`. |
| |
| Returns: |
| Array of predictions (if the model has a single output) |
| or list of arrays of predictions |
| (if the model has multiple outputs). |
| """ |
| current_strategy = model._distribution_strategy |
| K.get_session().run(current_strategy.initialize()) |
| |
| # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here. |
| K.set_learning_phase(0) |
| |
| def _per_device_predict_function(model): |
| model._make_predict_function() |
| return (model.predict_function.inputs, |
| model.predict_function.outputs, |
| model.predict_function.updates_op, |
| model.predict_function.session_kwargs) |
| |
| def step_fn(ctx, inputs, targets): |
| """Clones the model and calls make_predict_function.""" |
| |
| # TODO(anjalisridhar): Support predict input correctly as it will not |
| # contain targets, only inputs. |
| del targets |
| |
| # TODO(priyag, sourabhbajaj): The model gets cloned every time |
| # fit/test/predict is called. We should look into caching this keyed on |
| # input shapes. |
| clone_model_on_towers( |
| model, |
| current_strategy, |
| make_callback_model=False, |
| inputs=inputs) |
| |
| (grouped_inputs, grouped_outputs, grouped_updates, |
| grouped_session_args) = current_strategy.call_for_each_tower( |
| _per_device_predict_function, model._grouped_model) |
| |
| (all_inputs, all_outputs, all_updates, |
| all_session_args) = distributed_training_utils.unwrap_values( |
| current_strategy, grouped_inputs, grouped_outputs, grouped_updates, |
| grouped_session_args) |
| |
| combined_fn = K.Function( |
| all_inputs, all_outputs, |
| updates=all_updates, |
| name='distributed_predict_function', |
| **all_session_args) |
| |
| for label, output in zip(model.output_names, combined_fn.outputs): |
| ctx.set_last_step_output(label, output) |
| |
| return combined_fn.updates_op |
| |
| # Add initial dummy values for outputs. |
| initial_loop_values = {} |
| batch_dimension = distributed_training_utils.get_batch_dimension(iterator) |
| for name, tensor in zip(model.output_names, model.outputs): |
| # TODO(priyag): This is a workaround as we do not know the batch dimension |
| # of the model's output at this point. |
| shape = tensor_shape.TensorShape(tensor.shape.dims) |
| shape.dims = [batch_dimension] + shape.dims[1:] |
| initial_loop_values[name] = array_ops.zeros(shape, tensor.dtype) |
| |
| with current_strategy.scope(): |
| # TODO(priyag, sourabhbajaj): Support steps_per_run if/when we add outfeed. |
| ctx = current_strategy.run_steps_on_dataset( |
| step_fn, iterator, iterations=1, |
| initial_loop_values=initial_loop_values) |
| |
| predict_op = ctx.run_op |
| output_tensors = ctx.last_step_outputs |
| |
| if verbose == 1: |
| progbar = Progbar(target=steps) |
| |
| # Copy the weights from the original model to each of the replicated models. |
| orig_model_weights = model.get_weights() |
| with current_strategy.scope(): |
| distributed_model = current_strategy.unwrap(model._grouped_model)[0] |
| distributed_training_utils.set_weights( |
| current_strategy, distributed_model, orig_model_weights) |
| |
| assert steps is not None |
| # Since we do not know how many samples we will see, we cannot pre-allocate |
| # the returned Numpy arrays. Instead, we store one array per batch seen |
| # and concatenate them upon returning. |
| unconcatenated_outs = [[] for _ in model.outputs] |
| for step in range(steps): |
| _, batch_outs = K.get_session().run([predict_op, output_tensors]) |
| # TODO(priyag): maybe need to unwrap the outputs first for MirroredStrategy. |
| for i, label in enumerate(model.output_names): |
| unconcatenated_outs[i].extend(batch_outs[label]) |
| if verbose >= 1: |
| progbar.update(step + 1) |
| |
| K.get_session().run(current_strategy.finalize()) |
| |
| if len(unconcatenated_outs) == 1: |
| return np.concatenate(unconcatenated_outs[0], axis=0) |
| return [ |
| np.concatenate(unconcatenated_outs[i], axis=0) |
| for i in range(len(unconcatenated_outs)) |
| ] |
| |
| |
| def _clone_and_build_model(model, inputs=None, targets=None): |
| """Clone and build the given keras_model.""" |
| # We need to set the import here since we run into a circular dependency |
| # error. |
| from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top |
| cloned_model = models.clone_model(model, input_tensors=inputs) |
| |
| # Compile and build model. |
| if isinstance(model.optimizer, optimizers.TFOptimizer): |
| optimizer = model.optimizer |
| else: |
| optimizer_config = model.optimizer.get_config() |
| optimizer = model.optimizer.__class__.from_config(optimizer_config) |
| |
| # TODO(priyag): Is there a cleaner way to do this? The API doc suggests a |
| # single tensor should be OK but it throws an error in that case. |
| if (targets is not None and not isinstance(targets, list) and |
| not isinstance(targets, dict)): |
| targets = [targets] |
| cloned_model.compile( |
| optimizer, |
| model.loss, |
| metrics=model.metrics, |
| loss_weights=model.loss_weights, |
| sample_weight_mode=model.sample_weight_mode, |
| weighted_metrics=model.weighted_metrics, |
| target_tensors=targets) |
| return cloned_model |
| |
| |
| def clone_model_on_towers( |
| model, strategy, make_callback_model=False, inputs=None, targets=None): |
| """Create a cloned model on each tower.""" |
| with strategy.scope(): |
| model._grouped_model = strategy.call_for_each_tower( |
| _clone_and_build_model, model, inputs, targets) |
| if make_callback_model: |
| model._make_callback_model() |
| |
| |
| def _aggregate_metrics_across_towers(num_devices, out_labels, outs): |
| """Aggregate metrics values across all towers. |
| |
| When using `MirroredStrategy`, the number of towers is equal to the |
| number of devices over which training is distributed. This may not always be |
| the case. |
| |
| Args: |
| num_devices: Number of devices over which the model is being distributed. |
| out_labels: The list of metric names passed to `compile`. |
| outs: The output from all the towers. |
| |
| Returns: |
| The average value of each metric across the towers. |
| """ |
| # TODO(anjalisridhar): Temporary workaround for aggregating metrics |
| # across towers. Replace with the new metrics module eventually. |
| merged_output = [] |
| # The first output is the total loss. |
| merged_output.append(outs[0]) |
| current_index = 1 |
| # Each label in `out_labels` corresponds to one set of metrics. The |
| # number of metric values corresponds to the number of devices. We |
| # currently take the mean of the values. |
| for _ in out_labels[1:]: |
| m = np.mean(outs[current_index:current_index + num_devices]) |
| merged_output.append(m) |
| current_index += num_devices |
| return merged_output |
| |
| |
| def _get_input_from_iterator(iterator, model): |
| """Get elements from the iterator and verify the input shape and type.""" |
| next_element = iterator.get_next() |
| |
| if isinstance(next_element, tuple): |
| x, y = next_element |
| else: |
| x = next_element |
| y = None |
| # Validate that all the elements in x and y are of the same type and shape. |
| # We can then pass the first element of x and y to `_standardize_weights` |
| # below and be confident of the output. |
| x_values, y_values = distributed_training_utils.\ |
| validate_distributed_dataset_inputs(model._distribution_strategy, x, y) |
| # TODO(sourabhbajaj): Add support for sample weights in distribution |
| # strategy. |
| model._standardize_weights(x_values, y_values) |
| return x, y |