| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| |
| """Classes and methods related to model_fn.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import collections |
| |
| import six |
| |
| from tensorflow.python.estimator.export import export_output as export_output_lib |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.keras.metrics import Metric |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.saved_model import signature_constants |
| from tensorflow.python.saved_model import tag_constants |
| from tensorflow.python.training import monitored_session |
| from tensorflow.python.training import session_run_hook |
| from tensorflow.python.util import nest |
| from tensorflow.python.util.tf_export import estimator_export |
| |
| |
| @estimator_export('estimator.ModeKeys') |
| class ModeKeys(object): |
| """Standard names for model modes. |
| |
| The following standard keys are defined: |
| |
| * `TRAIN`: training mode. |
| * `EVAL`: evaluation mode. |
| * `PREDICT`: inference mode. |
| """ |
| |
| TRAIN = 'train' |
| EVAL = 'eval' |
| PREDICT = 'infer' |
| |
| |
| LOSS_METRIC_KEY = 'loss' |
| AVERAGE_LOSS_METRIC_KEY = 'average_loss' |
| |
| # Mapping of the modes to appropriate tag_constants that are used for saving. |
| EXPORT_TAG_MAP = { |
| ModeKeys.PREDICT: [tag_constants.SERVING], |
| ModeKeys.TRAIN: [tag_constants.TRAINING], |
| ModeKeys.EVAL: [tag_constants.EVAL], |
| } |
| |
| |
| @estimator_export('estimator.EstimatorSpec') |
| class EstimatorSpec( |
| collections.namedtuple('EstimatorSpec', [ |
| 'mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops', |
| 'export_outputs', 'training_chief_hooks', 'training_hooks', 'scaffold', |
| 'evaluation_hooks', 'prediction_hooks' |
| ])): |
| """Ops and objects returned from a `model_fn` and passed to an `Estimator`. |
| |
| `EstimatorSpec` fully defines the model to be run by an `Estimator`. |
| """ |
| |
| def __new__(cls, |
| mode, |
| predictions=None, |
| loss=None, |
| train_op=None, |
| eval_metric_ops=None, |
| export_outputs=None, |
| training_chief_hooks=None, |
| training_hooks=None, |
| scaffold=None, |
| evaluation_hooks=None, |
| prediction_hooks=None): |
| """Creates a validated `EstimatorSpec` instance. |
| |
| Depending on the value of `mode`, different arguments are required. Namely |
| |
| * For `mode == ModeKeys.TRAIN`: required fields are `loss` and `train_op`. |
| * For `mode == ModeKeys.EVAL`: required field is `loss`. |
| * For `mode == ModeKeys.PREDICT`: required fields are `predictions`. |
| |
| model_fn can populate all arguments independent of mode. In this case, some |
| arguments will be ignored by an `Estimator`. E.g. `train_op` will be |
| ignored in eval and infer modes. Example: |
| |
| ```python |
| def my_model_fn(features, labels, mode): |
| predictions = ... |
| loss = ... |
| train_op = ... |
| return tf.estimator.EstimatorSpec( |
| mode=mode, |
| predictions=predictions, |
| loss=loss, |
| train_op=train_op) |
| ``` |
| |
| Alternatively, model_fn can just populate the arguments appropriate to the |
| given mode. Example: |
| |
| ```python |
| def my_model_fn(features, labels, mode): |
| if (mode == tf.estimator.ModeKeys.TRAIN or |
| mode == tf.estimator.ModeKeys.EVAL): |
| loss = ... |
| else: |
| loss = None |
| if mode == tf.estimator.ModeKeys.TRAIN: |
| train_op = ... |
| else: |
| train_op = None |
| if mode == tf.estimator.ModeKeys.PREDICT: |
| predictions = ... |
| else: |
| predictions = None |
| |
| return tf.estimator.EstimatorSpec( |
| mode=mode, |
| predictions=predictions, |
| loss=loss, |
| train_op=train_op) |
| ``` |
| |
| Args: |
| mode: A `ModeKeys`. Specifies if this is training, evaluation or |
| prediction. |
| predictions: Predictions `Tensor` or dict of `Tensor`. |
| loss: Training loss `Tensor`. Must be either scalar, or with shape `[1]`. |
| train_op: Op for the training step. |
| eval_metric_ops: Dict of metric results keyed by name. |
| The values of the dict can be one of the following: |
| (1) instance of `Metric` class. |
| (2) Results of calling a metric function, namely a |
| `(metric_tensor, update_op)` tuple. `metric_tensor` should be |
| evaluated without any impact on state (typically is a pure computation |
| results based on variables.). For example, it should not trigger the |
| `update_op` or requires any input fetching. |
| export_outputs: Describes the output signatures to be exported to |
| `SavedModel` and used during serving. |
| A dict `{name: output}` where: |
| * name: An arbitrary name for this output. |
| * output: an `ExportOutput` object such as `ClassificationOutput`, |
| `RegressionOutput`, or `PredictOutput`. |
| Single-headed models only need to specify one entry in this dictionary. |
| Multi-headed models should specify one entry for each head, one of |
| which must be named using |
| signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY. |
| If no entry is provided, a default `PredictOutput` mapping to |
| `predictions` will be created. |
| training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to |
| run on the chief worker during training. |
| training_hooks: Iterable of `tf.train.SessionRunHook` objects to run |
| on all workers during training. |
| scaffold: A `tf.train.Scaffold` object that can be used to set |
| initialization, saver, and more to be used in training. |
| evaluation_hooks: Iterable of `tf.train.SessionRunHook` objects to |
| run during evaluation. |
| prediction_hooks: Iterable of `tf.train.SessionRunHook` objects to |
| run during predictions. |
| |
| Returns: |
| A validated `EstimatorSpec` object. |
| |
| Raises: |
| ValueError: If validation fails. |
| TypeError: If any of the arguments is not the expected type. |
| """ |
| # Validate train_op. |
| if train_op is None: |
| if mode == ModeKeys.TRAIN: |
| raise ValueError('Missing train_op.') |
| else: |
| _check_is_tensor_or_operation(train_op, 'train_op') |
| |
| # Validate loss. |
| if loss is None: |
| if mode in (ModeKeys.TRAIN, ModeKeys.EVAL): |
| raise ValueError('Missing loss.') |
| else: |
| loss = _check_is_tensor(loss, 'loss') |
| loss_shape = loss.get_shape() |
| if loss_shape.num_elements() not in (None, 1): |
| raise ValueError('Loss must be scalar, given: {}'.format(loss)) |
| if not loss_shape.is_compatible_with(tensor_shape.scalar()): |
| loss = array_ops.reshape(loss, []) |
| |
| # Validate predictions. |
| if predictions is None: |
| if mode == ModeKeys.PREDICT: |
| raise ValueError('Missing predictions.') |
| predictions = {} |
| else: |
| if isinstance(predictions, dict): |
| predictions = { |
| k: _check_is_tensor(v, 'predictions[{}]'.format(k)) |
| for k, v in six.iteritems(predictions) |
| } |
| else: |
| predictions = _check_is_tensor(predictions, 'predictions') |
| |
| # Validate eval_metric_ops. |
| if eval_metric_ops is None: |
| eval_metric_ops = {} |
| else: |
| if not isinstance(eval_metric_ops, dict): |
| raise TypeError( |
| 'eval_metric_ops must be a dict, given: {}'.format(eval_metric_ops)) |
| for key, value in six.iteritems(eval_metric_ops): |
| # TODO(psv): When we deprecate the old metrics, throw an error here if |
| # the value is not an instance of `Metric` class. |
| if isinstance(value, Metric): |
| if not value.updates: # Check if metrics updates are available. |
| raise ValueError( |
| 'Please call update_state(...) on the "{metric_name}" metric' |
| .format(metric_name=value.name)) |
| else: |
| if not isinstance(value, tuple) or len(value) != 2: |
| raise TypeError( |
| 'Values of eval_metric_ops must be (metric_value, update_op) ' |
| 'tuples, given: {} for key: {}'.format(value, key)) |
| metric_value, metric_update = value |
| for metric_value_member in nest.flatten(metric_value): |
| # Allow (possibly nested) tuples for metric values, but require that |
| # each of them be Tensors or Operations. |
| _check_is_tensor_or_operation(metric_value_member, |
| 'eval_metric_ops[{}]'.format(key)) |
| _check_is_tensor_or_operation(metric_update, |
| 'eval_metric_ops[{}]'.format(key)) |
| |
| # Validate the passed export outputs, or generate defaults. |
| if mode == ModeKeys.PREDICT: |
| export_outputs = _get_export_outputs(export_outputs, predictions) |
| |
| # Validate that all tensors and ops are from the default graph. |
| default_graph = ops.get_default_graph() |
| |
| # We enumerate possible error causes here to aid in debugging. |
| error_message_template = ( |
| '{0} with "{1}" must be from the default graph. ' |
| 'Possible causes of this error include: \n\n' |
| '1) {0} was created outside the context of the default graph.' |
| '\n\n' |
| '2) The object passed through to EstimatorSpec was not created ' |
| 'in the most recent call to "model_fn".') |
| |
| if isinstance(predictions, dict): |
| for key, value in six.iteritems(predictions): |
| if value.graph is not default_graph: |
| raise ValueError(error_message_template.format( |
| 'prediction values', |
| '{0}: {1}'.format(key, value.name))) |
| elif predictions is not None: |
| # 'predictions' must be a single Tensor. |
| if predictions.graph is not default_graph: |
| raise ValueError(error_message_template.format( |
| 'prediction values', predictions.name)) |
| |
| if loss is not None and loss.graph is not default_graph: |
| raise ValueError(error_message_template.format('loss', loss.name)) |
| if train_op is not None and train_op.graph is not default_graph: |
| raise ValueError(error_message_template.format('train_op', train_op.name)) |
| for key, value in list(six.iteritems(eval_metric_ops)): |
| if isinstance(value, Metric): |
| values_to_check = value.updates[:] |
| values_to_check.append(value.result()) |
| else: |
| values_to_check = nest.flatten(value) |
| for val in values_to_check: |
| if val.graph is not default_graph: |
| raise ValueError(error_message_template.format( |
| 'eval_metric_ops', |
| '{0}: {1}'.format(key, val.name))) |
| |
| # Validate hooks. |
| training_chief_hooks = tuple(training_chief_hooks or []) |
| training_hooks = tuple(training_hooks or []) |
| evaluation_hooks = tuple(evaluation_hooks or []) |
| prediction_hooks = tuple(prediction_hooks or []) |
| |
| for hook in (training_hooks + training_chief_hooks + evaluation_hooks + |
| prediction_hooks): |
| if not isinstance(hook, session_run_hook.SessionRunHook): |
| raise TypeError( |
| 'All hooks must be SessionRunHook instances, given: {}'.format( |
| hook)) |
| |
| # Add metric variables to the `LOCAL_VARIABLES` collection. Metric variables |
| # are by default not added to any collections. We are doing this here, so |
| # that metric variables get initialized. |
| local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) |
| vars_to_add = set() |
| for key, value in six.iteritems(eval_metric_ops): |
| if isinstance(value, Metric): |
| vars_to_add.update(value.variables) |
| # Remove variables that are in the local variables collection already. |
| vars_to_add = vars_to_add.difference(local_vars) |
| for v in vars_to_add: |
| ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, v) |
| |
| scaffold = scaffold or monitored_session.Scaffold() |
| # Validate scaffold. |
| if not isinstance(scaffold, monitored_session.Scaffold): |
| raise TypeError( |
| 'scaffold must be tf.train.Scaffold. Given: {}'.format(scaffold)) |
| |
| return super(EstimatorSpec, cls).__new__( |
| cls, |
| mode=mode, |
| predictions=predictions, |
| loss=loss, |
| train_op=train_op, |
| eval_metric_ops=eval_metric_ops, |
| export_outputs=export_outputs, |
| training_chief_hooks=training_chief_hooks, |
| training_hooks=training_hooks, |
| scaffold=scaffold, |
| evaluation_hooks=evaluation_hooks, |
| prediction_hooks=prediction_hooks) |
| |
| def _replace(self, **kwds): |
| """Return a new EstimatorSpec replacing specified fields with new values.""" |
| if 'mode' in kwds: |
| if self.mode != kwds['mode']: |
| raise ValueError('mode of EstimatorSpec cannot be changed.') |
| new_fields = map(kwds.pop, self._fields, list(self)) |
| return EstimatorSpec(*new_fields) |
| |
| |
| def _get_export_outputs(export_outputs, predictions): |
| """Validate export_outputs or create default export_outputs. |
| |
| Args: |
| export_outputs: Describes the output signatures to be exported to |
| `SavedModel` and used during serving. Should be a dict or None. |
| predictions: Predictions `Tensor` or dict of `Tensor`. |
| |
| Returns: |
| Valid export_outputs dict |
| |
| Raises: |
| TypeError: if export_outputs is not a dict or its values are not |
| ExportOutput instances. |
| """ |
| if export_outputs is None: |
| default_output = export_output_lib.PredictOutput(predictions) |
| export_outputs = { |
| signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: default_output} |
| |
| if not isinstance(export_outputs, dict): |
| raise TypeError('export_outputs must be dict, given: {}'.format( |
| export_outputs)) |
| for v in six.itervalues(export_outputs): |
| if not isinstance(v, export_output_lib.ExportOutput): |
| raise TypeError( |
| 'Values in export_outputs must be ExportOutput objects. ' |
| 'Given: {}'.format(export_outputs)) |
| |
| _maybe_add_default_serving_output(export_outputs) |
| |
| return export_outputs |
| |
| |
| def _maybe_add_default_serving_output(export_outputs): |
| """Add a default serving output to the export_outputs if not present. |
| |
| Args: |
| export_outputs: Describes the output signatures to be exported to |
| `SavedModel` and used during serving. Should be a dict. |
| |
| Returns: |
| export_outputs dict with default serving signature added if necessary |
| |
| Raises: |
| ValueError: if multiple export_outputs were provided without a default |
| serving key. |
| """ |
| if len(export_outputs) == 1: |
| (key, value), = export_outputs.items() |
| if key != signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: |
| export_outputs[ |
| signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = value |
| if len(export_outputs) > 1: |
| if (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY |
| not in export_outputs): |
| raise ValueError( |
| 'Multiple export_outputs were provided, but none of them is ' |
| 'specified as the default. Do this by naming one of them with ' |
| 'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.') |
| |
| return export_outputs |
| |
| |
| class _TPUEstimatorSpec( |
| collections.namedtuple('TPUEstimatorSpec', [ |
| 'mode', 'predictions', 'loss', 'train_op', 'eval_metrics', |
| 'export_outputs', 'scaffold_fn', 'host_call', 'training_hooks', |
| 'evaluation_hooks', 'prediction_hooks' |
| ])): |
| """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`. |
| |
| This is a simplified implementation of `tf.contrib.tpu.EstimatorSpec`. See |
| tensorflow/contrib/tpu/python/tpu/tpu_estimator.py for more detailed |
| documentation. |
| """ |
| |
| def __new__(cls, |
| mode, |
| predictions=None, |
| loss=None, |
| train_op=None, |
| eval_metrics=None, |
| export_outputs=None, |
| scaffold_fn=None, |
| host_call=None, |
| training_hooks=None, |
| evaluation_hooks=None, |
| prediction_hooks=None): |
| """Creates a `_TPUEstimatorSpec` instance.""" |
| return super(_TPUEstimatorSpec, cls).__new__( |
| cls, |
| mode=mode, |
| predictions=predictions, |
| loss=loss, |
| train_op=train_op, |
| eval_metrics=eval_metrics, |
| export_outputs=export_outputs, |
| scaffold_fn=scaffold_fn, |
| host_call=host_call, |
| training_hooks=training_hooks, |
| evaluation_hooks=evaluation_hooks, |
| prediction_hooks=prediction_hooks) |
| |
| def as_estimator_spec(self): |
| """Creates an equivalent `EstimatorSpec` used by CPU train/eval.""" |
| if not self.eval_metrics: |
| eval_metric_ops = None |
| else: |
| metric_fn, tensors = self.eval_metrics |
| eval_metric_ops = metric_fn(**tensors) |
| return EstimatorSpec( |
| mode=self.mode, |
| predictions=self.predictions, |
| loss=self.loss, |
| train_op=self.train_op, |
| eval_metric_ops=eval_metric_ops, |
| export_outputs=self.export_outputs, |
| training_hooks=self.training_hooks, |
| evaluation_hooks=self.evaluation_hooks, |
| prediction_hooks=self.prediction_hooks) |
| |
| |
| def _check_is_tensor_or_operation(x, name): |
| if not (isinstance(x, ops.Operation) or isinstance(x, ops.Tensor)): |
| raise TypeError('{} must be Operation or Tensor, given: {}'.format(name, x)) |
| |
| |
| def _check_is_tensor(x, tensor_name): |
| """Returns `x` if it is a `Tensor`, raises TypeError otherwise.""" |
| if not isinstance(x, ops.Tensor): |
| raise TypeError('{} must be Tensor, given: {}'.format(tensor_name, x)) |
| return x |
| |
| |
| def export_outputs_for_mode( |
| mode, serving_export_outputs=None, predictions=None, loss=None, |
| metrics=None): |
| """Util function for constructing a `ExportOutput` dict given a mode. |
| |
| The returned dict can be directly passed to `build_all_signature_defs` helper |
| function as the `export_outputs` argument, used for generating a SignatureDef |
| map. |
| |
| Args: |
| mode: A `ModeKeys` specifying the mode. |
| serving_export_outputs: Describes the output signatures to be exported to |
| `SavedModel` and used during serving. Should be a dict or None. |
| predictions: A dict of Tensors or single Tensor representing model |
| predictions. This argument is only used if serving_export_outputs is not |
| set. |
| loss: A dict of Tensors or single Tensor representing calculated loss. |
| metrics: A dict of (metric_value, update_op) tuples, or a single tuple. |
| metric_value must be a Tensor, and update_op must be a Tensor or Op |
| |
| Returns: |
| Dictionary mapping the a key to an `tf.estimator.export.ExportOutput` object |
| The key is the expected SignatureDef key for the mode. |
| |
| Raises: |
| ValueError: if an appropriate ExportOutput cannot be found for the mode. |
| """ |
| # TODO(b/113185250): move all model export helper functions into an util file. |
| if mode == ModeKeys.PREDICT: |
| return _get_export_outputs(serving_export_outputs, predictions) |
| elif mode == ModeKeys.TRAIN: |
| return {mode: export_output_lib.TrainOutput( |
| loss=loss, predictions=predictions, metrics=metrics)} |
| elif mode == ModeKeys.EVAL: |
| return {mode: export_output_lib.EvalOutput( |
| loss=loss, predictions=predictions, metrics=metrics)} |
| else: |
| raise ValueError( |
| 'Export output type not found for mode: {}'.format(mode)) |