Sampler API described for customization. (#97338)
Explanation with examples of sampler customization added.
* fixed TypeVar
* removed unused init from Sampler class
* added examples for custom sampler and batch sampler
* Distributed sampler typing fixed.
* _InfiniteConstantSampler fixed
Fixes #92268
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97338
Approved by: https://github.com/NivekT
diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py
index e29ece5..7ab90e5 100644
--- a/torch/utils/data/dataloader.py
+++ b/torch/utils/data/dataloader.py
@@ -13,7 +13,7 @@
import threading
import warnings
-from typing import Any, Callable, Iterable, TypeVar, Generic, Sequence, List, Optional, Union
+from typing import Any, Callable, Iterable, TypeVar, Generic, List, Optional, Union
import multiprocessing as python_multiprocessing
import torch
@@ -83,14 +83,8 @@
class _InfiniteConstantSampler(Sampler):
r"""Analogous to ``itertools.repeat(None, None)``.
Used as sampler for :class:`~torch.utils.data.IterableDataset`.
-
- Args:
- data_source (Dataset): dataset to sample from
"""
- def __init__(self):
- super().__init__(None)
-
def __iter__(self):
while True:
yield None
@@ -223,8 +217,8 @@
__initialized = False
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
- shuffle: Optional[bool] = None, sampler: Union[Sampler, Iterable, None] = None,
- batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
+ shuffle: Optional[bool] = None, sampler: Union[Sampler[int], Iterable[int], None] = None,
+ batch_sampler: Union[Sampler[List[int]], Iterable[List[int]], None] = None,
num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
diff --git a/torch/utils/data/distributed.py b/torch/utils/data/distributed.py
index 8358204..4a2ba43 100644
--- a/torch/utils/data/distributed.py
+++ b/torch/utils/data/distributed.py
@@ -1,5 +1,5 @@
import math
-from typing import TypeVar, Optional, Iterator
+from typing import Optional, Iterator
import torch
from . import Sampler, Dataset
@@ -7,10 +7,8 @@
__all__ = ["DistributedSampler", ]
-T_co = TypeVar('T_co', covariant=True)
-
-class DistributedSampler(Sampler[T_co]):
+class DistributedSampler(Sampler[int]):
r"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
@@ -94,7 +92,7 @@
self.shuffle = shuffle
self.seed = seed
- def __iter__(self) -> Iterator[T_co]:
+ def __iter__(self) -> Iterator[int]:
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py
index 7e2bbee..e8243d9 100644
--- a/torch/utils/data/sampler.py
+++ b/torch/utils/data/sampler.py
@@ -12,23 +12,57 @@
"WeightedRandomSampler",
]
-T_co = TypeVar('T_co', covariant=True)
+T_co = TypeVar('T_co', int, List[int], covariant=True)
class Sampler(Generic[T_co]):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
- way to iterate over indices of dataset elements, and a :meth:`__len__` method
+ way to iterate over indices or lists of indices (batches) of dataset elements, and a :meth:`__len__` method
that returns the length of the returned iterators.
+ Args:
+ data_source (Dataset): This argument is not used and will be removed in 2.2.0.
+ You may still have custom implementation that utilizes it.
+
+ Example:
+ >>> # xdoctest: +SKIP
+ >>> class AccedingSequenceLengthSampler(Sampler[int]):
+ >>> def __init__(self, data: List[str]) -> None:
+ >>> self.data = data
+ >>>
+ >>> def __len__(self) -> int:
+ >>> return len(self.data)
+ >>>
+ >>> def __iter__(self) -> Iterator[int]:
+ >>> sizes = torch.tensor([len(x) for x in self.data])
+ >>> yield from torch.argsort(sizes).tolist()
+ >>>
+ >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
+ >>> def __init__(self, data: List[str], batch_size: int) -> None:
+ >>> self.data = data
+ >>> self.batch_size = batch_size
+ >>>
+ >>> def __len__(self) -> int:
+ >>> return (len(self.data) + self.batch_size - 1) // self.batch_size
+ >>>
+ >>> def __iter__(self) -> Iterator[List[int]]:
+ >>> sizes = torch.tensor([len(x) for x in self.data])
+ >>> for batch in torch.chunk(torch.argsort(sizes), len(self)):
+ >>> yield batch.tolist()
+
.. note:: The :meth:`__len__` method isn't strictly required by
:class:`~torch.utils.data.DataLoader`, but is expected in any
calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
"""
- def __init__(self, data_source: Optional[Sized]) -> None:
- pass
+ def __init__(self, data_source: Optional[Sized] = None) -> None:
+ if data_source is not None:
+ import warnings
+
+ warnings.warn("`data_source` argument is not used and will be removed in 2.2.0."
+ "You may still have custom implementation that utilizes it.")
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError