blob: 58f348034fdcaadd8d738517aef2a7e2f0172c13 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for gan.python.train."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.contrib import layers
from tensorflow.contrib.framework.python.ops import variables as variables_lib
from tensorflow.contrib.gan.python import namedtuples
from tensorflow.contrib.gan.python import train
from tensorflow.contrib.gan.python.features.python import random_tensor_pool
from tensorflow.contrib.slim.python.slim import learning as slim_learning
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.distributions import categorical
from tensorflow.python.platform import test
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import coordinator
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import sync_replicas_optimizer
from tensorflow.python.training import training_util
def generator_model(inputs):
return variable_scope.get_variable('dummy_g', initializer=2.0) * inputs
class Generator(object):
def __call__(self, inputs):
return generator_model(inputs)
def infogan_generator_model(inputs):
return variable_scope.get_variable('dummy_g', initializer=2.0) * inputs[0]
class InfoGANGenerator(object):
def __call__(self, inputs):
return infogan_generator_model(inputs)
def discriminator_model(inputs, _):
return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs
class Discriminator(object):
def __call__(self, inputs, _):
return discriminator_model(inputs, _)
def infogan_discriminator_model(inputs, _):
return (variable_scope.get_variable('dummy_d', initializer=2.0) * inputs,
[categorical.Categorical([1.0])])
class InfoGANDiscriminator(object):
def __call__(self, inputs, _):
return infogan_discriminator_model(inputs, _)
def acgan_discriminator_model(inputs, _, num_classes=10):
return (
discriminator_model(inputs, _),
array_ops.one_hot(
# TODO(haeusser): infer batch size from input
random_ops.random_uniform(
[3], maxval=num_classes, dtype=dtypes.int32),
num_classes))
class ACGANDiscriminator(object):
def __call__(self, inputs, _, num_classes=10):
return (
discriminator_model(inputs, _),
array_ops.one_hot(
# TODO(haeusser): infer batch size from input
random_ops.random_uniform(
[3], maxval=num_classes, dtype=dtypes.int32),
num_classes))
def stargan_generator_model(inputs, _):
"""Dummy generator for StarGAN."""
return variable_scope.get_variable('dummy_g', initializer=0.5) * inputs
class StarGANGenerator(object):
def __call__(self, inputs, _):
return stargan_generator_model(inputs, _)
def stargan_discriminator_model(inputs, num_domains):
"""Differentiable dummy discriminator for StarGAN."""
hidden = layers.flatten(inputs)
output_src = math_ops.reduce_mean(hidden, axis=1)
output_cls = layers.fully_connected(
inputs=hidden,
num_outputs=num_domains,
activation_fn=None,
normalizer_fn=None,
biases_initializer=None)
return output_src, output_cls
class StarGANDiscriminator(object):
def __call__(self, inputs, num_domains):
return stargan_discriminator_model(inputs, num_domains)
def get_gan_model():
# TODO(joelshor): Find a better way of creating a variable scope.
with variable_scope.variable_scope('generator') as gen_scope:
pass
with variable_scope.variable_scope('discriminator') as dis_scope:
pass
return namedtuples.GANModel(
generator_inputs=None,
generated_data=None,
generator_variables=None,
generator_scope=gen_scope,
generator_fn=generator_model,
real_data=array_ops.ones([1, 2, 3]),
discriminator_real_outputs=array_ops.ones([1, 2, 3]),
discriminator_gen_outputs=array_ops.ones([1, 2, 3]),
discriminator_variables=None,
discriminator_scope=dis_scope,
discriminator_fn=discriminator_model)
def get_callable_gan_model():
ganmodel = get_gan_model()
return ganmodel._replace(
generator_fn=Generator(), discriminator_fn=Discriminator())
def create_gan_model():
return train.gan_model(
generator_model,
discriminator_model,
real_data=array_ops.zeros([1, 2]),
generator_inputs=random_ops.random_normal([1, 2]))
def create_callable_gan_model():
return train.gan_model(
Generator(),
Discriminator(),
real_data=array_ops.zeros([1, 2]),
generator_inputs=random_ops.random_normal([1, 2]))
def get_infogan_model():
return namedtuples.InfoGANModel(
*get_gan_model(),
structured_generator_inputs=[constant_op.constant(0)],
predicted_distributions=[categorical.Categorical([1.0])],
discriminator_and_aux_fn=infogan_discriminator_model)
def get_callable_infogan_model():
return namedtuples.InfoGANModel(
*get_callable_gan_model(),
structured_generator_inputs=[constant_op.constant(0)],
predicted_distributions=[categorical.Categorical([1.0])],
discriminator_and_aux_fn=infogan_discriminator_model)
def create_infogan_model():
return train.infogan_model(
infogan_generator_model,
infogan_discriminator_model,
real_data=array_ops.zeros([1, 2]),
unstructured_generator_inputs=[],
structured_generator_inputs=[random_ops.random_normal([1, 2])])
def create_callable_infogan_model():
return train.infogan_model(
InfoGANGenerator(),
InfoGANDiscriminator(),
real_data=array_ops.zeros([1, 2]),
unstructured_generator_inputs=[],
structured_generator_inputs=[random_ops.random_normal([1, 2])])
def get_acgan_model():
return namedtuples.ACGANModel(
*get_gan_model(),
one_hot_labels=array_ops.one_hot([0, 1, 2], 10),
discriminator_real_classification_logits=array_ops.one_hot([0, 1, 3], 10),
discriminator_gen_classification_logits=array_ops.one_hot([0, 1, 4], 10))
def get_callable_acgan_model():
return namedtuples.ACGANModel(
*get_callable_gan_model(),
one_hot_labels=array_ops.one_hot([0, 1, 2], 10),
discriminator_real_classification_logits=array_ops.one_hot([0, 1, 3], 10),
discriminator_gen_classification_logits=array_ops.one_hot([0, 1, 4], 10))
def create_acgan_model():
return train.acgan_model(
generator_model,
acgan_discriminator_model,
real_data=array_ops.zeros([1, 2]),
generator_inputs=random_ops.random_normal([1, 2]),
one_hot_labels=array_ops.one_hot([0, 1, 2], 10))
def create_callable_acgan_model():
return train.acgan_model(
Generator(),
ACGANDiscriminator(),
real_data=array_ops.zeros([1, 2]),
generator_inputs=random_ops.random_normal([1, 2]),
one_hot_labels=array_ops.one_hot([0, 1, 2], 10))
def get_cyclegan_model():
return namedtuples.CycleGANModel(
model_x2y=get_gan_model(),
model_y2x=get_gan_model(),
reconstructed_x=array_ops.ones([1, 2, 3]),
reconstructed_y=array_ops.zeros([1, 2, 3]))
def get_callable_cyclegan_model():
return namedtuples.CycleGANModel(
model_x2y=get_callable_gan_model(),
model_y2x=get_callable_gan_model(),
reconstructed_x=array_ops.ones([1, 2, 3]),
reconstructed_y=array_ops.zeros([1, 2, 3]))
def create_cyclegan_model():
return train.cyclegan_model(
generator_model,
discriminator_model,
data_x=array_ops.zeros([1, 2]),
data_y=array_ops.ones([1, 2]))
def create_callable_cyclegan_model():
return train.cyclegan_model(
Generator(),
Discriminator(),
data_x=array_ops.zeros([1, 2]),
data_y=array_ops.ones([1, 2]))
def get_stargan_model():
"""Similar to get_gan_model()."""
# TODO(joelshor): Find a better way of creating a variable scope.
with variable_scope.variable_scope('generator') as gen_scope:
pass
with variable_scope.variable_scope('discriminator') as dis_scope:
pass
return namedtuples.StarGANModel(
input_data=array_ops.ones([1, 2, 2, 3]),
input_data_domain_label=array_ops.ones([1, 2]),
generated_data=array_ops.ones([1, 2, 2, 3]),
generated_data_domain_target=array_ops.ones([1, 2]),
reconstructed_data=array_ops.ones([1, 2, 2, 3]),
discriminator_input_data_source_predication=array_ops.ones([1]),
discriminator_generated_data_source_predication=array_ops.ones([1]),
discriminator_input_data_domain_predication=array_ops.ones([1, 2]),
discriminator_generated_data_domain_predication=array_ops.ones([1, 2]),
generator_variables=None,
generator_scope=gen_scope,
generator_fn=stargan_generator_model,
discriminator_variables=None,
discriminator_scope=dis_scope,
discriminator_fn=stargan_discriminator_model)
def get_callable_stargan_model():
model = get_stargan_model()
return model._replace(
generator_fn=StarGANGenerator(), discriminator_fn=StarGANDiscriminator())
def create_stargan_model():
return train.stargan_model(
stargan_generator_model, stargan_discriminator_model,
array_ops.ones([1, 2, 2, 3]), array_ops.ones([1, 2]))
def create_callable_stargan_model():
return train.stargan_model(StarGANGenerator(), StarGANDiscriminator(),
array_ops.ones([1, 2, 2, 3]),
array_ops.ones([1, 2]))
def get_sync_optimizer():
return sync_replicas_optimizer.SyncReplicasOptimizer(
gradient_descent.GradientDescentOptimizer(learning_rate=1.0),
replicas_to_aggregate=1)
class GANModelTest(test.TestCase, parameterized.TestCase):
"""Tests for `gan_model`."""
@parameterized.named_parameters(
('gan', get_gan_model, namedtuples.GANModel),
('callable_gan', get_callable_gan_model, namedtuples.GANModel),
('infogan', get_infogan_model, namedtuples.InfoGANModel),
('callable_infogan', get_callable_infogan_model,
namedtuples.InfoGANModel),
('acgan', get_acgan_model, namedtuples.ACGANModel),
('callable_acgan', get_callable_acgan_model, namedtuples.ACGANModel),
('cyclegan', get_cyclegan_model, namedtuples.CycleGANModel),
('callable_cyclegan', get_callable_cyclegan_model,
namedtuples.CycleGANModel),
('stargan', get_stargan_model, namedtuples.StarGANModel),
('callabel_stargan', get_callable_stargan_model, namedtuples.StarGANModel)
)
def test_output_type(self, create_fn, expected_tuple_type):
"""Test that output type is as expected."""
self.assertIsInstance(create_fn(), expected_tuple_type)
def test_no_shape_check(self):
def dummy_generator_model(_):
return (None, None)
def dummy_discriminator_model(data, conditioning): # pylint: disable=unused-argument
return 1
with self.assertRaisesRegexp(AttributeError, 'object has no attribute'):
train.gan_model(
dummy_generator_model,
dummy_discriminator_model,
real_data=array_ops.zeros([1, 2]),
generator_inputs=array_ops.zeros([1]),
check_shapes=True)
train.gan_model(
dummy_generator_model,
dummy_discriminator_model,
real_data=array_ops.zeros([1, 2]),
generator_inputs=array_ops.zeros([1]),
check_shapes=False)
class StarGANModelTest(test.TestCase):
"""Tests for `stargan_model`."""
@staticmethod
def create_input_and_label_tensor(batch_size, img_size, c_size, num_domains):
input_tensor_list = []
label_tensor_list = []
for _ in range(num_domains):
input_tensor_list.append(
random_ops.random_uniform((batch_size, img_size, img_size, c_size)))
domain_idx = random_ops.random_uniform(
[batch_size], minval=0, maxval=num_domains, dtype=dtypes.int32)
label_tensor_list.append(array_ops.one_hot(domain_idx, num_domains))
return input_tensor_list, label_tensor_list
def test_generate_stargan_random_domain_target(self):
batch_size = 8
domain_numbers = 3
target_tensor = train._generate_stargan_random_domain_target(
batch_size, domain_numbers)
with self.test_session() as sess:
targets = sess.run(target_tensor)
self.assertTupleEqual((batch_size, domain_numbers), targets.shape)
for target in targets:
self.assertEqual(1, np.sum(target))
self.assertEqual(1, np.max(target))
def test_stargan_model_output_type(self):
batch_size = 2
img_size = 16
c_size = 3
num_domains = 5
input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor(
batch_size, img_size, c_size, num_domains)
model = train.stargan_model(
generator_fn=stargan_generator_model,
discriminator_fn=stargan_discriminator_model,
input_data=input_tensor,
input_data_domain_label=label_tensor)
self.assertIsInstance(model, namedtuples.StarGANModel)
self.assertTrue(isinstance(model.discriminator_variables, list))
self.assertTrue(isinstance(model.generator_variables, list))
self.assertIsInstance(model.discriminator_scope,
variable_scope.VariableScope)
self.assertTrue(model.generator_scope, variable_scope.VariableScope)
self.assertTrue(callable(model.discriminator_fn))
self.assertTrue(callable(model.generator_fn))
def test_stargan_model_generator_output(self):
batch_size = 2
img_size = 16
c_size = 3
num_domains = 5
input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor(
batch_size, img_size, c_size, num_domains)
model = train.stargan_model(
generator_fn=stargan_generator_model,
discriminator_fn=stargan_discriminator_model,
input_data=input_tensor,
input_data_domain_label=label_tensor)
with self.test_session(use_gpu=True) as sess:
sess.run(variables.global_variables_initializer())
input_data, generated_data, reconstructed_data = sess.run(
[model.input_data, model.generated_data, model.reconstructed_data])
self.assertTupleEqual(
(batch_size * num_domains, img_size, img_size, c_size),
input_data.shape)
self.assertTupleEqual(
(batch_size * num_domains, img_size, img_size, c_size),
generated_data.shape)
self.assertTupleEqual(
(batch_size * num_domains, img_size, img_size, c_size),
reconstructed_data.shape)
def test_stargan_model_discriminator_output(self):
batch_size = 2
img_size = 16
c_size = 3
num_domains = 5
input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor(
batch_size, img_size, c_size, num_domains)
model = train.stargan_model(
generator_fn=stargan_generator_model,
discriminator_fn=stargan_discriminator_model,
input_data=input_tensor,
input_data_domain_label=label_tensor)
with self.test_session(use_gpu=True) as sess:
sess.run(variables.global_variables_initializer())
disc_input_data_source_pred, disc_gen_data_source_pred = sess.run([
model.discriminator_input_data_source_predication,
model.discriminator_generated_data_source_predication
])
self.assertEqual(1, len(disc_input_data_source_pred.shape))
self.assertEqual(batch_size * num_domains,
disc_input_data_source_pred.shape[0])
self.assertEqual(1, len(disc_gen_data_source_pred.shape))
self.assertEqual(batch_size * num_domains,
disc_gen_data_source_pred.shape[0])
input_label, disc_input_label, gen_label, disc_gen_label = sess.run([
model.input_data_domain_label,
model.discriminator_input_data_domain_predication,
model.generated_data_domain_target,
model.discriminator_generated_data_domain_predication
])
self.assertTupleEqual((batch_size * num_domains, num_domains),
input_label.shape)
self.assertTupleEqual((batch_size * num_domains, num_domains),
disc_input_label.shape)
self.assertTupleEqual((batch_size * num_domains, num_domains),
gen_label.shape)
self.assertTupleEqual((batch_size * num_domains, num_domains),
disc_gen_label.shape)
class GANLossTest(test.TestCase, parameterized.TestCase):
"""Tests for `gan_loss`."""
@parameterized.named_parameters(
('gan', get_gan_model),
('callable_gan', get_callable_gan_model),
('infogan', get_infogan_model),
('callable_infogan', get_callable_infogan_model),
('acgan', get_acgan_model),
('callable_acgan', get_callable_acgan_model),
)
def test_output_type(self, get_gan_model_fn):
"""Test output type."""
loss = train.gan_loss(get_gan_model_fn(), add_summaries=True)
self.assertIsInstance(loss, namedtuples.GANLoss)
self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
@parameterized.named_parameters(
('cyclegan', create_cyclegan_model),
('callable_cyclegan', create_callable_cyclegan_model),
)
def test_cyclegan_output_type(self, get_gan_model_fn):
loss = train.cyclegan_loss(get_gan_model_fn(), add_summaries=True)
self.assertIsInstance(loss, namedtuples.CycleGANLoss)
self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
@parameterized.named_parameters(
('gan', create_gan_model, False),
('gan_one_sided', create_gan_model, True),
('callable_gan', create_callable_gan_model, False),
('callable_gan_one_sided', create_callable_gan_model, True),
('infogan', create_infogan_model, False),
('infogan_one_sided', create_infogan_model, True),
('callable_infogan', create_callable_infogan_model, False),
('callable_infogan_one_sided', create_callable_infogan_model, True),
('acgan', create_acgan_model, False),
('acgan_one_sided', create_acgan_model, True),
('callable_acgan', create_callable_acgan_model, False),
('callable_acgan_one_sided', create_callable_acgan_model, True),
)
def test_grad_penalty(self, create_gan_model_fn, one_sided):
"""Test gradient penalty option."""
model = create_gan_model_fn()
loss = train.gan_loss(model)
loss_gp = train.gan_loss(
model,
gradient_penalty_weight=1.0,
gradient_penalty_one_sided=one_sided)
self.assertIsInstance(loss_gp, namedtuples.GANLoss)
# Check values.
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
loss_gen_np, loss_gen_gp_np = sess.run(
[loss.generator_loss, loss_gp.generator_loss])
loss_dis_np, loss_dis_gp_np = sess.run(
[loss.discriminator_loss, loss_gp.discriminator_loss])
self.assertEqual(loss_gen_np, loss_gen_gp_np)
self.assertLess(loss_dis_np, loss_dis_gp_np)
@parameterized.named_parameters(
('infogan', get_infogan_model),
('callable_infogan', get_callable_infogan_model),
)
def test_mutual_info_penalty(self, create_gan_model_fn):
"""Test mutual information penalty option."""
train.gan_loss(
create_gan_model_fn(),
mutual_information_penalty_weight=constant_op.constant(1.0))
@parameterized.named_parameters(
('gan', get_gan_model),
('callable_gan', get_callable_gan_model),
('infogan', get_infogan_model),
('callable_infogan', get_callable_infogan_model),
('acgan', get_acgan_model),
('callable_acgan', get_callable_acgan_model),
)
def test_regularization_helper(self, get_gan_model_fn):
"""Test regularization loss."""
# Evaluate losses without regularization.
no_reg_loss = train.gan_loss(get_gan_model_fn())
with self.test_session(use_gpu=True):
no_reg_loss_gen_np = no_reg_loss.generator_loss.eval()
no_reg_loss_dis_np = no_reg_loss.discriminator_loss.eval()
with ops.name_scope(get_gan_model_fn().generator_scope.name):
ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES,
constant_op.constant(3.0))
with ops.name_scope(get_gan_model_fn().discriminator_scope.name):
ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES,
constant_op.constant(2.0))
# Check that losses now include the correct regularization values.
reg_loss = train.gan_loss(get_gan_model_fn())
with self.test_session(use_gpu=True):
reg_loss_gen_np = reg_loss.generator_loss.eval()
reg_loss_dis_np = reg_loss.discriminator_loss.eval()
self.assertEqual(3.0, reg_loss_gen_np - no_reg_loss_gen_np)
self.assertEqual(2.0, reg_loss_dis_np - no_reg_loss_dis_np)
@parameterized.named_parameters(
('notcallable', create_acgan_model),
('callable', create_callable_acgan_model),
)
def test_acgan(self, create_gan_model_fn):
"""Test that ACGAN models work."""
model = create_gan_model_fn()
loss = train.gan_loss(model)
loss_ac_gen = train.gan_loss(model, aux_cond_generator_weight=1.0)
loss_ac_dis = train.gan_loss(model, aux_cond_discriminator_weight=1.0)
self.assertIsInstance(loss, namedtuples.GANLoss)
self.assertIsInstance(loss_ac_gen, namedtuples.GANLoss)
self.assertIsInstance(loss_ac_dis, namedtuples.GANLoss)
# Check values.
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
loss_gen_np, loss_ac_gen_gen_np, loss_ac_dis_gen_np = sess.run([
loss.generator_loss, loss_ac_gen.generator_loss,
loss_ac_dis.generator_loss
])
loss_dis_np, loss_ac_gen_dis_np, loss_ac_dis_dis_np = sess.run([
loss.discriminator_loss, loss_ac_gen.discriminator_loss,
loss_ac_dis.discriminator_loss
])
self.assertLess(loss_gen_np, loss_dis_np)
self.assertTrue(np.isscalar(loss_ac_gen_gen_np))
self.assertTrue(np.isscalar(loss_ac_dis_gen_np))
self.assertTrue(np.isscalar(loss_ac_gen_dis_np))
self.assertTrue(np.isscalar(loss_ac_dis_dis_np))
@parameterized.named_parameters(
('notcallable', create_cyclegan_model),
('callable', create_callable_cyclegan_model),
)
def test_cyclegan(self, create_gan_model_fn):
"""Test that CycleGan models work."""
model = create_gan_model_fn()
loss = train.cyclegan_loss(model)
self.assertIsInstance(loss, namedtuples.CycleGANLoss)
# Check values.
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
(loss_x2y_gen_np, loss_x2y_dis_np, loss_y2x_gen_np,
loss_y2x_dis_np) = sess.run([
loss.loss_x2y.generator_loss, loss.loss_x2y.discriminator_loss,
loss.loss_y2x.generator_loss, loss.loss_y2x.discriminator_loss
])
self.assertGreater(loss_x2y_gen_np, loss_x2y_dis_np)
self.assertGreater(loss_y2x_gen_np, loss_y2x_dis_np)
self.assertTrue(np.isscalar(loss_x2y_gen_np))
self.assertTrue(np.isscalar(loss_x2y_dis_np))
self.assertTrue(np.isscalar(loss_y2x_gen_np))
self.assertTrue(np.isscalar(loss_y2x_dis_np))
@parameterized.named_parameters(
('notcallable', create_stargan_model),
('callable', create_callable_stargan_model),
)
def test_stargan(self, create_gan_model_fn):
model = create_gan_model_fn()
model_loss = train.stargan_loss(model)
self.assertIsInstance(model_loss, namedtuples.GANLoss)
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
gen_loss, disc_loss = sess.run(
[model_loss.generator_loss, model_loss.discriminator_loss])
self.assertTrue(np.isscalar(gen_loss))
self.assertTrue(np.isscalar(disc_loss))
@parameterized.named_parameters(
('gan', create_gan_model),
('callable_gan', create_callable_gan_model),
('infogan', create_infogan_model),
('callable_infogan', create_callable_infogan_model),
('acgan', create_acgan_model),
('callable_acgan', create_callable_acgan_model),
)
def test_tensor_pool(self, create_gan_model_fn):
"""Test tensor pool option."""
model = create_gan_model_fn()
tensor_pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=5)
loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
self.assertIsInstance(loss, namedtuples.GANLoss)
# Check values.
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
for _ in range(10):
sess.run([loss.generator_loss, loss.discriminator_loss])
def test_discriminator_only_sees_pool(self):
"""Checks that discriminator only sees pooled values."""
def checker_gen_fn(_):
return constant_op.constant(0.0)
model = train.gan_model(
checker_gen_fn,
discriminator_model,
real_data=array_ops.zeros([]),
generator_inputs=random_ops.random_normal([]))
def tensor_pool_fn(_):
return (random_ops.random_uniform([]), random_ops.random_uniform([]))
def checker_dis_fn(inputs, _):
"""Discriminator that checks that it only sees pooled Tensors."""
self.assertFalse(constant_op.is_constant(inputs))
return inputs
model = model._replace(
discriminator_fn=checker_dis_fn)
train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
def test_doesnt_crash_when_in_nested_scope(self):
with variable_scope.variable_scope('outer_scope'):
gan_model = train.gan_model(
generator_model,
discriminator_model,
real_data=array_ops.zeros([1, 2]),
generator_inputs=random_ops.random_normal([1, 2]))
# This should work inside a scope.
train.gan_loss(gan_model, gradient_penalty_weight=1.0)
# This should also work outside a scope.
train.gan_loss(gan_model, gradient_penalty_weight=1.0)
class TensorPoolAdjusteModelTest(test.TestCase):
def _check_tensor_pool_adjusted_model_outputs(
self, tensor1, tensor2, pool_size):
history_values = []
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
for i in range(2 * pool_size):
t1, t2 = sess.run([tensor1, tensor2])
history_values.append(t1)
if i < pool_size:
# For [0, pool_size), the pool is not full, tensor1 should be equal
# to tensor2 as the pool.
self.assertAllEqual(t1, t2)
else:
# For [pool_size, ?), the pool is full, tensor2 must be equal to some
# historical values of tensor1 (which is previously stored in the
# pool).
self.assertTrue(any([(v == t2).all() for v in history_values]))
def _make_new_model_and_check(self, model, pool_size):
pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=pool_size)
new_model = train._tensor_pool_adjusted_model(model, pool_fn)
# 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0'
self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES)))
self.assertIsNot(new_model.discriminator_gen_outputs,
model.discriminator_gen_outputs)
return new_model
def test_tensor_pool_adjusted_model_gan(self):
"""Test `_tensor_pool_adjusted_model` for gan model."""
pool_size = 5
model = create_gan_model()
new_model = self._make_new_model_and_check(model, pool_size)
# Check values.
self._check_tensor_pool_adjusted_model_outputs(
model.discriminator_gen_outputs, new_model.discriminator_gen_outputs,
pool_size)
def test_tensor_pool_adjusted_model_infogan(self):
"""Test _tensor_pool_adjusted_model for infogan model."""
pool_size = 5
model = create_infogan_model()
new_model = self._make_new_model_and_check(model, pool_size)
# Check values.
self.assertIsNot(new_model.predicted_distributions,
model.predicted_distributions)
self._check_tensor_pool_adjusted_model_outputs(
model.discriminator_gen_outputs, new_model.discriminator_gen_outputs,
pool_size)
def test_tensor_pool_adjusted_model_acgan(self):
"""Test _tensor_pool_adjusted_model for acgan model."""
pool_size = 5
model = create_acgan_model()
new_model = self._make_new_model_and_check(model, pool_size)
# Check values.
self.assertIsNot(new_model.discriminator_gen_classification_logits,
model.discriminator_gen_classification_logits)
self._check_tensor_pool_adjusted_model_outputs(
model.discriminator_gen_outputs, new_model.discriminator_gen_outputs,
pool_size)
class GANTrainOpsTest(test.TestCase, parameterized.TestCase):
"""Tests for `gan_train_ops`."""
@parameterized.named_parameters(
('gan', create_gan_model),
('callable_gan', create_callable_gan_model),
('infogan', create_infogan_model),
('callable_infogan', create_callable_infogan_model),
('acgan', create_acgan_model),
('callable_acgan', create_callable_acgan_model),
)
def test_output_type(self, create_gan_model_fn):
model = create_gan_model_fn()
loss = train.gan_loss(model)
g_opt = gradient_descent.GradientDescentOptimizer(1.0)
d_opt = gradient_descent.GradientDescentOptimizer(1.0)
train_ops = train.gan_train_ops(
model,
loss,
g_opt,
d_opt,
summarize_gradients=True,
colocate_gradients_with_ops=True)
self.assertIsInstance(train_ops, namedtuples.GANTrainOps)
# TODO(joelshor): Add a test to check that custom update op is run.
@parameterized.named_parameters(
('gan', create_gan_model, False),
('gan_provideupdates', create_gan_model, True),
('callable_gan', create_callable_gan_model, False),
('callable_gan_provideupdates', create_callable_gan_model, True),
('infogan', create_infogan_model, False),
('infogan_provideupdates', create_infogan_model, True),
('callable_infogan', create_callable_infogan_model, False),
('callable_infogan_provideupdates', create_callable_infogan_model, True),
('acgan', create_acgan_model, False),
('acgan_provideupdates', create_acgan_model, True),
('callable_acgan', create_callable_acgan_model, False),
('callable_acgan_provideupdates', create_callable_acgan_model, True),
)
def test_unused_update_ops(self, create_gan_model_fn, provide_update_ops):
model = create_gan_model_fn()
loss = train.gan_loss(model)
# Add generator and discriminator update ops.
with variable_scope.variable_scope(model.generator_scope):
gen_update_count = variable_scope.get_variable('gen_count', initializer=0)
gen_update_op = gen_update_count.assign_add(1)
ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, gen_update_op)
with variable_scope.variable_scope(model.discriminator_scope):
dis_update_count = variable_scope.get_variable('dis_count', initializer=0)
dis_update_op = dis_update_count.assign_add(1)
ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, dis_update_op)
# Add an update op outside the generator and discriminator scopes.
if provide_update_ops:
kwargs = {
'update_ops': [
constant_op.constant(1.0), gen_update_op, dis_update_op
]
}
else:
ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, constant_op.constant(1.0))
kwargs = {}
g_opt = gradient_descent.GradientDescentOptimizer(1.0)
d_opt = gradient_descent.GradientDescentOptimizer(1.0)
with self.assertRaisesRegexp(ValueError, 'There are unused update ops:'):
train.gan_train_ops(
model, loss, g_opt, d_opt, check_for_unused_update_ops=True, **kwargs)
train_ops = train.gan_train_ops(
model, loss, g_opt, d_opt, check_for_unused_update_ops=False, **kwargs)
with self.test_session(use_gpu=True) as sess:
sess.run(variables.global_variables_initializer())
self.assertEqual(0, gen_update_count.eval())
self.assertEqual(0, dis_update_count.eval())
train_ops.generator_train_op.eval()
self.assertEqual(1, gen_update_count.eval())
self.assertEqual(0, dis_update_count.eval())
train_ops.discriminator_train_op.eval()
self.assertEqual(1, gen_update_count.eval())
self.assertEqual(1, dis_update_count.eval())
@parameterized.named_parameters(
('gan', create_gan_model, False),
('callable_gan', create_callable_gan_model, False),
('infogan', create_infogan_model, False),
('callable_infogan', create_callable_infogan_model, False),
('acgan', create_acgan_model, False),
('callable_acgan', create_callable_acgan_model, False),
('gan_canbeint32', create_gan_model, True),
)
def test_sync_replicas(self, create_gan_model_fn, create_global_step):
model = create_gan_model_fn()
loss = train.gan_loss(model)
num_trainable_vars = len(variables_lib.get_trainable_variables())
if create_global_step:
gstep = variable_scope.get_variable(
'custom_gstep', dtype=dtypes.int32, initializer=0, trainable=False)
ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, gstep)
g_opt = get_sync_optimizer()
d_opt = get_sync_optimizer()
train_ops = train.gan_train_ops(
model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt)
self.assertIsInstance(train_ops, namedtuples.GANTrainOps)
# No new trainable variables should have been added.
self.assertEqual(num_trainable_vars,
len(variables_lib.get_trainable_variables()))
g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1)
d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1)
# Check that update op is run properly.
global_step = training_util.get_or_create_global_step()
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
variables.local_variables_initializer().run()
g_opt.chief_init_op.run()
d_opt.chief_init_op.run()
gstep_before = global_step.eval()
# Start required queue runner for SyncReplicasOptimizer.
coord = coordinator.Coordinator()
g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord)
d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord)
g_sync_init_op.run()
d_sync_init_op.run()
train_ops.generator_train_op.eval()
# Check that global step wasn't incremented.
self.assertEqual(gstep_before, global_step.eval())
train_ops.discriminator_train_op.eval()
# Check that global step wasn't incremented.
self.assertEqual(gstep_before, global_step.eval())
coord.request_stop()
coord.join(g_threads + d_threads)
class GANTrainTest(test.TestCase, parameterized.TestCase):
"""Tests for `gan_train`."""
def _gan_train_ops(self, generator_add, discriminator_add):
step = training_util.create_global_step()
# Increment the global count every time a train op is run so we can count
# the number of times they're run.
# NOTE: `use_locking=True` is required to avoid race conditions with
# joint training.
train_ops = namedtuples.GANTrainOps(
generator_train_op=step.assign_add(generator_add, use_locking=True),
discriminator_train_op=step.assign_add(
discriminator_add, use_locking=True),
global_step_inc_op=step.assign_add(1))
return train_ops
@parameterized.named_parameters(
('gan', create_gan_model),
('callable_gan', create_callable_gan_model),
('infogan', create_infogan_model),
('callable_infogan', create_callable_infogan_model),
('acgan', create_acgan_model),
('callable_acgan', create_callable_acgan_model),
)
def test_run_helper(self, create_gan_model_fn):
random_seed.set_random_seed(1234)
model = create_gan_model_fn()
loss = train.gan_loss(model)
g_opt = gradient_descent.GradientDescentOptimizer(1.0)
d_opt = gradient_descent.GradientDescentOptimizer(1.0)
train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)
final_step = train.gan_train(
train_ops,
logdir='',
hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
self.assertTrue(np.isscalar(final_step))
self.assertEqual(2, final_step)
@parameterized.named_parameters(
('seq_train_steps', train.get_sequential_train_hooks),
('efficient_seq_train_steps', train.get_joint_train_hooks),
)
def test_multiple_steps(self, get_hooks_fn_fn):
"""Test multiple train steps."""
train_ops = self._gan_train_ops(generator_add=10, discriminator_add=100)
train_steps = namedtuples.GANTrainSteps(
generator_train_steps=3, discriminator_train_steps=4)
final_step = train.gan_train(
train_ops,
get_hooks_fn=get_hooks_fn_fn(train_steps),
logdir='',
hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)])
self.assertTrue(np.isscalar(final_step))
self.assertEqual(1 + 3 * 10 + 4 * 100, final_step)
def test_supervisor_run_gan_model_train_ops_multiple_steps(self):
step = training_util.create_global_step()
train_ops = namedtuples.GANTrainOps(
generator_train_op=constant_op.constant(3.0),
discriminator_train_op=constant_op.constant(2.0),
global_step_inc_op=step.assign_add(1))
train_steps = namedtuples.GANTrainSteps(
generator_train_steps=3, discriminator_train_steps=4)
final_loss = slim_learning.train(
train_op=train_ops,
logdir='',
global_step=step,
number_of_steps=1,
train_step_fn=train.get_sequential_train_steps(train_steps))
self.assertTrue(np.isscalar(final_loss))
self.assertEqual(17.0, final_loss)
class PatchGANTest(test.TestCase, parameterized.TestCase):
"""Tests that functions work on PatchGAN style output."""
@parameterized.named_parameters(
('gan', create_gan_model),
('callable_gan', create_callable_gan_model),
('infogan', create_infogan_model),
('callable_infogan', create_callable_infogan_model),
('acgan', create_acgan_model),
('callable_acgan', create_callable_acgan_model),
)
def test_patchgan(self, create_gan_model_fn):
"""Ensure that patch-based discriminators work end-to-end."""
random_seed.set_random_seed(1234)
model = create_gan_model_fn()
loss = train.gan_loss(model)
g_opt = gradient_descent.GradientDescentOptimizer(1.0)
d_opt = gradient_descent.GradientDescentOptimizer(1.0)
train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)
final_step = train.gan_train(
train_ops,
logdir='',
hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
self.assertTrue(np.isscalar(final_step))
self.assertEqual(2, final_step)
if __name__ == '__main__':
test.main()