Revert D28240105: [pytorch][PR] Fix DistributedSampler mem usage on large datasets
Test Plan: revert-hammer
Differential Revision:
D28240105 (https://github.com/pytorch/pytorch/commit/a0ce8da26ee5d5b2842d4eacd94f6e26b610c777)
Original commit changeset: 4c6aa493d0f7
fbshipit-source-id: 8a0e17764c2f26c8316f88ad6c8772b08883ceee
diff --git a/torch/utils/data/distributed.py b/torch/utils/data/distributed.py
index 9b4fe75..7ef638a 100644
--- a/torch/utils/data/distributed.py
+++ b/torch/utils/data/distributed.py
@@ -1,12 +1,15 @@
import math
-from typing import Optional, Iterator
+from typing import TypeVar, Optional, Iterator
import torch
from . import Sampler, Dataset
import torch.distributed as dist
-class DistributedSampler(Sampler[int]):
+T_co = TypeVar('T_co', covariant=True)
+
+
+class DistributedSampler(Sampler[T_co]):
r"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
@@ -73,39 +76,49 @@
self.rank = rank
self.epoch = 0
self.drop_last = drop_last
- if self.drop_last:
- self.num_samples = len(self.dataset) // self.num_replicas # type: ignore[arg-type]
- else:
+ # If the dataset length is evenly divisible by # of replicas, then there
+ # is no need to drop any data, since the dataset will be split equally.
+ if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
+ # Split to nearest available length that is evenly divisible.
+ # This is to ensure each rank receives the same amount of data when
+ # using this Sampler.
self.num_samples = math.ceil(
- len(self.dataset) / self.num_replicas # type: ignore[arg-type]
+ # `type:ignore` is required because Dataset cannot provide a default __len__
+ # see NOTE in pytorch/torch/utils/data/sampler.py
+ (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
+ else:
+ self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
- def __iter__(self) -> Iterator[int]:
+ def __iter__(self) -> Iterator[T_co]:
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
- indices = torch.randperm(len(self.dataset), generator=g) # type: ignore[arg-type]
+ indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
- indices = range(len(self.dataset)) # type: ignore[arg-type]
-
- offset = 0
- while (offset + self.num_replicas) <= len(indices):
- yield int(indices[offset + self.rank])
- offset += self.num_replicas
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
- # find the number of samples remaining
- num_rem = len(indices) % self.num_replicas
- if num_rem:
- if self.rank < num_rem:
- yield int(indices[offset + self.rank])
- else:
- # wraparound, but mod in the case of self.rank >= len(indices)
- yield int(indices[self.rank % len(indices)])
+ # add extra samples to make it evenly divisible
+ padding_size = self.total_size - len(indices)
+ if padding_size <= len(indices):
+ indices += indices[:padding_size]
+ else:
+ indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
+ else:
+ # remove tail of data to make it evenly divisible.
+ indices = indices[:self.total_size]
+ assert len(indices) == self.total_size
+
+ # subsample
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
def __len__(self) -> int:
return self.num_samples