|  | import math | 
|  | from typing import TypeVar, Optional, Iterator | 
|  |  | 
|  | import torch | 
|  | from . import Sampler, Dataset | 
|  | import torch.distributed as dist | 
|  |  | 
|  | __all__ = ["DistributedSampler", ] | 
|  |  | 
|  | T_co = TypeVar('T_co', covariant=True) | 
|  |  | 
|  |  | 
|  | class DistributedSampler(Sampler[T_co]): | 
|  | r"""Sampler that restricts data loading to a subset of the dataset. | 
|  |  | 
|  | It is especially useful in conjunction with | 
|  | :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each | 
|  | process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a | 
|  | :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the | 
|  | original dataset that is exclusive to it. | 
|  |  | 
|  | .. note:: | 
|  | Dataset is assumed to be of constant size and that any instance of it always | 
|  | returns the same elements in the same order. | 
|  |  | 
|  | Args: | 
|  | dataset: Dataset used for sampling. | 
|  | num_replicas (int, optional): Number of processes participating in | 
|  | distributed training. By default, :attr:`world_size` is retrieved from the | 
|  | current distributed group. | 
|  | rank (int, optional): Rank of the current process within :attr:`num_replicas`. | 
|  | By default, :attr:`rank` is retrieved from the current distributed | 
|  | group. | 
|  | shuffle (bool, optional): If ``True`` (default), sampler will shuffle the | 
|  | indices. | 
|  | seed (int, optional): random seed used to shuffle the sampler if | 
|  | :attr:`shuffle=True`. This number should be identical across all | 
|  | processes in the distributed group. Default: ``0``. | 
|  | drop_last (bool, optional): if ``True``, then the sampler will drop the | 
|  | tail of the data to make it evenly divisible across the number of | 
|  | replicas. If ``False``, the sampler will add extra indices to make | 
|  | the data evenly divisible across the replicas. Default: ``False``. | 
|  |  | 
|  | .. warning:: | 
|  | In distributed mode, calling the :meth:`set_epoch` method at | 
|  | the beginning of each epoch **before** creating the :class:`DataLoader` iterator | 
|  | is necessary to make shuffling work properly across multiple epochs. Otherwise, | 
|  | the same ordering will be always used. | 
|  |  | 
|  | Example:: | 
|  |  | 
|  | >>> # xdoctest: +SKIP | 
|  | >>> sampler = DistributedSampler(dataset) if is_distributed else None | 
|  | >>> loader = DataLoader(dataset, shuffle=(sampler is None), | 
|  | ...                     sampler=sampler) | 
|  | >>> for epoch in range(start_epoch, n_epochs): | 
|  | ...     if is_distributed: | 
|  | ...         sampler.set_epoch(epoch) | 
|  | ...     train(loader) | 
|  | """ | 
|  |  | 
|  | def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, | 
|  | rank: Optional[int] = None, shuffle: bool = True, | 
|  | seed: int = 0, drop_last: bool = False) -> None: | 
|  | if num_replicas is None: | 
|  | if not dist.is_available(): | 
|  | raise RuntimeError("Requires distributed package to be available") | 
|  | num_replicas = dist.get_world_size() | 
|  | if rank is None: | 
|  | if not dist.is_available(): | 
|  | raise RuntimeError("Requires distributed package to be available") | 
|  | rank = dist.get_rank() | 
|  | if rank >= num_replicas or rank < 0: | 
|  | raise ValueError( | 
|  | "Invalid rank {}, rank should be in the interval" | 
|  | " [0, {}]".format(rank, num_replicas - 1)) | 
|  | self.dataset = dataset | 
|  | self.num_replicas = num_replicas | 
|  | self.rank = rank | 
|  | self.epoch = 0 | 
|  | self.drop_last = drop_last | 
|  | # If the dataset length is evenly divisible by # of replicas, then there | 
|  | # is no need to drop any data, since the dataset will be split equally. | 
|  | if self.drop_last and len(self.dataset) % self.num_replicas != 0:  # type: ignore[arg-type] | 
|  | # Split to nearest available length that is evenly divisible. | 
|  | # This is to ensure each rank receives the same amount of data when | 
|  | # using this Sampler. | 
|  | self.num_samples = math.ceil( | 
|  | (len(self.dataset) - self.num_replicas) / self.num_replicas  # type: ignore[arg-type] | 
|  | ) | 
|  | else: | 
|  | self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)  # type: ignore[arg-type] | 
|  | self.total_size = self.num_samples * self.num_replicas | 
|  | self.shuffle = shuffle | 
|  | self.seed = seed | 
|  |  | 
|  | def __iter__(self) -> Iterator[T_co]: | 
|  | if self.shuffle: | 
|  | # deterministically shuffle based on epoch and seed | 
|  | g = torch.Generator() | 
|  | g.manual_seed(self.seed + self.epoch) | 
|  | indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type] | 
|  | else: | 
|  | indices = list(range(len(self.dataset)))  # type: ignore[arg-type] | 
|  |  | 
|  | if not self.drop_last: | 
|  | # add extra samples to make it evenly divisible | 
|  | padding_size = self.total_size - len(indices) | 
|  | if padding_size <= len(indices): | 
|  | indices += indices[:padding_size] | 
|  | else: | 
|  | indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] | 
|  | else: | 
|  | # remove tail of data to make it evenly divisible. | 
|  | indices = indices[:self.total_size] | 
|  | assert len(indices) == self.total_size | 
|  |  | 
|  | # subsample | 
|  | indices = indices[self.rank:self.total_size:self.num_replicas] | 
|  | assert len(indices) == self.num_samples | 
|  |  | 
|  | return iter(indices) | 
|  |  | 
|  | def __len__(self) -> int: | 
|  | return self.num_samples | 
|  |  | 
|  | def set_epoch(self, epoch: int) -> None: | 
|  | r""" | 
|  | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas | 
|  | use a different random ordering for each epoch. Otherwise, the next iteration of this | 
|  | sampler will yield the same ordering. | 
|  |  | 
|  | Args: | 
|  | epoch (int): Epoch number. | 
|  | """ | 
|  | self.epoch = epoch |