[BE][PT-D] Fix race on checkpoint file (#84881)
Without calling `dist.barrier()` before removing the checkpoint file, rank 0 may run ahead and delete the checkpoint file before nonzero ranks are able to load from the checkpoint.
This PR adds a `dist.barrier()` to ensure all ranks can load the checkpoint before rank 0 deletes it.
For example, including the added `dist.barrier()`:
https://github.com/pytorch/pytorch/blob/037e8eefcf0b669430211b83d19aedf2185ed6fc/torch/testing/_internal/distributed/distributed_test.py#L5068-L5098
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84881
Approved by: https://github.com/rohan-varma
diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py
index cb05593..7f30129 100644
--- a/torch/testing/_internal/distributed/distributed_test.py
+++ b/torch/testing/_internal/distributed/distributed_test.py
@@ -5108,6 +5108,7 @@
self.assertEqual(averager2.step, 0)
+ dist.barrier()
if self.rank == 0:
os.remove(chkpt_file)
@@ -9026,6 +9027,7 @@
for orig_param, dummy_param in zip(ddp_model.parameters(), dummy_ddp_model.parameters()):
self.assertEqual(orig_param.grad, dummy_param.grad)
+ dist.barrier()
if rank == 0:
os.remove(chkpt_file)