Reset worker cycle iterator for determinism across runs (#73675)
Summary:
Reset worker cycle iterator for determinism across runs
Fixes https://github.com/pytorch/pytorch/issues/73603
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73675
Reviewed By: bdhirsh
Differential Revision: D34688704
Pulled By: ejguan
fbshipit-source-id: 7bab11f0b9f59645d9b168fa11d92dc7c2c4d34e
(cherry picked from commit eb5fd559224988f9967528e154cf37c5031fe7c2)
diff --git a/test/test_dataloader.py b/test/test_dataloader.py
index 9a1e829..c00cebd 100644
--- a/test/test_dataloader.py
+++ b/test/test_dataloader.py
@@ -842,6 +842,21 @@
return int(math.ceil(len(self.dataset) / float(self.batch_size)))
+class TestMultiEpochDataset(IterableDataset):
+ def __init__(self, length):
+ self.length = length
+
+ def __iter__(self):
+ worker_info = torch.utils.data.get_worker_info()
+ assert worker_info is not None
+ worker_id = worker_info.id
+ for idx in range(self.length // worker_info.num_workers):
+ yield worker_id
+
+ def __len__(self):
+ return self.length
+
+
class CustomList(list):
pass
@@ -1426,6 +1441,19 @@
dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
self.assertEqual(set(int(batch) for batch in get_dataloader()), set(int(batch) for batch in get_dataloader()))
+ def test_multi_epochs_reproducibility(self):
+ num_workers = 2
+ batch_size = 10
+ num_epochs = 3
+
+ dataset = TestMultiEpochDataset(batch_size * num_workers)
+ dataloader = self._get_data_loader(dataset, batch_size=batch_size,
+ shuffle=False, num_workers=num_workers)
+
+ for ind in range(num_epochs):
+ for batch_idx, sample in enumerate(dataloader):
+ self.assertEqual(sample.tolist(), [batch_idx % num_workers] * batch_size)
+
def test_worker_init_fn(self):
dataset = SeedDataset(4)
dataloader = self._get_data_loader(dataset, batch_size=2, num_workers=2,
diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py
index d105769..06adf84 100644
--- a/torch/utils/data/dataloader.py
+++ b/torch/utils/data/dataloader.py
@@ -895,7 +895,6 @@
multiprocessing_context = loader.multiprocessing_context
self._worker_init_fn = loader.worker_init_fn
- self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
# No certainty which module multiprocessing_context is
self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
self._worker_pids_set = False
@@ -981,6 +980,8 @@
# It does not mean that a worker is dead. In case of `_persistent_workers`,
# the worker will be reset to available in the next epoch.
self._workers_status = [True for i in range(self._num_workers)]
+ # Reset the worker queue cycle so it resumes next epoch at worker 0
+ self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
# We resume the prefetching in case it was enabled
if not first_iter:
for idx in range(self._num_workers):