blob: bcb4720275bb01b78feefb985fac6bf3c5c4ffa9 [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Correctness tests for tf.keras DNN model using DistributionStrategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.eager import context
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.distribute import keras_correctness_test_base
from tensorflow.python.keras.distribute import strategy_combinations
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
from tensorflow.python.training import gradient_descent
def all_strategy_combinations_with_eager_and_graph_modes():
return (combinations.combine(
distribution=strategy_combinations.all_strategies,
mode=['graph', 'eager']) + combinations.combine(
distribution=strategy_combinations.multi_worker_mirrored_strategies,
mode='eager'))
def all_strategy_combinations_with_graph_mode():
return (combinations.combine(
distribution=keras_correctness_test_base.all_strategies,
mode=['graph']))
def is_default_strategy(strategy):
with strategy.scope():
return not distribution_strategy_context.has_strategy()
@testing_utils.run_all_without_tensor_float_32(
'Uses Dense layers, which call matmul')
class TestDistributionStrategyDnnCorrectness(
keras_correctness_test_base.TestDistributionStrategyCorrectnessBase):
def get_model(self,
initial_weights=None,
distribution=None,
input_shapes=None):
with keras_correctness_test_base.MaybeDistributionScope(distribution):
# We add few non-linear layers to make it non-trivial.
model = keras.Sequential()
model.add(keras.layers.Dense(10, activation='relu', input_shape=(1,)))
model.add(
keras.layers.Dense(
10,
activation='relu',
kernel_regularizer=keras.regularizers.l2(1e-4)))
model.add(keras.layers.Dense(10, activation='relu'))
model.add(keras.layers.Dense(1))
if initial_weights:
model.set_weights(initial_weights)
model.compile(
loss=keras.losses.mean_squared_error,
optimizer=gradient_descent_keras.SGD(0.05),
metrics=['mse'])
return model
def get_data(self):
x_train = np.random.rand(9984, 1).astype('float32')
y_train = 3 * x_train
x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32)
return x_train, y_train, x_predict
def get_data_with_partial_last_batch(self):
x_train = np.random.rand(10000, 1).astype('float32')
y_train = 3 * x_train
x_eval = np.random.rand(10000, 1).astype('float32')
y_eval = 3 * x_eval
x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32)
return x_train, y_train, x_eval, y_eval, x_predict
def get_data_with_partial_last_batch_eval(self):
x_train = np.random.rand(9984, 1).astype('float32')
y_train = 3 * x_train
x_eval = np.random.rand(10000, 1).astype('float32')
y_eval = 3 * x_eval
x_predict = np.array([[1.], [2.], [3.], [4.]], dtype=np.float32)
return x_train, y_train, x_eval, y_eval, x_predict
@ds_combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations() +
keras_correctness_test_base.multi_worker_mirrored_eager())
def test_dnn_correctness(self, distribution, use_numpy, use_validation_data):
self.run_correctness_test(distribution, use_numpy, use_validation_data)
@ds_combinations.generate(
keras_correctness_test_base
.test_combinations_with_tpu_strategies_graph() +
keras_correctness_test_base.multi_worker_mirrored_eager())
def test_dnn_correctness_with_partial_last_batch_eval(self, distribution,
use_numpy,
use_validation_data):
self.run_correctness_test(
distribution, use_numpy, use_validation_data, partial_last_batch='eval')
@ds_combinations.generate(
keras_correctness_test_base
.strategy_minus_tpu_and_input_config_combinations_eager() +
keras_correctness_test_base.multi_worker_mirrored_eager())
def test_dnn_correctness_with_partial_last_batch(self, distribution,
use_numpy,
use_validation_data):
distribution.extended.experimental_enable_get_next_as_optional = True
self.run_correctness_test(
distribution,
use_numpy,
use_validation_data,
partial_last_batch='train_and_eval',
training_epochs=1)
@ds_combinations.generate(all_strategy_combinations_with_graph_mode())
def test_dnn_with_dynamic_learning_rate(self, distribution):
self.run_dynamic_lr_test(distribution)
class TestDistributionStrategyDnnMetricCorrectness(
keras_correctness_test_base.TestDistributionStrategyCorrectnessBase):
def get_model(self,
distribution=None,
input_shapes=None):
with distribution.scope():
model = keras.Sequential()
model.add(
keras.layers.Dense(1, input_shape=(1,), kernel_initializer='ones'))
model.compile(
loss=keras.losses.mean_squared_error,
optimizer=gradient_descent_keras.SGD(0.05),
metrics=[keras.metrics.BinaryAccuracy()])
return model
def run_metric_correctness_test(self, distribution):
with self.cached_session():
self.set_up_test_config()
x_train, y_train, _ = self.get_data()
model = self.get_model(
distribution=distribution)
batch_size = 64
batch_size = (
keras_correctness_test_base.get_batch_size(batch_size, distribution))
train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = (
keras_correctness_test_base.batch_wrapper(train_dataset, batch_size))
history = model.fit(x=train_dataset, epochs=2, steps_per_epoch=10)
self.assertEqual(history.history['binary_accuracy'], [1.0, 1.0])
@ds_combinations.generate(
all_strategy_combinations_with_eager_and_graph_modes())
def test_simple_dnn_metric_correctness(self, distribution):
self.run_metric_correctness_test(distribution)
class TestDistributionStrategyDnnMetricEvalCorrectness(
keras_correctness_test_base.TestDistributionStrategyCorrectnessBase):
def get_model(self,
distribution=None,
input_shapes=None):
with distribution.scope():
model = keras.Sequential()
model.add(
keras.layers.Dense(
3, activation='relu', input_dim=4, kernel_initializer='ones'))
model.add(
keras.layers.Dense(
1, activation='sigmoid', kernel_initializer='ones'))
model.compile(
loss='mae',
metrics=['accuracy', keras.metrics.BinaryAccuracy()],
optimizer=gradient_descent.GradientDescentOptimizer(0.001))
return model
def run_eval_metrics_correctness_test(self, distribution):
with self.cached_session():
self.set_up_test_config()
model = self.get_model(
distribution=distribution)
# verify correctness of stateful and stateless metrics.
x = np.ones((100, 4)).astype('float32')
y = np.ones((100, 1)).astype('float32')
dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat()
dataset = keras_correctness_test_base.batch_wrapper(dataset, 4)
outs = model.evaluate(dataset, steps=10)
self.assertEqual(outs[1], 1.)
self.assertEqual(outs[2], 1.)
y = np.zeros((100, 1)).astype('float32')
dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat()
dataset = keras_correctness_test_base.batch_wrapper(dataset, 4)
outs = model.evaluate(dataset, steps=10)
self.assertEqual(outs[1], 0.)
self.assertEqual(outs[2], 0.)
@ds_combinations.generate(
all_strategy_combinations_with_eager_and_graph_modes())
def test_identity_model_metric_eval_correctness(self, distribution):
self.run_eval_metrics_correctness_test(distribution)
class SubclassedModel(keras.Model):
def __init__(self, initial_weights, input_shapes):
super(SubclassedModel, self).__init__()
self.dense1 = keras.layers.Dense(10, activation='relu', input_shape=(1,))
self.dense2 = keras.layers.Dense(
10, activation='relu', kernel_regularizer=keras.regularizers.l2(1e-4))
self.dense3 = keras.layers.Dense(10, activation='relu')
self.dense4 = keras.layers.Dense(1)
if input_shapes:
self.build(input_shapes)
else:
# This covers cases when the input is DatasetV1Adapter.
self.build((None, 1))
if initial_weights:
self.set_weights(initial_weights)
def call(self, inputs):
x = self.dense1(inputs)
x = self.dense2(x)
x = self.dense3(x)
return self.dense4(x)
@testing_utils.run_all_without_tensor_float_32(
'Uses Dense layers, which call matmul')
class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
TestDistributionStrategyDnnCorrectness):
def get_model(self,
initial_weights=None,
distribution=None,
input_shapes=None):
with keras_correctness_test_base.MaybeDistributionScope(distribution):
model = SubclassedModel(initial_weights, input_shapes)
model.compile(
loss=keras.losses.mean_squared_error,
optimizer=gradient_descent_keras.SGD(0.05),
metrics=['mse'])
return model
@ds_combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations() +
keras_correctness_test_base.multi_worker_mirrored_eager())
def test_dnn_correctness(self, distribution, use_numpy, use_validation_data):
if (context.executing_eagerly()) or is_default_strategy(distribution):
self.run_correctness_test(distribution, use_numpy, use_validation_data)
elif K.is_tpu_strategy(distribution) and not context.executing_eagerly():
with self.assertRaisesRegex(
ValueError,
'Expected `model` argument to be a functional `Model` instance, '
'but got a subclass model instead.'):
self.run_correctness_test(distribution, use_numpy, use_validation_data)
else:
with self.assertRaisesRegex(
ValueError,
'We currently do not support distribution strategy with a '
'`Sequential` model that is created without `input_shape`/'
'`input_dim` set in its first layer or a subclassed model.'):
self.run_correctness_test(distribution, use_numpy, use_validation_data)
@ds_combinations.generate(all_strategy_combinations_with_graph_mode())
def test_dnn_with_dynamic_learning_rate(self, distribution):
if ((context.executing_eagerly() and not K.is_tpu_strategy(distribution)) or
is_default_strategy(distribution)):
self.run_dynamic_lr_test(distribution)
elif K.is_tpu_strategy(distribution):
with self.assertRaisesRegex(
ValueError,
'Expected `model` argument to be a functional `Model` instance, '
'but got a subclass model instead.'):
self.run_dynamic_lr_test(distribution)
else:
with self.assertRaisesRegex(
ValueError,
'We currently do not support distribution strategy with a '
'`Sequential` model that is created without `input_shape`/'
'`input_dim` set in its first layer or a subclassed model.'):
self.run_dynamic_lr_test(distribution)
@ds_combinations.generate(
keras_correctness_test_base.test_combinations_with_tpu_strategies_graph())
def test_dnn_correctness_with_partial_last_batch_eval(self, distribution,
use_numpy,
use_validation_data):
with self.assertRaisesRegex(
ValueError,
'Expected `model` argument to be a functional `Model` instance, '
'but got a subclass model instead.'):
self.run_correctness_test(
distribution,
use_numpy,
use_validation_data,
partial_last_batch='eval')
if __name__ == '__main__':
multi_process_runner.test_main()