Revert D28240105: [pytorch][PR] Fix DistributedSampler mem usage on large datasets

Test Plan: revert-hammer

Differential Revision:
D28240105 (https://github.com/pytorch/pytorch/commit/a0ce8da26ee5d5b2842d4eacd94f6e26b610c777)

Original commit changeset: 4c6aa493d0f7

fbshipit-source-id: 8a0e17764c2f26c8316f88ad6c8772b08883ceee
diff --git a/torch/utils/data/distributed.py b/torch/utils/data/distributed.py
index 9b4fe75..7ef638a 100644
--- a/torch/utils/data/distributed.py
+++ b/torch/utils/data/distributed.py
@@ -1,12 +1,15 @@
 import math
-from typing import Optional, Iterator
+from typing import TypeVar, Optional, Iterator
 
 import torch
 from . import Sampler, Dataset
 import torch.distributed as dist
 
 
-class DistributedSampler(Sampler[int]):
+T_co = TypeVar('T_co', covariant=True)
+
+
+class DistributedSampler(Sampler[T_co]):
     r"""Sampler that restricts data loading to a subset of the dataset.
 
     It is especially useful in conjunction with
@@ -73,39 +76,49 @@
         self.rank = rank
         self.epoch = 0
         self.drop_last = drop_last
-        if self.drop_last:
-            self.num_samples = len(self.dataset) // self.num_replicas  # type: ignore[arg-type]
-        else:
+        # If the dataset length is evenly divisible by # of replicas, then there
+        # is no need to drop any data, since the dataset will be split equally.
+        if self.drop_last and len(self.dataset) % self.num_replicas != 0:  # type: ignore[arg-type]
+            # Split to nearest available length that is evenly divisible.
+            # This is to ensure each rank receives the same amount of data when
+            # using this Sampler.
             self.num_samples = math.ceil(
-                len(self.dataset) / self.num_replicas  # type: ignore[arg-type]
+                # `type:ignore` is required because Dataset cannot provide a default __len__
+                # see NOTE in pytorch/torch/utils/data/sampler.py
+                (len(self.dataset) - self.num_replicas) / self.num_replicas  # type: ignore[arg-type]
             )
+        else:
+            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)  # type: ignore[arg-type]
         self.total_size = self.num_samples * self.num_replicas
         self.shuffle = shuffle
         self.seed = seed
 
-    def __iter__(self) -> Iterator[int]:
+    def __iter__(self) -> Iterator[T_co]:
         if self.shuffle:
             # deterministically shuffle based on epoch and seed
             g = torch.Generator()
             g.manual_seed(self.seed + self.epoch)
-            indices = torch.randperm(len(self.dataset), generator=g)  # type: ignore[arg-type]
+            indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]
         else:
-            indices = range(len(self.dataset))  # type: ignore[arg-type]
-
-        offset = 0
-        while (offset + self.num_replicas) <= len(indices):
-            yield int(indices[offset + self.rank])
-            offset += self.num_replicas
+            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]
 
         if not self.drop_last:
-            # find the number of samples remaining
-            num_rem = len(indices) % self.num_replicas
-            if num_rem:
-                if self.rank < num_rem:
-                    yield int(indices[offset + self.rank])
-                else:
-                    # wraparound, but mod in the case of self.rank >= len(indices)
-                    yield int(indices[self.rank % len(indices)])
+            # add extra samples to make it evenly divisible
+            padding_size = self.total_size - len(indices)
+            if padding_size <= len(indices):
+                indices += indices[:padding_size]
+            else:
+                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
+        else:
+            # remove tail of data to make it evenly divisible.
+            indices = indices[:self.total_size]
+        assert len(indices) == self.total_size
+
+        # subsample
+        indices = indices[self.rank:self.total_size:self.num_replicas]
+        assert len(indices) == self.num_samples
+
+        return iter(indices)
 
     def __len__(self) -> int:
         return self.num_samples