TF API Fixit: Forking create_in_process_cluster to avoid private usage in Keras.
PiperOrigin-RevId: 367498737
Change-Id: I8579d267efce3be855729e9c0911e2917b16f676
diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD
index 42ed89f..732badf 100644
--- a/tensorflow/python/keras/distribute/BUILD
+++ b/tensorflow/python/keras/distribute/BUILD
@@ -323,6 +323,7 @@
],
srcs_version = "PY3",
deps = [
+ ":multi_worker_testing_utils",
":optimizer_combinations",
":strategy_combinations",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py
index f966ce9..2c17fbf 100644
--- a/tensorflow/python/keras/distribute/distribute_strategy_test.py
+++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py
@@ -31,7 +31,6 @@
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import multi_process_runner
-from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute import reduce_util
@@ -47,6 +46,7 @@
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.distribute import distributed_training_utils
from tensorflow.python.keras.distribute import distributed_training_utils_v1
+from tensorflow.python.keras.distribute import multi_worker_testing_utils
from tensorflow.python.keras.distribute import optimizer_combinations
from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
from tensorflow.python.keras.distribute.strategy_combinations import multi_worker_mirrored_strategies
@@ -2429,7 +2429,7 @@
@ds_combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_unimplemented_parameter_server_strategy(self):
- cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ cluster_spec = multi_worker_testing_utils.create_in_process_cluster(
num_workers=3, num_ps=2)
cluster_resolver = SimpleClusterResolver(
cluster_spec=server_lib.ClusterSpec(cluster_spec),
diff --git a/tensorflow/python/keras/distribute/multi_worker_testing_utils.py b/tensorflow/python/keras/distribute/multi_worker_testing_utils.py
index 5a162e7..c406228 100644
--- a/tensorflow/python/keras/distribute/multi_worker_testing_utils.py
+++ b/tensorflow/python/keras/distribute/multi_worker_testing_utils.py
@@ -14,17 +14,34 @@
# ==============================================================================
"""Utilities for testing multi-worker distribution strategies with Keras."""
+import threading
+import unittest
+
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import server_lib
from tensorflow.python.training.server_lib import ClusterSpec
+_portpicker_import_error = None
+try:
+ import portpicker # pylint: disable=g-import-not-at-top
+except (ImportError, ModuleNotFoundError) as _error: # pylint: disable=invalid-name
+ _portpicker_import_error = _error
+ portpicker = None
+
+ASSIGNED_PORTS = set()
+lock = threading.Lock()
+
+
def mnist_synthetic_dataset(batch_size, steps_per_epoch):
"""Generate synthetic MNIST dataset for testing."""
# train dataset
@@ -76,6 +93,151 @@
def make_parameter_server_cluster(num_workers, num_ps):
- cluster_def = multi_worker_test_base.create_in_process_cluster(
+ cluster_def = create_in_process_cluster(
num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
return SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc")
+
+
+def pick_unused_port():
+ """Returns an unused and unassigned local port."""
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+
+ global ASSIGNED_PORTS
+ with lock:
+ while True:
+ try:
+ port = portpicker.pick_unused_port()
+ except portpicker.NoFreePortFoundError:
+ raise unittest.SkipTest("Flakes in portpicker library do not represent "
+ "TensorFlow errors.")
+ if port > 10000 and port not in ASSIGNED_PORTS:
+ ASSIGNED_PORTS.add(port)
+ logging.info("Using local port %r", port)
+ return port
+
+
+def _create_cluster(num_workers,
+ num_ps,
+ has_chief=False,
+ has_eval=False,
+ protocol="grpc",
+ worker_config=None,
+ ps_config=None,
+ eval_config=None,
+ worker_name="worker",
+ ps_name="ps",
+ chief_name="chief"):
+ """Creates and starts local servers and returns the cluster_spec dict."""
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+ worker_ports = [pick_unused_port() for _ in range(num_workers)]
+ ps_ports = [pick_unused_port() for _ in range(num_ps)]
+
+ cluster_dict = {}
+ if num_workers > 0:
+ cluster_dict[worker_name] = ["localhost:%s" % port for port in worker_ports]
+ if num_ps > 0:
+ cluster_dict[ps_name] = ["localhost:%s" % port for port in ps_ports]
+ if has_eval:
+ cluster_dict["evaluator"] = ["localhost:%s" % pick_unused_port()]
+ if has_chief:
+ cluster_dict[chief_name] = ["localhost:%s" % pick_unused_port()]
+
+ cs = server_lib.ClusterSpec(cluster_dict)
+
+ for i in range(num_workers):
+ server_lib.Server(
+ cs,
+ job_name=worker_name,
+ protocol=protocol,
+ task_index=i,
+ config=worker_config,
+ start=True)
+
+ for i in range(num_ps):
+ server_lib.Server(
+ cs,
+ job_name=ps_name,
+ protocol=protocol,
+ task_index=i,
+ config=ps_config,
+ start=True)
+
+ if has_chief:
+ server_lib.Server(
+ cs,
+ job_name=chief_name,
+ protocol=protocol,
+ task_index=0,
+ config=worker_config,
+ start=True)
+
+ if has_eval:
+ server_lib.Server(
+ cs,
+ job_name="evaluator",
+ protocol=protocol,
+ task_index=0,
+ config=eval_config,
+ start=True)
+
+ return cluster_dict
+
+
+def create_in_process_cluster(num_workers,
+ num_ps,
+ has_chief=False,
+ has_eval=False,
+ rpc_layer="grpc"):
+ """Create an in-process cluster that consists of only standard server."""
+ # Leave some memory for cuda runtime.
+ gpu_mem_frac = 0.7 / (num_workers + int(has_chief) + int(has_eval))
+ worker_config = config_pb2.ConfigProto()
+ worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
+
+ # The cluster may hang if workers don't have enough inter_op threads. See
+ # b/172296720 for more details.
+ if worker_config.inter_op_parallelism_threads < num_workers + 1:
+ worker_config.inter_op_parallelism_threads = num_workers + 1
+
+ # Enable collective ops which has no impact on non-collective ops.
+ if has_chief:
+ worker_config.experimental.collective_group_leader = (
+ "/job:chief/replica:0/task:0")
+ else:
+ worker_config.experimental.collective_group_leader = (
+ "/job:worker/replica:0/task:0")
+
+ ps_config = config_pb2.ConfigProto()
+ ps_config.device_count["GPU"] = 0
+
+ eval_config = config_pb2.ConfigProto()
+ eval_config.experimental.collective_group_leader = ""
+
+ # Create in-process servers. Once an in-process tensorflow server is created,
+ # there is no way to terminate it. So we create one cluster per test process.
+ # We could've started the server in another process, we could then kill that
+ # process to terminate the server. The reasons why we don"t want multiple
+ # processes are
+ # 1) it is more difficult to manage these processes;
+ # 2) there is something global in CUDA such that if we initialize CUDA in the
+ # parent process, the child process cannot initialize it again and thus cannot
+ # use GPUs (https://stackoverflow.com/questions/22950047).
+ cluster = None
+ try:
+ cluster = _create_cluster(
+ num_workers,
+ num_ps=num_ps,
+ has_chief=has_chief,
+ has_eval=has_eval,
+ worker_config=worker_config,
+ ps_config=ps_config,
+ eval_config=eval_config,
+ protocol=rpc_layer)
+ except errors.UnknownError as e:
+ if "Could not start gRPC server" in e.message:
+ raise unittest.SkipTest("Cannot start std servers.")
+ else:
+ raise
+ return cluster
diff --git a/tensorflow/python/keras/utils/BUILD b/tensorflow/python/keras/utils/BUILD
index c1a5787..929badd 100644
--- a/tensorflow/python/keras/utils/BUILD
+++ b/tensorflow/python/keras/utils/BUILD
@@ -276,6 +276,7 @@
":dataset_creator",
"//tensorflow/python/distribute:multi_worker_test_base",
"//tensorflow/python/keras:combinations",
+ "//tensorflow/python/keras/distribute:multi_worker_testing_utils",
"//tensorflow/python/keras/engine",
"//tensorflow/python/keras/layers:core",
],
diff --git a/tensorflow/python/keras/utils/dataset_creator_test.py b/tensorflow/python/keras/utils/dataset_creator_test.py
index 84f33a6..9646344 100644
--- a/tensorflow/python/keras/utils/dataset_creator_test.py
+++ b/tensorflow/python/keras/utils/dataset_creator_test.py
@@ -19,11 +19,11 @@
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import distribute_lib
-from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.coordinator import cluster_coordinator
from tensorflow.python.keras import combinations
+from tensorflow.python.keras.distribute import multi_worker_testing_utils
from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.layers import core as core_layers
@@ -80,7 +80,7 @@
self.assertLen(history.history["loss"], 10)
def _get_parameter_server_strategy(self):
- cluster_def = multi_worker_test_base.create_in_process_cluster(
+ cluster_def = multi_worker_testing_utils.create_in_process_cluster(
num_workers=2, num_ps=1, rpc_layer="grpc")
return parameter_server_strategy_v2.ParameterServerStrategyV2(
SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc"))