Sampler API described for customization. (#97338)

Explanation with examples of sampler customization added.

* fixed TypeVar
* removed unused init from Sampler class
* added examples for custom sampler and batch sampler
* Distributed sampler typing fixed.
* _InfiniteConstantSampler fixed

Fixes #92268

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97338
Approved by: https://github.com/NivekT
diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py
index e29ece5..7ab90e5 100644
--- a/torch/utils/data/dataloader.py
+++ b/torch/utils/data/dataloader.py
@@ -13,7 +13,7 @@
 import threading
 import warnings
 
-from typing import Any, Callable, Iterable, TypeVar, Generic, Sequence, List, Optional, Union
+from typing import Any, Callable, Iterable, TypeVar, Generic, List, Optional, Union
 
 import multiprocessing as python_multiprocessing
 import torch
@@ -83,14 +83,8 @@
 class _InfiniteConstantSampler(Sampler):
     r"""Analogous to ``itertools.repeat(None, None)``.
     Used as sampler for :class:`~torch.utils.data.IterableDataset`.
-
-    Args:
-        data_source (Dataset): dataset to sample from
     """
 
-    def __init__(self):
-        super().__init__(None)
-
     def __iter__(self):
         while True:
             yield None
@@ -223,8 +217,8 @@
     __initialized = False
 
     def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
-                 shuffle: Optional[bool] = None, sampler: Union[Sampler, Iterable, None] = None,
-                 batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
+                 shuffle: Optional[bool] = None, sampler: Union[Sampler[int], Iterable[int], None] = None,
+                 batch_sampler: Union[Sampler[List[int]], Iterable[List[int]], None] = None,
                  num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
                  pin_memory: bool = False, drop_last: bool = False,
                  timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
diff --git a/torch/utils/data/distributed.py b/torch/utils/data/distributed.py
index 8358204..4a2ba43 100644
--- a/torch/utils/data/distributed.py
+++ b/torch/utils/data/distributed.py
@@ -1,5 +1,5 @@
 import math
-from typing import TypeVar, Optional, Iterator
+from typing import Optional, Iterator
 
 import torch
 from . import Sampler, Dataset
@@ -7,10 +7,8 @@
 
 __all__ = ["DistributedSampler", ]
 
-T_co = TypeVar('T_co', covariant=True)
 
-
-class DistributedSampler(Sampler[T_co]):
+class DistributedSampler(Sampler[int]):
     r"""Sampler that restricts data loading to a subset of the dataset.
 
     It is especially useful in conjunction with
@@ -94,7 +92,7 @@
         self.shuffle = shuffle
         self.seed = seed
 
-    def __iter__(self) -> Iterator[T_co]:
+    def __iter__(self) -> Iterator[int]:
         if self.shuffle:
             # deterministically shuffle based on epoch and seed
             g = torch.Generator()
diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py
index 7e2bbee..e8243d9 100644
--- a/torch/utils/data/sampler.py
+++ b/torch/utils/data/sampler.py
@@ -12,23 +12,57 @@
     "WeightedRandomSampler",
 ]
 
-T_co = TypeVar('T_co', covariant=True)
+T_co = TypeVar('T_co', int, List[int], covariant=True)
 
 
 class Sampler(Generic[T_co]):
     r"""Base class for all Samplers.
 
     Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
-    way to iterate over indices of dataset elements, and a :meth:`__len__` method
+    way to iterate over indices or lists of indices (batches) of dataset elements, and a :meth:`__len__` method
     that returns the length of the returned iterators.
 
+    Args:
+        data_source (Dataset): This argument is not used and will be removed in 2.2.0.
+            You may still have custom implementation that utilizes it.
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> class AccedingSequenceLengthSampler(Sampler[int]):
+        >>>     def __init__(self, data: List[str]) -> None:
+        >>>         self.data = data
+        >>>
+        >>>     def __len__(self) -> int:
+        >>>         return len(self.data)
+        >>>
+        >>>     def __iter__(self) -> Iterator[int]:
+        >>>         sizes = torch.tensor([len(x) for x in self.data])
+        >>>         yield from torch.argsort(sizes).tolist()
+        >>>
+        >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
+        >>>     def __init__(self, data: List[str], batch_size: int) -> None:
+        >>>         self.data = data
+        >>>         self.batch_size = batch_size
+        >>>
+        >>>     def __len__(self) -> int:
+        >>>         return (len(self.data) + self.batch_size - 1) // self.batch_size
+        >>>
+        >>>     def __iter__(self) -> Iterator[List[int]]:
+        >>>         sizes = torch.tensor([len(x) for x in self.data])
+        >>>         for batch in torch.chunk(torch.argsort(sizes), len(self)):
+        >>>             yield batch.tolist()
+
     .. note:: The :meth:`__len__` method isn't strictly required by
               :class:`~torch.utils.data.DataLoader`, but is expected in any
               calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
     """
 
-    def __init__(self, data_source: Optional[Sized]) -> None:
-        pass
+    def __init__(self, data_source: Optional[Sized] = None) -> None:
+        if data_source is not None:
+            import warnings
+
+            warnings.warn("`data_source` argument is not used and will be removed in 2.2.0."
+                          "You may still have custom implementation that utilizes it.")
 
     def __iter__(self) -> Iterator[T_co]:
         raise NotImplementedError