Reset worker cycle iterator for determinism across runs (#73675)

Summary:
Reset worker cycle iterator for determinism across runs

Fixes https://github.com/pytorch/pytorch/issues/73603

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73675

Reviewed By: bdhirsh

Differential Revision: D34688704

Pulled By: ejguan

fbshipit-source-id: 7bab11f0b9f59645d9b168fa11d92dc7c2c4d34e
(cherry picked from commit eb5fd559224988f9967528e154cf37c5031fe7c2)
diff --git a/test/test_dataloader.py b/test/test_dataloader.py
index 9a1e829..c00cebd 100644
--- a/test/test_dataloader.py
+++ b/test/test_dataloader.py
@@ -842,6 +842,21 @@
         return int(math.ceil(len(self.dataset) / float(self.batch_size)))
 
 
+class TestMultiEpochDataset(IterableDataset):
+    def __init__(self, length):
+        self.length = length
+
+    def __iter__(self):
+        worker_info = torch.utils.data.get_worker_info()
+        assert worker_info is not None
+        worker_id = worker_info.id
+        for idx in range(self.length // worker_info.num_workers):
+            yield worker_id
+
+    def __len__(self):
+        return self.length
+
+
 class CustomList(list):
     pass
 
@@ -1426,6 +1441,19 @@
         dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers)
         self.assertEqual(set(int(batch) for batch in get_dataloader()), set(int(batch) for batch in get_dataloader()))
 
+    def test_multi_epochs_reproducibility(self):
+        num_workers = 2
+        batch_size = 10
+        num_epochs = 3
+
+        dataset = TestMultiEpochDataset(batch_size * num_workers)
+        dataloader = self._get_data_loader(dataset, batch_size=batch_size,
+                                           shuffle=False, num_workers=num_workers)
+
+        for ind in range(num_epochs):
+            for batch_idx, sample in enumerate(dataloader):
+                self.assertEqual(sample.tolist(), [batch_idx % num_workers] * batch_size)
+
     def test_worker_init_fn(self):
         dataset = SeedDataset(4)
         dataloader = self._get_data_loader(dataset, batch_size=2, num_workers=2,
diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py
index d105769..06adf84 100644
--- a/torch/utils/data/dataloader.py
+++ b/torch/utils/data/dataloader.py
@@ -895,7 +895,6 @@
             multiprocessing_context = loader.multiprocessing_context
 
         self._worker_init_fn = loader.worker_init_fn
-        self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
         # No certainty which module multiprocessing_context is
         self._worker_result_queue = multiprocessing_context.Queue()  # type: ignore[var-annotated]
         self._worker_pids_set = False
@@ -981,6 +980,8 @@
         # It does not mean that a worker is dead. In case of `_persistent_workers`,
         # the worker will be reset to available in the next epoch.
         self._workers_status = [True for i in range(self._num_workers)]
+        # Reset the worker queue cycle so it resumes next epoch at worker 0
+        self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
         # We resume the prefetching in case it was enabled
         if not first_iter:
             for idx in range(self._num_workers):