blob: 7243f150ce540cc96d1960511bc1500b7f917791 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A TFGAN-backed GAN Estimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import enum
from tensorflow.contrib.framework.python.ops import variables as variable_lib
from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
from tensorflow.contrib.gan.python import train as tfgan_train
from tensorflow.contrib.gan.python.eval.python import summaries as tfgan_summaries
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import tf_inspect as inspect
__all__ = [
'GANEstimator',
'SummaryType'
]
class SummaryType(enum.IntEnum):
NONE = 0
VARIABLES = 1
IMAGES = 2
IMAGE_COMPARISON = 3
_summary_type_map = {
SummaryType.VARIABLES: tfgan_summaries.add_gan_model_summaries,
SummaryType.IMAGES: tfgan_summaries.add_gan_model_image_summaries,
SummaryType.IMAGE_COMPARISON: tfgan_summaries.add_image_comparison_summaries, # pylint:disable=line-too-long
}
class GANEstimator(estimator.Estimator):
"""An estimator for Generative Adversarial Networks (GANs).
This Estimator is backed by TFGAN. The network functions follow the TFGAN API
except for one exception: if either `generator_fn` or `discriminator_fn` have
an argument called `mode`, then the tf.Estimator mode is passed in for that
argument. This helps with operations like batch normalization, which have
different train and evaluation behavior.
Example:
```python
import tensorflow as tf
tfgan = tf.contrib.gan
# See TFGAN's `train.py` for a description of the generator and
# discriminator API.
def generator_fn(generator_inputs):
...
return generated_data
def discriminator_fn(data, conditioning):
...
return logits
# Create GAN estimator.
gan_estimator = tfgan.estimator.GANEstimator(
model_dir,
generator_fn=generator_fn,
discriminator_fn=discriminator_fn,
generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
generator_optimizer=tf.train.AdamOptimizer(0.1, 0.5),
discriminator_optimizer=tf.train.AdamOptimizer(0.1, 0.5))
# Train estimator.
gan_estimator.train(train_input_fn, steps)
# Evaluate resulting estimator.
gan_estimator.evaluate(eval_input_fn)
# Generate samples from generator.
predictions = np.array([
x for x in gan_estimator.predict(predict_input_fn)])
```
"""
def __init__(self,
model_dir=None,
generator_fn=None,
discriminator_fn=None,
generator_loss_fn=None,
discriminator_loss_fn=None,
generator_optimizer=None,
discriminator_optimizer=None,
get_hooks_fn=None,
get_eval_metric_ops_fn=None,
add_summaries=None,
use_loss_summaries=True,
config=None):
"""Initializes a GANEstimator instance.
Args:
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator
to continue training a previously saved model.
generator_fn: A python function that takes a Tensor, Tensor list, or
Tensor dictionary as inputs and returns the outputs of the GAN
generator. See `TFGAN` for more details and examples. Additionally, if
it has an argument called `mode`, the Estimator's `mode` will be passed
in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch
normalization.
discriminator_fn: A python function that takes the output of
`generator_fn` or real data in the GAN setup, and `generator_inputs`.
Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details
and examples.
generator_loss_fn: The loss function on the generator. Takes a `GANModel`
tuple.
discriminator_loss_fn: The loss function on the discriminator. Takes a
`GANModel` tuple.
generator_optimizer: The optimizer for generator updates, or a function
that takes no arguments and returns an optimizer. This function will
be called when the default graph is the `GANEstimator`'s graph, so
utilities like `tf.contrib.framework.get_or_create_global_step` will
work.
discriminator_optimizer: Same as `generator_optimizer`, but for the
discriminator updates.
get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
list of hooks. These hooks are run on the generator and discriminator
train ops, and can be used to implement the GAN training scheme.
Defaults to `train.get_sequential_train_hooks()`.
get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
dict of metric results keyed by name. The output of this function is
passed into `tf.estimator.EstimatorSpec` during evaluation.
add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
If `None`, uses defaults.
config: `RunConfig` object to configure the runtime settings.
Raises:
ValueError: If loss functions aren't callable.
ValueError: If `use_loss_summaries` isn't boolean or `None`.
ValueError: If `get_hooks_fn` isn't callable or `None`.
"""
if not callable(generator_loss_fn):
raise ValueError('generator_loss_fn must be callable.')
if not callable(discriminator_loss_fn):
raise ValueError('discriminator_loss_fn must be callable.')
if use_loss_summaries not in [True, False, None]:
raise ValueError('use_loss_summaries must be True, False or None.')
if get_hooks_fn is not None and not callable(get_hooks_fn):
raise TypeError('get_hooks_fn must be callable.')
def _model_fn(features, labels, mode):
"""GANEstimator model function."""
if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL,
model_fn_lib.ModeKeys.PREDICT]:
raise ValueError('Mode not recognized: %s' % mode)
real_data = labels # rename inputs for clarity
generator_inputs = features # rename inputs for clarity
# Make GANModel, which encapsulates the GAN model architectures.
gan_model = _get_gan_model(
mode, generator_fn, discriminator_fn, real_data, generator_inputs,
add_summaries)
# Make the EstimatorSpec, which incorporates the GANModel, losses, eval
# metrics, and optimizers (if required).
return _get_estimator_spec(
mode, gan_model, generator_loss_fn, discriminator_loss_fn,
get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
get_hooks_fn, use_loss_summaries)
super(GANEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)
def _get_gan_model(
mode, generator_fn, discriminator_fn, real_data, generator_inputs,
add_summaries, generator_scope='Generator'):
"""Makes the GANModel tuple, which encapsulates the GAN model architecture."""
if mode == model_fn_lib.ModeKeys.PREDICT:
if real_data is not None:
raise ValueError('`labels` must be `None` when mode is `predict`. '
'Instead, found %s' % real_data)
gan_model = _make_prediction_gan_model(
generator_inputs, generator_fn, generator_scope)
else: # model_fn_lib.ModeKeys.TRAIN or model_fn_lib.ModeKeys.EVAL
gan_model = _make_gan_model(
generator_fn, discriminator_fn, real_data, generator_inputs,
generator_scope, add_summaries, mode)
return gan_model
def _get_estimator_spec(
mode, gan_model, generator_loss_fn, discriminator_loss_fn,
get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
get_hooks_fn=None, use_loss_summaries=True):
"""Get the EstimatorSpec for the current mode."""
if mode == model_fn_lib.ModeKeys.PREDICT:
estimator_spec = model_fn_lib.EstimatorSpec(
mode=mode, predictions=gan_model.generated_data)
else:
gan_loss = tfgan_tuples.GANLoss(
generator_loss=generator_loss_fn(
gan_model, add_summaries=use_loss_summaries),
discriminator_loss=discriminator_loss_fn(
gan_model, add_summaries=use_loss_summaries))
if mode == model_fn_lib.ModeKeys.EVAL:
estimator_spec = _get_eval_estimator_spec(
gan_model, gan_loss, get_eval_metric_ops_fn)
else: # model_fn_lib.ModeKeys.TRAIN:
gopt = (generator_optimizer() if callable(generator_optimizer) else
generator_optimizer)
dopt = (discriminator_optimizer() if callable(discriminator_optimizer)
else discriminator_optimizer)
get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks()
estimator_spec = _get_train_estimator_spec(
gan_model, gan_loss, gopt, dopt, get_hooks_fn)
return estimator_spec
def _make_gan_model(generator_fn, discriminator_fn, real_data,
generator_inputs, generator_scope, add_summaries, mode):
"""Construct a `GANModel`, and optionally pass in `mode`."""
# If network functions have an argument `mode`, pass mode to it.
if 'mode' in inspect.getargspec(generator_fn).args:
generator_fn = functools.partial(generator_fn, mode=mode)
if 'mode' in inspect.getargspec(discriminator_fn).args:
discriminator_fn = functools.partial(discriminator_fn, mode=mode)
gan_model = tfgan_train.gan_model(
generator_fn,
discriminator_fn,
real_data,
generator_inputs,
generator_scope=generator_scope,
check_shapes=False)
if add_summaries:
if not isinstance(add_summaries, (tuple, list)):
add_summaries = [add_summaries]
with ops.name_scope(None):
for summary_type in add_summaries:
_summary_type_map[summary_type](gan_model)
return gan_model
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
"""Make a `GANModel` from just the generator."""
# If `generator_fn` has an argument `mode`, pass mode to it.
if 'mode' in inspect.getargspec(generator_fn).args:
generator_fn = functools.partial(generator_fn,
mode=model_fn_lib.ModeKeys.PREDICT)
with variable_scope.variable_scope(generator_scope) as gen_scope:
generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs) # pylint:disable=protected-access
generated_data = generator_fn(generator_inputs)
generator_variables = variable_lib.get_trainable_variables(gen_scope)
return tfgan_tuples.GANModel(
generator_inputs,
generated_data,
generator_variables,
gen_scope,
generator_fn,
real_data=None,
discriminator_real_outputs=None,
discriminator_gen_outputs=None,
discriminator_variables=None,
discriminator_scope=None,
discriminator_fn=None)
def _get_eval_estimator_spec(gan_model, gan_loss, get_eval_metric_ops_fn=None,
name=None):
"""Return an EstimatorSpec for the eval case."""
scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
with ops.name_scope(None, 'metrics',
[gan_loss.generator_loss,
gan_loss.discriminator_loss]):
def _summary_key(head_name, val):
return '%s/%s' % (val, head_name) if head_name else val
eval_metric_ops = {
_summary_key(name, 'generator_loss'):
metrics_lib.mean(gan_loss.generator_loss),
_summary_key(name, 'discriminator_loss'):
metrics_lib.mean(gan_loss.discriminator_loss)
}
if get_eval_metric_ops_fn is not None:
custom_eval_metric_ops = get_eval_metric_ops_fn(gan_model)
if not isinstance(custom_eval_metric_ops, dict):
raise TypeError('get_eval_metric_ops_fn must return a dict, '
'received: {}'.format(custom_eval_metric_ops))
eval_metric_ops.update(custom_eval_metric_ops)
return model_fn_lib.EstimatorSpec(
mode=model_fn_lib.ModeKeys.EVAL,
predictions=gan_model.generated_data,
loss=scalar_loss,
eval_metric_ops=eval_metric_ops)
def _get_train_estimator_spec(
gan_model, gan_loss, generator_optimizer, discriminator_optimizer,
get_hooks_fn, train_op_fn=tfgan_train.gan_train_ops):
"""Return an EstimatorSpec for the train case."""
scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
train_ops = train_op_fn(gan_model, gan_loss, generator_optimizer,
discriminator_optimizer)
training_hooks = get_hooks_fn(train_ops)
return model_fn_lib.EstimatorSpec(
loss=scalar_loss,
mode=model_fn_lib.ModeKeys.TRAIN,
train_op=train_ops.global_step_inc_op,
training_hooks=training_hooks)