TF API Fixit: Forking create_in_process_cluster to avoid private usage in Keras.

PiperOrigin-RevId: 367498737
Change-Id: I8579d267efce3be855729e9c0911e2917b16f676
diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD
index 42ed89f..732badf 100644
--- a/tensorflow/python/keras/distribute/BUILD
+++ b/tensorflow/python/keras/distribute/BUILD
@@ -323,6 +323,7 @@
     ],
     srcs_version = "PY3",
     deps = [
+        ":multi_worker_testing_utils",
         ":optimizer_combinations",
         ":strategy_combinations",
         "//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py
index f966ce9..2c17fbf 100644
--- a/tensorflow/python/keras/distribute/distribute_strategy_test.py
+++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py
@@ -31,7 +31,6 @@
 from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.distribute import mirrored_strategy
 from tensorflow.python.distribute import multi_process_runner
-from tensorflow.python.distribute import multi_worker_test_base
 from tensorflow.python.distribute import parameter_server_strategy
 from tensorflow.python.distribute import parameter_server_strategy_v2
 from tensorflow.python.distribute import reduce_util
@@ -47,6 +46,7 @@
 from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.distribute import distributed_training_utils
 from tensorflow.python.keras.distribute import distributed_training_utils_v1
+from tensorflow.python.keras.distribute import multi_worker_testing_utils
 from tensorflow.python.keras.distribute import optimizer_combinations
 from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
 from tensorflow.python.keras.distribute.strategy_combinations import multi_worker_mirrored_strategies
@@ -2429,7 +2429,7 @@
 
   @ds_combinations.generate(combinations.combine(mode=['graph', 'eager']))
   def test_unimplemented_parameter_server_strategy(self):
-    cluster_spec = multi_worker_test_base.create_in_process_cluster(
+    cluster_spec = multi_worker_testing_utils.create_in_process_cluster(
         num_workers=3, num_ps=2)
     cluster_resolver = SimpleClusterResolver(
         cluster_spec=server_lib.ClusterSpec(cluster_spec),
diff --git a/tensorflow/python/keras/distribute/multi_worker_testing_utils.py b/tensorflow/python/keras/distribute/multi_worker_testing_utils.py
index 5a162e7..c406228 100644
--- a/tensorflow/python/keras/distribute/multi_worker_testing_utils.py
+++ b/tensorflow/python/keras/distribute/multi_worker_testing_utils.py
@@ -14,17 +14,34 @@
 # ==============================================================================
 """Utilities for testing multi-worker distribution strategies with Keras."""
 
+import threading
+import unittest
+
+from tensorflow.core.protobuf import config_pb2
 from tensorflow.python import keras
 from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.distribute import multi_worker_test_base
 from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
 from tensorflow.python.keras.optimizer_v2 import gradient_descent
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import server_lib
 from tensorflow.python.training.server_lib import ClusterSpec
 
 
+_portpicker_import_error = None
+try:
+  import portpicker  # pylint: disable=g-import-not-at-top
+except (ImportError, ModuleNotFoundError) as _error:  # pylint: disable=invalid-name
+  _portpicker_import_error = _error
+  portpicker = None
+
+ASSIGNED_PORTS = set()
+lock = threading.Lock()
+
+
 def mnist_synthetic_dataset(batch_size, steps_per_epoch):
   """Generate synthetic MNIST dataset for testing."""
   # train dataset
@@ -76,6 +93,151 @@
 
 
 def make_parameter_server_cluster(num_workers, num_ps):
-  cluster_def = multi_worker_test_base.create_in_process_cluster(
+  cluster_def = create_in_process_cluster(
       num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
   return SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc")
+
+
+def pick_unused_port():
+  """Returns an unused and unassigned local port."""
+  if _portpicker_import_error:
+    raise _portpicker_import_error  # pylint: disable=raising-bad-type
+
+  global ASSIGNED_PORTS
+  with lock:
+    while True:
+      try:
+        port = portpicker.pick_unused_port()
+      except portpicker.NoFreePortFoundError:
+        raise unittest.SkipTest("Flakes in portpicker library do not represent "
+                                "TensorFlow errors.")
+      if port > 10000 and port not in ASSIGNED_PORTS:
+        ASSIGNED_PORTS.add(port)
+        logging.info("Using local port %r", port)
+        return port
+
+
+def _create_cluster(num_workers,
+                    num_ps,
+                    has_chief=False,
+                    has_eval=False,
+                    protocol="grpc",
+                    worker_config=None,
+                    ps_config=None,
+                    eval_config=None,
+                    worker_name="worker",
+                    ps_name="ps",
+                    chief_name="chief"):
+  """Creates and starts local servers and returns the cluster_spec dict."""
+  if _portpicker_import_error:
+    raise _portpicker_import_error  # pylint: disable=raising-bad-type
+  worker_ports = [pick_unused_port() for _ in range(num_workers)]
+  ps_ports = [pick_unused_port() for _ in range(num_ps)]
+
+  cluster_dict = {}
+  if num_workers > 0:
+    cluster_dict[worker_name] = ["localhost:%s" % port for port in worker_ports]
+  if num_ps > 0:
+    cluster_dict[ps_name] = ["localhost:%s" % port for port in ps_ports]
+  if has_eval:
+    cluster_dict["evaluator"] = ["localhost:%s" % pick_unused_port()]
+  if has_chief:
+    cluster_dict[chief_name] = ["localhost:%s" % pick_unused_port()]
+
+  cs = server_lib.ClusterSpec(cluster_dict)
+
+  for i in range(num_workers):
+    server_lib.Server(
+        cs,
+        job_name=worker_name,
+        protocol=protocol,
+        task_index=i,
+        config=worker_config,
+        start=True)
+
+  for i in range(num_ps):
+    server_lib.Server(
+        cs,
+        job_name=ps_name,
+        protocol=protocol,
+        task_index=i,
+        config=ps_config,
+        start=True)
+
+  if has_chief:
+    server_lib.Server(
+        cs,
+        job_name=chief_name,
+        protocol=protocol,
+        task_index=0,
+        config=worker_config,
+        start=True)
+
+  if has_eval:
+    server_lib.Server(
+        cs,
+        job_name="evaluator",
+        protocol=protocol,
+        task_index=0,
+        config=eval_config,
+        start=True)
+
+  return cluster_dict
+
+
+def create_in_process_cluster(num_workers,
+                              num_ps,
+                              has_chief=False,
+                              has_eval=False,
+                              rpc_layer="grpc"):
+  """Create an in-process cluster that consists of only standard server."""
+  # Leave some memory for cuda runtime.
+  gpu_mem_frac = 0.7 / (num_workers + int(has_chief) + int(has_eval))
+  worker_config = config_pb2.ConfigProto()
+  worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
+
+  # The cluster may hang if workers don't have enough inter_op threads. See
+  # b/172296720 for more details.
+  if worker_config.inter_op_parallelism_threads < num_workers + 1:
+    worker_config.inter_op_parallelism_threads = num_workers + 1
+
+  # Enable collective ops which has no impact on non-collective ops.
+  if has_chief:
+    worker_config.experimental.collective_group_leader = (
+        "/job:chief/replica:0/task:0")
+  else:
+    worker_config.experimental.collective_group_leader = (
+        "/job:worker/replica:0/task:0")
+
+  ps_config = config_pb2.ConfigProto()
+  ps_config.device_count["GPU"] = 0
+
+  eval_config = config_pb2.ConfigProto()
+  eval_config.experimental.collective_group_leader = ""
+
+  # Create in-process servers. Once an in-process tensorflow server is created,
+  # there is no way to terminate it. So we create one cluster per test process.
+  # We could've started the server in another process, we could then kill that
+  # process to terminate the server. The reasons why we don"t want multiple
+  # processes are
+  # 1) it is more difficult to manage these processes;
+  # 2) there is something global in CUDA such that if we initialize CUDA in the
+  # parent process, the child process cannot initialize it again and thus cannot
+  # use GPUs (https://stackoverflow.com/questions/22950047).
+  cluster = None
+  try:
+    cluster = _create_cluster(
+        num_workers,
+        num_ps=num_ps,
+        has_chief=has_chief,
+        has_eval=has_eval,
+        worker_config=worker_config,
+        ps_config=ps_config,
+        eval_config=eval_config,
+        protocol=rpc_layer)
+  except errors.UnknownError as e:
+    if "Could not start gRPC server" in e.message:
+      raise unittest.SkipTest("Cannot start std servers.")
+    else:
+      raise
+  return cluster
diff --git a/tensorflow/python/keras/utils/BUILD b/tensorflow/python/keras/utils/BUILD
index c1a5787..929badd 100644
--- a/tensorflow/python/keras/utils/BUILD
+++ b/tensorflow/python/keras/utils/BUILD
@@ -276,6 +276,7 @@
         ":dataset_creator",
         "//tensorflow/python/distribute:multi_worker_test_base",
         "//tensorflow/python/keras:combinations",
+        "//tensorflow/python/keras/distribute:multi_worker_testing_utils",
         "//tensorflow/python/keras/engine",
         "//tensorflow/python/keras/layers:core",
     ],
diff --git a/tensorflow/python/keras/utils/dataset_creator_test.py b/tensorflow/python/keras/utils/dataset_creator_test.py
index 84f33a6..9646344 100644
--- a/tensorflow/python/keras/utils/dataset_creator_test.py
+++ b/tensorflow/python/keras/utils/dataset_creator_test.py
@@ -19,11 +19,11 @@
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.distribute import collective_all_reduce_strategy
 from tensorflow.python.distribute import distribute_lib
-from tensorflow.python.distribute import multi_worker_test_base
 from tensorflow.python.distribute import parameter_server_strategy_v2
 from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
 from tensorflow.python.distribute.coordinator import cluster_coordinator
 from tensorflow.python.keras import combinations
+from tensorflow.python.keras.distribute import multi_worker_testing_utils
 from tensorflow.python.keras.engine import data_adapter
 from tensorflow.python.keras.engine import sequential
 from tensorflow.python.keras.layers import core as core_layers
@@ -80,7 +80,7 @@
     self.assertLen(history.history["loss"], 10)
 
   def _get_parameter_server_strategy(self):
-    cluster_def = multi_worker_test_base.create_in_process_cluster(
+    cluster_def = multi_worker_testing_utils.create_in_process_cluster(
         num_workers=2, num_ps=1, rpc_layer="grpc")
     return parameter_server_strategy_v2.ParameterServerStrategyV2(
         SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc"))