blob: 36918ca2db2c7a89377189cc28e225e2001cd68f [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests that show that DistributionStrategy works with optimizer v2."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
def get_model():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
return model
class MirroredStrategyOptimizerV2Test(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.central_storage_strategy_with_two_gpus,
],
mode=['graph', 'eager']))
def testKerasOptimizerWithUnequalInput(self, distribution):
self.skipTest('b/130309197')
with distribution.scope():
var = variables.Variable(
2.0, name='var', aggregation=variable_scope.VariableAggregation.SUM)
optimizer = adam.Adam(learning_rate=0.01, beta_1=0.2, beta_2=0.2)
all_vars = []
def model_fn():
def loss_fn():
replica_id = _replica_id()
return math_ops.cast(replica_id + 1, dtype=dtypes.float32) * 0.5 * var
train_op = optimizer.minimize(loss_fn, var_list=[var])
return train_op, optimizer
def train_fn():
train_op, optimizer = distribution.extended.call_for_each_replica(
model_fn)
if not all_vars:
all_vars.append(var)
all_vars.append(optimizer.get_slot(var, 'm'))
all_vars.append(optimizer.get_slot(var, 'v'))
return distribution.group(train_op)
if not context.executing_eagerly():
with self.cached_session() as sess:
train_fn = sess.make_callable(train_fn())
self.evaluate(variables.global_variables_initializer())
# first step.
train_fn()
# var(1) = var(0) - lr * m(1) * sqrt(1 - beta2) / sqrt(v(1)) / (1 - beta1)
# = 2.0 - 0.01 * 1.2 * sqrt(0.8) / sqrt(1.8) / 0.8
self.assertAllClose(1.99, self.evaluate(all_vars[0]))
# m(1) = beta1 * m(0) + (1-beta1) * grad = 0.2 * 0 + 0.8 * (1 + 2) / 2
self.assertAllClose(1.2, self.evaluate(all_vars[1]))
# v(1) = beta2 * v(0) + (1-beta2) * grad^2 = 0.2 * 0 + 0.8 * 2.25
self.assertAllClose(1.8, self.evaluate(all_vars[2]))
# second step.
train_fn()
# var(1) = var(0) - lr * 2 = 1.98
self.assertAllClose(1.98, self.evaluate(all_vars[0]))
# m(2) = beta1 * m(1) + (1-beta1) * grad = 0.2 * 1.2 + 0.8 * 1.5
self.assertAllClose(1.44, self.evaluate(all_vars[1]))
# v(2) = beta2 * v(1) + (1-beta2) * grad^2 = 0.2 * 1.8 + 0.8 * 2.25
self.assertAllClose(2.16, self.evaluate(all_vars[2]))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.central_storage_strategy_with_two_gpus,
],
mode=['graph', 'eager'],
run_distributed=[True, False]))
def testOptimizerWithKerasModelAndNumpyArrays(self, distribution,
run_distributed):
self.skipTest('b/130309197')
with self.cached_session():
with distribution.scope():
model = get_model()
optimizer = gradient_descent.SGD(0.001)
loss = 'mse'
metrics = ['mae']
model.compile(
optimizer, loss, metrics=metrics, run_distributed=run_distributed)
inputs = np.zeros((64, 3), dtype=np.float32)
targets = np.zeros((64, 4), dtype=np.float32)
model.fit(
inputs,
targets,
epochs=1,
batch_size=2,
verbose=0,
validation_data=(inputs, targets))
model.evaluate(inputs, targets)
model.predict(inputs)
def _replica_id():
replica_id = ds_context.get_replica_context().replica_id_in_sync_group
if not isinstance(replica_id, ops.Tensor):
replica_id = constant_op.constant(replica_id)
return replica_id
if __name__ == '__main__':
test.main()