blob: ed96cfd198b16840d8640c21380cd82cea04779b [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilites for `Model.compile`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import six
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.keras import losses as losses_mod
from tensorflow.python.keras import metrics as metrics_mod
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.losses import util as tf_losses_utils
from tensorflow.python.util import nest
class LossesContainer(object):
"""A container class for losses passed to `Model.compile`."""
def __init__(self, losses, loss_weights=None, output_names=None):
# Keep user-supplied values untouched for recompiling and serialization.
self._user_losses = losses
self._user_loss_weights = loss_weights
self._losses = losses
self._loss_weights = loss_weights
self._output_names = output_names
self._per_output_metrics = None # Per-output losses become metrics.
self._loss_metric = metrics_mod.Mean(name='loss') # Total loss.
self._built = False
@property
def metrics(self):
"""Per-output loss metrics."""
if not self._built:
return []
per_output_metrics = [
metric_obj for metric_obj in nest.flatten(self._per_output_metrics)
if metric_obj is not None
]
return [self._loss_metric] + per_output_metrics
def _build(self, y_pred):
"""One-time setup of loss objects."""
if self._output_names is None:
# In Subclass API, output names like 'output_1' are used for
# `Metric` names.
self._output_names = create_pseudo_output_names(y_pred)
# Accept a dict of losses keyed by output_name when outputs are a flat
# list.
self._losses = map_to_output_names(y_pred, self._output_names, self._losses)
self._loss_weights = map_to_output_names(y_pred, self._output_names,
self._loss_weights)
# Broadcast single config values to apply to each output.
if not nest.is_sequence(self._losses):
self._losses = nest.map_structure(lambda output: self._losses, y_pred)
if not nest.is_sequence(self._loss_weights):
self._loss_weights = nest.map_structure(lambda output: self._loss_weights,
y_pred)
self._losses = nest.map_structure(self._get_loss_object, self._losses)
# Now that structures have been checked, it is safe to flatten.
self._losses = nest.flatten(self._losses)
self._loss_weights = nest.flatten(self._loss_weights)
# Create per-output loss metrics, but only for multi-output Models.
if len(self._output_names) == 1:
self._per_output_metrics = [None]
else:
self._per_output_metrics = []
for loss_obj, output_name in zip(self._losses, self._output_names):
if loss_obj is None:
self._per_output_metrics.append(None)
else:
self._per_output_metrics.append(
metrics_mod.Mean(output_name + '_loss'))
self._built = True
def __call__(self,
y_true,
y_pred,
sample_weight=None,
regularization_losses=None):
"""Computes the overall loss.
Arguments:
y_true: An arbitrary structure of Tensors representing the ground truth.
y_pred: An arbitrary structure of Tensors representing a Model's outputs.
sample_weight: An arbitrary structure of Tensors representing the
per-sample loss weights. If one Tensor is passed, it is used for all
losses. If multiple Tensors are passed, the structure should match
`y_pred`.
regularization_losses: Additional losses to be added to the total loss.
Returns:
Tuple of `(total_loss, per_output_loss_list)`
"""
y_true = map_to_output_names(y_pred, self._output_names, y_true)
sample_weight = map_to_output_names(y_pred, self._output_names,
sample_weight)
if not self._built:
self._build(y_pred)
y_true = nest.flatten(y_true) if y_true is not None else []
y_pred = nest.flatten(y_pred)
# TODO(omalleyt): Remove ambiguity here.
# This is currently needed to support passing only 1 loss and 1 target
# to a Functional Model with multiple outputs. However, this is
# ambiguous, especially with subclass, and we should reconsider how we
# support this.
if len(y_true) == 1 and len(y_pred) > 1:
y_true = y_true * len(y_pred)
sample_weight = nest.flatten(sample_weight)
# Allows passing one sample-weight array for all outputs.
if len(sample_weight) == 1 and len(y_pred) > 1:
sample_weight = sample_weight * len(y_pred)
loss_values = [] # Used for gradient calculation.
loss_metric_values = [] # Used for loss metric calculation.
zip_args = (y_true, y_pred, sample_weight, self._losses, self._loss_weights,
self._per_output_metrics)
for y_t, y_p, sw, loss_obj, loss_weight, metric_obj in zip(*zip_args):
if loss_obj is None: # Ok to have no loss for an output.
continue
y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
sw = apply_mask(y_p, sw)
loss_value = loss_obj(y_t, y_p, sample_weight=sw)
loss_metric_value = loss_value
# Correct for the `Mean` loss metrics counting each replica as a batch.
if loss_obj.reduction == losses_utils.ReductionV2.SUM:
loss_metric_value *= ds_context.get_strategy().num_replicas_in_sync
if metric_obj is not None:
metric_obj.update_state(loss_metric_value)
if loss_weight is not None:
loss_value *= loss_weight
loss_metric_value *= loss_weight
if (loss_obj.reduction == losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE or
loss_obj.reduction == losses_utils.ReductionV2.AUTO):
loss_value = losses_utils.scale_loss_for_distribution(loss_value)
loss_values.append(loss_value)
loss_metric_values.append(loss_metric_value)
if regularization_losses:
regularization_losses = losses_utils.cast_losses_to_common_dtype(
regularization_losses)
reg_loss = math_ops.add_n(regularization_losses)
loss_metric_values.append(reg_loss)
loss_values.append(losses_utils.scale_loss_for_distribution(reg_loss))
if loss_values:
loss_metric_values = losses_utils.cast_losses_to_common_dtype(
loss_metric_values)
total_loss_metric_value = math_ops.add_n(loss_metric_values)
self._loss_metric.update_state(total_loss_metric_value)
loss_values = losses_utils.cast_losses_to_common_dtype(loss_values)
total_loss = math_ops.add_n(loss_values)
return total_loss
else:
# Ok for a model to have no compiled loss.
return array_ops.zeros(shape=())
def _get_loss_object(self, loss):
"""Returns a `Loss` object.
Converts the user-supplied loss to a `Loss` object. Also allows
`SUM_OVER_BATCH_SIZE` reduction to be used for this loss.
Arguments:
loss: A string, function, or `Loss` object.
Returns:
A `Loss` object.
"""
if loss is None:
return None # Ok to have no loss for an output.
loss = losses_mod.get(loss)
if not isinstance(loss, losses_mod.Loss):
loss_name = loss.__name__
loss = losses_mod.LossFunctionWrapper(loss, name=loss_name)
loss._allow_sum_over_batch_size = True # pylint: disable=protected-access
return loss
class MetricsContainer(object):
"""A container class for metrics passed to `Model.compile`."""
def __init__(self, metrics=None, weighted_metrics=None, output_names=None):
# Keep user-supplied values untouched for recompiling and serialization.
self._user_metrics = metrics
self._user_weighted_metrics = weighted_metrics
self._metrics = metrics
self._weighted_metrics = weighted_metrics
self._output_names = output_names
self._built = False
@property
def metrics(self):
"""Metrics created by this container."""
if not self._built:
return []
return self._metrics_in_order
def _build(self, y_pred, y_true):
"""One-time setup of metric objects."""
if self._output_names is None:
# Subclass output names like 'output_1' are used for `Metric` names.
self._output_names = create_pseudo_output_names(y_pred)
# If a single metric or flat list of metrics, apply to all outputs.
self._metrics = self._maybe_broadcast(self._metrics, y_pred)
self._weighted_metrics = self._maybe_broadcast(self._weighted_metrics,
y_pred)
# Accept a dict of metrics keyed by output_name when outputs are a flat
# list.
self._metrics = map_to_output_names(y_pred, self._output_names,
self._metrics)
self._weighted_metrics = map_to_output_names(y_pred, self._output_names,
self._weighted_metrics)
# Standardize on tuple since `tf.data` turns lists into `Tensor`s.
# pylint: disable=protected-access
y_pred = nest._list_to_tuple(y_pred)
y_true = nest._list_to_tuple(y_true)
self._metrics = nest._list_to_tuple(self._metrics)
self._weighted_metrics = nest._list_to_tuple(self._weighted_metrics)
# pylint: enable=protected-access
# Convert to `Metric` objects, potentially disambiguating based on output
# properties.
self._metrics = nest.map_structure_up_to(y_pred, self._get_metric_objects,
self._metrics, y_true, y_pred)
self._weighted_metrics = nest.map_structure_up_to(y_pred,
self._get_metric_objects,
self._weighted_metrics,
y_true, y_pred)
self._metrics = nest.flatten_up_to(y_pred, self._metrics, check_types=False)
self._weighted_metrics = nest.flatten_up_to(
y_pred, self._weighted_metrics, check_types=False)
# Assumes metrics, weighted_metrics have been flattened up to outputs.
self._set_metric_names()
# Cache the flat order needed when returning metrics, for backwards compat.
self._metrics_in_order = []
for output_metrics, output_weighted_metrics in zip(self._metrics,
self._weighted_metrics):
for m in nest.flatten(output_metrics):
if m is not None:
self._metrics_in_order.append(m)
for wm in nest.flatten(output_weighted_metrics):
if wm is not None:
self._metrics_in_order.append(wm)
self._built = True
def _set_metric_names(self):
"""Sets unique metric names."""
# For multi-output models, prepend the output name to the metric name.
# For weighted metrics, prepend "weighted_" if the name would be non-unique.
# pylint: disable=protected-access
metric_names = set()
is_multi_output = len(self._output_names) > 1
zip_args = (self._output_names, self._metrics, self._weighted_metrics)
for output_name, output_metrics, weighted_output_metrics in zip(*zip_args):
for m in output_metrics:
if m is None:
continue
if is_multi_output:
m._name = output_name + '_' + m._name
if m._name in metric_names:
raise ValueError('Found two metrics with the same name: {}'.format(
m._name))
metric_names.add(m._name)
for wm in weighted_output_metrics:
if wm is None:
continue
if is_multi_output:
if output_name + '_' + wm._name in metric_names:
wm._name = output_name + '_weighted_' + wm._name
else:
wm._name = output_name + '_' + wm._name
elif wm._name in metric_names:
wm._name = 'weighted_' + wm._name
if wm._name in metric_names:
raise ValueError('Found two metrics with the same name: {}'.format(
wm._name))
metric_names.add(wm._name)
# pylint: enable=protected-access
def update_state(self, y_true, y_pred, sample_weight=None):
"""Updates the state of per-output metrics."""
y_true = map_to_output_names(y_pred, self._output_names, y_true)
sample_weight = map_to_output_names(y_pred, self._output_names,
sample_weight)
flat_y_true = nest.flatten(y_true) if y_true is not None else []
flat_y_pred = nest.flatten(y_pred)
if not flat_y_true:
return # Handle case where no targets are passed.
# TODO(omalleyt): Remove ambiguity here (see LossesContainer).
if len(flat_y_true) == 1 and len(flat_y_pred) > 1:
y_true = nest.map_structure(lambda _: flat_y_true[0], y_pred)
flat_y_true = nest.flatten(y_true)
if not self._built:
# `_build` needs unflattened outputs and labels.
self._build(y_pred, y_true)
y_true = flat_y_true
y_pred = flat_y_pred
sample_weight = nest.flatten(sample_weight)
# Allows passing one sample-weight array for all outputs.
if len(sample_weight) == 1 and len(y_pred) > 1:
sample_weight = sample_weight * len(y_pred)
zip_args = (y_true, y_pred, sample_weight, self._metrics,
self._weighted_metrics)
for y_t, y_p, sw, metric_objs, weighted_metric_objs in zip(*zip_args):
y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
sw = apply_mask(y_p, sw)
for metric_obj in metric_objs:
if metric_obj is None:
continue
metric_obj.update_state(y_t, y_p)
for weighted_metric_obj in weighted_metric_objs:
if weighted_metric_obj is None:
continue
weighted_metric_obj.update_state(y_t, y_p, sample_weight=sw)
def _get_metric_objects(self, metrics, y_t, y_p):
"""Convert user-supplied metrics to `Metric` objects."""
metrics = nest.flatten(metrics)
return [self._get_metric_object(m, y_t, y_p) for m in metrics]
def _get_metric_object(self, metric, y_t, y_p):
"""Converts user-supplied metric to a `Metric` object.
Arguments:
metric: A string, function, or `Metric` object.
y_t: Sample of label.
y_p: Sample of output.
Returns:
A `Metric` object.
"""
if metric is None:
return None # Ok to have no metric for an output.
# Convenience feature for selecting b/t binary, categorical,
# and sparse categorical.
if metric not in ['accuracy', 'acc', 'crossentropy', 'ce']:
metric_obj = metrics_mod.get(metric)
else:
y_t_rank = len(y_t.shape.as_list())
y_p_rank = len(y_p.shape.as_list())
y_t_last_dim = y_t.shape.as_list()[-1]
y_p_last_dim = y_p.shape.as_list()[-1]
is_binary = y_p_last_dim == 1
is_sparse_categorical = (
y_t_rank < y_p_rank or y_t_last_dim == 1 and y_p_last_dim > 1)
if metric in ['accuracy', 'acc']:
if is_binary:
metric_obj = metrics_mod.binary_accuracy
elif is_sparse_categorical:
metric_obj = metrics_mod.sparse_categorical_accuracy
else:
metric_obj = metrics_mod.categorical_accuracy
else:
if is_binary:
metric_obj = metrics_mod.binary_crossentropy
elif is_sparse_categorical:
metric_obj = metrics_mod.sparse_categorical_crossentropy
else:
metric_obj = metrics_mod.categorical_crossentropy
if not isinstance(metric_obj, metrics_mod.Metric):
if isinstance(metric, six.string_types):
metric_name = metric
elif hasattr(metric, 'name'):
metric_name = metric.name # TODO(omalleyt): Is this needed?
else:
# function was passed.
metric_name = metric.__name__
metric_obj = metrics_mod.MeanMetricWrapper(metric_obj, name=metric_name)
return metric_obj
def _maybe_broadcast(self, metrics, y_pred):
"""If a flat list of Metrics is supplied, apply them to all outputs."""
def _should_broadcast(metrics):
# e.g. 'mse'.
if not nest.is_sequence(metrics):
return True
# e.g. ['mse'] or ['mse', 'mae'].
return (isinstance(metrics, (list, tuple)) and
not any(nest.is_sequence(m) for m in metrics))
if _should_broadcast(metrics):
copy_metrics = len(nest.flatten(y_pred)) > 1
def _maybe_copy(m):
if copy_metrics and isinstance(m, metrics_mod.Metric):
return m.__class__.from_config(m.get_config())
return m
metrics = nest.flatten(metrics)
return nest.map_structure(lambda _: [_maybe_copy(m) for m in metrics],
y_pred)
return metrics
def create_pseudo_output_names(outputs):
"""Create pseudo output names for a subclassed Model."""
return _create_pseudo_names(outputs, prefix='output_')
def create_pseudo_input_names(inputs):
"""Create pseudo input names for a subclassed Model."""
return _create_pseudo_names(inputs, prefix='input_')
def _create_pseudo_names(tensors, prefix):
"""Creates pseudo {input | output} names for subclassed Models.
Warning: this function should only be used to define default
names for `Metics` and `SavedModel`. No other use cases should
rely on a `Model`'s input or output names.
Example with dict:
`{'a': [x1, x2], 'b': x3}` becomes:
`['a_1', 'a_2', 'b']`
Example with list:
`[x, y]` becomes:
`['output_1', 'output_2']`
Arguments:
tensors: `Model`'s outputs or inputs.
prefix: 'output_' for outputs, 'input_' for inputs.
Returns:
Flattened list of pseudo names.
"""
def one_index(ele):
# Start with "output_1" instead of "output_0".
if isinstance(ele, int):
return ele + 1
return ele
flat_paths = list(nest.yield_flat_paths(tensors))
flat_paths = nest.map_structure(one_index, flat_paths)
names = []
for path in flat_paths:
if not path:
name = prefix + '1' # Single output.
else:
name = '_'.join(str(p) for p in path)
if isinstance(path[0], int):
name = prefix + name
names.append(name)
return names
def map_to_output_names(y_pred, output_names, struct):
"""Maps a dict to a list using `output_names` as keys.
This is a convenience feature only. When a `Model`'s outputs
are a list, you can specify per-output losses and metrics as
a dict, where the keys are the output names. If you specify
per-output losses and metrics via the same structure as the
`Model`'s outputs (recommended), no mapping is performed.
For the Functional API, the output names are the names of the
last layer of each output. For the Subclass API, the output names
are determined by `create_pseudo_output_names` (For example:
`['output_1', 'output_2']` for a list of outputs).
This mapping preserves backwards compatibility for `compile` and
`fit`.
Arguments:
y_pred: Sample outputs of the Model, to determine if this convenience
feature should be applied (`struct` is returned unmodified if `y_pred`
isn't a flat list).
output_names: List. The names of the outputs of the Model.
struct: The structure to map.
Returns:
`struct` mapped to a list in same order as `output_names`.
"""
outputs_are_flat_list = (
isinstance(y_pred, (list, tuple)) and
not any(nest.is_sequence(y_p) for y_p in y_pred))
single_output = not nest.is_sequence(y_pred)
if (single_output or outputs_are_flat_list) and isinstance(struct, dict):
output_names = output_names or create_pseudo_output_names(y_pred)
struct = copy.copy(struct)
new_struct = [struct.pop(name, None) for name in output_names]
if struct:
raise ValueError('Found unexpected keys that do not correspond '
'to any Model output: {}. Expected: {}'.format(
struct.keys(), output_names))
if len(new_struct) == 1:
return new_struct[0]
return new_struct
else:
return struct
def match_dtype_and_rank(y_t, y_p, sw):
"""Match dtype and rank of predictions."""
if y_t.shape.rank == 1 and y_p.shape.rank == 2:
y_t = array_ops.expand_dims_v2(y_t, axis=-1)
if sw is not None:
if sw.shape.rank == 1 and y_p.shape.rank == 2:
sw = array_ops.expand_dims_v2(sw, axis=-1)
# Dtype.
y_t = math_ops.cast(y_t, y_p.dtype)
if sw is not None:
sw = math_ops.cast(sw, y_p.dtype)
return y_t, y_p, sw
def apply_mask(y_p, sw):
"""Applies any mask on predictions to sample weights."""
# Handle Keras mask on outputs.
mask = getattr(y_p, '_keras_mask', None)
if mask is not None:
mask = math_ops.cast(mask, y_p.dtype)
if sw is not None:
mask, _, sw = (
tf_losses_utils.squeeze_or_expand_dimensions(mask, sample_weight=sw))
sw *= mask
else:
sw = mask
return sw