[Resubmit] Fix for incorrect usage of logging in torch/distributed/distributed_c10d.py (#52757)
Summary:
Resubmit of https://github.com/pytorch/pytorch/pull/51739
Fixes https://github.com/pytorch/pytorch/issues/51428
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52757
Reviewed By: cbalioglu
Differential Revision: D26646843
fbshipit-source-id: df4962ef86ea465307e39878860b9fbbcc958d52
diff --git a/test/distributed/test_c10d.py b/test/distributed/test_c10d.py
index c9afbc5..5c31b17 100644
--- a/test/distributed/test_c10d.py
+++ b/test/distributed/test_c10d.py
@@ -1,4 +1,5 @@
import copy
+import logging
import math
import operator
import os
@@ -595,6 +596,24 @@
# check with get
self.assertEqual(b"value0", store0.get("key0"))
+ @retry_on_connect_failures
+ def test_logging_init(self):
+ os.environ["WORLD_SIZE"] = "1"
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
+ os.environ["MASTER_PORT"] = str(common.find_free_port())
+ os.environ["RANK"] = "0"
+
+ previous_handlers = logging.root.handlers
+
+ c10d.init_process_group(backend="gloo", init_method="env://")
+
+ current_handlers = logging.root.handlers
+ self.assertEqual(len(previous_handlers), len(current_handlers))
+ for current, previous in zip(current_handlers, previous_handlers):
+ self.assertEqual(current, previous)
+
+ c10d.destroy_process_group()
+
class RendezvousFileTest(TestCase):
def test_common_errors(self):
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index f641b23..5c6555f 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -53,6 +53,10 @@
except ImportError:
_GLOO_AVAILABLE = False
+
+logger = logging.getLogger(__name__)
+
+
# Some reduce ops are not supported by complex numbers and will result in an error.
# We currently provide complex support to the distributed API by viewing
# complex tensors as real (torch.view_as_real), meaning that calling
@@ -188,7 +192,7 @@
"""
store_key = "{}:{}".format(STORE_BASED_BARRIER_PREFIX, _group_count)
store.add(store_key, 1)
- logging.info('Added key: {} to store for rank: {}'.format(store_key, rank))
+ logger.info('Added key: {} to store for rank: {}'.format(store_key, rank))
# Now wait for all workers to check in with the store.
world_size = get_world_size()
@@ -206,7 +210,7 @@
# Print status periodically to keep track.
if timedelta(seconds=(time.time() - log_time)) > timedelta(seconds=10):
- logging.info(
+ logger.info(
"Waiting in store based barrier to initialize process group for "
"rank: {}, key: {} (world_size={}, worker_count={}, timeout={})".format(
rank, store_key, world_size, worker_count, timeout))