Make portpicking unique.
PiperOrigin-RevId: 448112092
diff --git a/tensorflow/python/distribute/coordinator/cluster_coordinator_test.py b/tensorflow/python/distribute/coordinator/cluster_coordinator_test.py
index c91be2f..9c8b4e3 100644
--- a/tensorflow/python/distribute/coordinator/cluster_coordinator_test.py
+++ b/tensorflow/python/distribute/coordinator/cluster_coordinator_test.py
@@ -447,7 +447,7 @@
cluster_def = multi_worker_test_base.create_in_process_cluster(
num_workers=num_workers, num_ps=num_ps, rpc_layer='grpc')
cluster_def['chief'] = [
- 'localhost:%d' % multi_worker_test_base.pick_unused_port()
+ 'localhost:%d' % test_util.pick_unused_port()
]
cluster_resolver = SimpleClusterResolver(
ClusterSpec(cluster_def), rpc_layer='grpc')
diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py
index afedb5f..3a77dfe 100644
--- a/tensorflow/python/distribute/distribute_coordinator_test.py
+++ b/tensorflow/python/distribute/distribute_coordinator_test.py
@@ -24,13 +24,6 @@
import six
-_portpicker_import_error = None
-try:
- import portpicker # pylint: disable=g-import-not-at-top
-except ImportError as _error: # pylint: disable=invalid-name
- _portpicker_import_error = _error
- portpicker = None
-
# pylint: disable=g-import-not-at-top
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
@@ -195,23 +188,20 @@
num_workers=1,
num_ps=0,
has_eval=False):
- if _portpicker_import_error:
- raise _portpicker_import_error # pylint: disable=raising-bad-type
-
cluster_spec = {}
if has_chief:
- cluster_spec[CHIEF] = ["localhost:%s" % portpicker.pick_unused_port()]
+ cluster_spec[CHIEF] = ["localhost:%s" % test_util.pick_unused_port()]
if num_workers:
cluster_spec[WORKER] = [
- "localhost:%s" % portpicker.pick_unused_port()
+ "localhost:%s" % test_util.pick_unused_port()
for _ in range(num_workers)
]
if num_ps:
cluster_spec[PS] = [
- "localhost:%s" % portpicker.pick_unused_port() for _ in range(num_ps)
+ "localhost:%s" % test_util.pick_unused_port() for _ in range(num_ps)
]
if has_eval:
- cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
+ cluster_spec[EVALUATOR] = ["localhost:%s" % test_util.pick_unused_port()]
return cluster_spec
def _in_graph_worker_fn(self, strategy):
diff --git a/tensorflow/python/distribute/multi_worker_test_base.py b/tensorflow/python/distribute/multi_worker_test_base.py
index b55336d..9615a8e 100644
--- a/tensorflow/python/distribute/multi_worker_test_base.py
+++ b/tensorflow/python/distribute/multi_worker_test_base.py
@@ -25,12 +25,6 @@
import six
-_portpicker_import_error = None
-try:
- import portpicker # pylint: disable=g-import-not-at-top
-except (ImportError, ModuleNotFoundError) as _error: # pylint: disable=invalid-name
- _portpicker_import_error = _error
- portpicker = None
# pylint: disable=g-import-not-at-top
from tensorflow.core.protobuf import config_pb2
@@ -56,28 +50,7 @@
original_run_std_server = dc._run_std_server # pylint: disable=protected-access
-
-ASSIGNED_PORTS = set()
-lock = threading.Lock()
-
-
-def pick_unused_port():
- """Returns an unused and unassigned local port."""
- if _portpicker_import_error:
- raise _portpicker_import_error # pylint: disable=raising-bad-type
-
- global ASSIGNED_PORTS
- with lock:
- while True:
- try:
- port = portpicker.pick_unused_port()
- except portpicker.NoFreePortFoundError:
- raise unittest.SkipTest('Flakes in portpicker library do not represent '
- 'TensorFlow errors.')
- if port > 10000 and port not in ASSIGNED_PORTS:
- ASSIGNED_PORTS.add(port)
- logging.info('Using local port %r', port)
- return port
+pick_unused_port = test_util.pick_unused_port
def _create_cluster(num_workers,
@@ -92,8 +65,7 @@
ps_name='ps',
chief_name='chief'):
"""Creates and starts local servers and returns the cluster_spec dict."""
- if _portpicker_import_error:
- raise _portpicker_import_error # pylint: disable=raising-bad-type
+
worker_ports = [pick_unused_port() for _ in range(num_workers)]
ps_ports = [pick_unused_port() for _ in range(num_ps)]
@@ -424,9 +396,6 @@
# {'evaluator': ['localhost:23381']}
```
"""
- if _portpicker_import_error:
- raise _portpicker_import_error # pylint: disable=raising-bad-type
-
cluster_spec = {}
if has_chief:
cluster_spec['chief'] = ['localhost:%s' % pick_unused_port()]
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 55442ff..388bd7e 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -3717,6 +3717,28 @@
return self._cached_session
+ASSIGNED_PORTS = set()
+lock = threading.Lock()
+
+
+def pick_unused_port():
+ """Returns an unused and unassigned local port."""
+ import portpicker # pylint: disable=g-import-not-at-top
+
+ global ASSIGNED_PORTS
+ with lock:
+ while True:
+ try:
+ port = portpicker.pick_unused_port()
+ except portpicker.NoFreePortFoundError as porterror:
+ raise unittest.SkipTest("Flakes in portpicker library do not represent"
+ " TensorFlow errors.") from porterror
+ if port > 10000 and port not in ASSIGNED_PORTS:
+ ASSIGNED_PORTS.add(port)
+ logging.info("Using local port %r", port)
+ return port
+
+
@tf_export("test.create_local_cluster")
def create_local_cluster(num_workers,
num_ps,
@@ -3775,9 +3797,8 @@
Raises:
ImportError: if portpicker module was not found at load time
"""
- import portpicker # pylint: disable=g-import-not-at-top
- worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
- ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+ worker_ports = [pick_unused_port() for _ in range(num_workers)]
+ ps_ports = [pick_unused_port() for _ in range(num_ps)]
cluster_dict = {
"worker": ["localhost:%s" % port for port in worker_ports],
"ps": ["localhost:%s" % port for port in ps_ports]