blob: cc4e4b6bfd13bf7ccecbeb224c3d94d63291d0c0 [file] [log] [blame]
import torch.utils.data.sharding
def worker_init_fn(worker_id):
info = torch.utils.data.get_worker_info()
num_workers = info.num_workers
datapipe = info.dataset
torch.utils.data.sharding.apply_sharding(datapipe, num_workers, worker_id)