blob: 81511d2bf26546c2d2ce576fd52b4bf49df40e72 [file] [log] [blame]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test to demonstrate custom training loop with ParameterServerStrategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
from absl import logging
import portpicker
import tensorflow as tf
NUM_EPOCHS = 10
NUM_STEPS = 100
STEPS_PER_EXECUTION = 10
class ParameterServerCustomTrainingLoopTest(tf.test.TestCase):
"""Test to demonstrate custom training loop with ParameterServerStrategy."""
def create_in_process_cluster(self, num_workers, num_ps):
"""Creates and starts local servers and returns the cluster_resolver."""
worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
cluster_dict = {}
cluster_dict["worker"] = ["localhost:%s" % port for port in worker_ports]
if num_ps > 0:
cluster_dict["ps"] = ["localhost:%s" % port for port in ps_ports]
cluster_spec = tf.train.ClusterSpec(cluster_dict)
# Workers need some inter_ops threads to work properly.
worker_config = tf.compat.v1.ConfigProto()
if multiprocessing.cpu_count() < num_workers + 1:
worker_config.inter_op_parallelism_threads = num_workers + 1
for i in range(num_workers):
tf.distribute.Server(
cluster_spec,
job_name="worker",
task_index=i,
config=worker_config,
protocol="grpc")
for i in range(num_ps):
tf.distribute.Server(
cluster_spec, job_name="ps", task_index=i, protocol="grpc")
return cluster_spec
def setUp(self):
super(ParameterServerCustomTrainingLoopTest, self).setUp()
cluster_spec = self.create_in_process_cluster(num_workers=3, num_ps=2)
cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver(
cluster_spec, rpc_layer="grpc")
self.strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
self.coordinator = (
tf.distribute.experimental.coordinator.ClusterCoordinator(
self.strategy))
def testCustomTrainingLoop(self):
coordinator, strategy = self.coordinator, self.strategy
def per_worker_dataset_fn():
def dataset_fn(_):
return tf.data.Dataset.from_tensor_slices((tf.random.uniform(
(6, 10)), tf.random.uniform((6, 10)))).batch(2).repeat()
return strategy.distribute_datasets_from_function(dataset_fn)
per_worker_dataset = coordinator.create_per_worker_dataset(
per_worker_dataset_fn)
with strategy.scope():
model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001)
train_accuracy = tf.keras.metrics.CategoricalAccuracy(
name="train_accuracy")
@tf.function
def worker_train_fn(iterator):
def replica_fn(inputs):
"""Training loop function."""
batch_data, labels = inputs
with tf.GradientTape() as tape:
predictions = model(batch_data, training=True)
loss = tf.keras.losses.CategoricalCrossentropy(
reduction=tf.keras.losses.Reduction.NONE)(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_accuracy.update_state(labels, predictions)
for _ in tf.range(STEPS_PER_EXECUTION):
strategy.run(replica_fn, args=(next(iterator),))
for epoch in range(NUM_EPOCHS):
distributed_iterator = iter(per_worker_dataset)
for step in range(0, NUM_STEPS, STEPS_PER_EXECUTION):
coordinator.schedule(worker_train_fn, args=(distributed_iterator,))
logging.info("Epoch %d, step %d scheduled.", epoch, step)
logging.info("Now joining at epoch %d.", epoch)
coordinator.join()
logging.info(
"Finished joining at epoch %d. Training accuracy: %f. "
"Total iterations: %d", epoch, train_accuracy.result(),
optimizer.iterations.value())
if epoch < NUM_EPOCHS - 1:
train_accuracy.reset_states()
if __name__ == "__main__":
tf.__internal__.distribute.multi_process_runner.test_main()