Fix: Make `__len__` of datapipes dynamic (#88302)
Fixes #88074
Several datapipes have their lengths cached on being executed for the first time. However, source datapipes might change in length (most prominently, whenever `apply_sharding` is called). The behaviour is counter-intuitive because we do not expect `__len__` to have side-effects.
This PR makes `__len__` dynamically computed.
Changes:
- Add note to the `datapipes` README that `__len__` should be dynamic and why.
- Remove caching of length computations in `ConcaterIterDataPipe`, `MultiplexerIterDataPipe`, `ZipperIterDataPipe`, `BatcherIterDataPipe`, `ConcaterMapDataPipe`, and `BatcherMapDataPipe`.
- This required removal of the `length` attribute in setstate/getstate of `MultiplexerIterDataPipe`. I am unsure whether to remove this completely and risk breaking saved checkpoints (as I did) or whether to just ignore the `length` of the loaded `state`.
- This also means the classes above no longer have a `length` attribute. I have found no uses of this, though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88302
Approved by: https://github.com/NivekT
diff --git a/torch/utils/data/datapipes/README.md b/torch/utils/data/datapipes/README.md
index 39a0a93..76dfd0e 100644
--- a/torch/utils/data/datapipes/README.md
+++ b/torch/utils/data/datapipes/README.md
@@ -38,6 +38,8 @@
### Length
In the most common cases, as the example of `MapperIterDataPipe` above, the `__len__` method of DataPipe should return the length of source DataPipe.
+Take care that `__len__` must be computed dynamically, because the length of source data-pipes might change after initialization (for example if sharding is applied).
+
```py
class MapperIterDataPipe(IterDataPipe):
...
diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py
index 088fcbf..178f043 100644
--- a/torch/utils/data/datapipes/iter/combining.py
+++ b/torch/utils/data/datapipes/iter/combining.py
@@ -39,7 +39,6 @@
[0, 1, 2, 0, 1, 2, 3, 4]
"""
datapipes: Tuple[IterDataPipe]
- length: Optional[int]
def __init__(self, *datapipes: IterDataPipe):
if len(datapipes) == 0:
@@ -47,7 +46,6 @@
if not all(isinstance(dp, IterDataPipe) for dp in datapipes):
raise TypeError("Expected all inputs to be `IterDataPipe`")
self.datapipes = datapipes # type: ignore[assignment]
- self.length = None
def __iter__(self) -> Iterator:
for dp in self.datapipes:
@@ -55,15 +53,10 @@
yield data
def __len__(self) -> int:
- if self.length is not None:
- if self.length == -1:
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
- return self.length
if all(isinstance(dp, Sized) for dp in self.datapipes):
- self.length = sum(len(dp) for dp in self.datapipes)
+ return sum(len(dp) for dp in self.datapipes)
else:
- self.length = -1
- return len(self)
+ raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
@functional_datapipe('fork')
@@ -519,7 +512,6 @@
"""
def __init__(self, *datapipes):
self.datapipes = datapipes
- self.length: Optional[int] = None
self.buffer: List = [] # Store values to be yielded only when every iterator provides one
def __iter__(self):
@@ -537,15 +529,10 @@
self.buffer.clear()
def __len__(self):
- if self.length is not None:
- if self.length == -1:
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
- return self.length
if all(isinstance(dp, Sized) for dp in self.datapipes):
- self.length = min(len(dp) for dp in self.datapipes) * len(self.datapipes)
+ return min(len(dp) for dp in self.datapipes) * len(self.datapipes)
else:
- self.length = -1
- return len(self)
+ raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
def reset(self) -> None:
self.buffer = []
@@ -553,7 +540,6 @@
def __getstate__(self):
state = (
self.datapipes,
- self.length,
self._valid_iterator_id,
self._number_of_samples_yielded,
)
@@ -564,7 +550,6 @@
def __setstate__(self, state):
(
self.datapipes,
- self.length,
self._valid_iterator_id,
self._number_of_samples_yielded,
) = state
@@ -591,7 +576,6 @@
[(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)]
"""
datapipes: Tuple[IterDataPipe]
- length: Optional[int]
def __init__(self, *datapipes: IterDataPipe):
if not all(isinstance(dp, IterDataPipe) for dp in datapipes):
@@ -599,7 +583,6 @@
"for `ZipIterDataPipe`.")
super().__init__()
self.datapipes = datapipes # type: ignore[assignment]
- self.length = None
def __iter__(self) -> Iterator[Tuple[T_co]]:
iterators = [iter(datapipe) for datapipe in self.datapipes]
@@ -607,12 +590,7 @@
yield data
def __len__(self) -> int:
- if self.length is not None:
- if self.length == -1:
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
- return self.length
if all(isinstance(dp, Sized) for dp in self.datapipes):
- self.length = min(len(dp) for dp in self.datapipes)
+ return min(len(dp) for dp in self.datapipes)
else:
- self.length = -1
- return len(self)
+ raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py
index 58d1509..61c7f1e 100644
--- a/torch/utils/data/datapipes/iter/grouping.py
+++ b/torch/utils/data/datapipes/iter/grouping.py
@@ -109,7 +109,6 @@
datapipe: IterDataPipe
batch_size: int
drop_last: bool
- length: Optional[int]
def __init__(self,
datapipe: IterDataPipe,
@@ -122,7 +121,6 @@
self.datapipe = datapipe
self.batch_size = batch_size
self.drop_last = drop_last
- self.length = None
self.wrapper_class = wrapper_class
def __iter__(self) -> Iterator[DataChunk]:
@@ -137,15 +135,13 @@
yield self.wrapper_class(batch)
def __len__(self) -> int:
- if self.length is not None:
- return self.length
if isinstance(self.datapipe, Sized):
if self.drop_last:
- self.length = len(self.datapipe) // self.batch_size
+ return len(self.datapipe) // self.batch_size
else:
- self.length = (len(self.datapipe) + self.batch_size - 1) // self.batch_size
- return self.length
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+ return (len(self.datapipe) + self.batch_size - 1) // self.batch_size
+ else:
+ raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
@functional_datapipe('unbatch')
diff --git a/torch/utils/data/datapipes/map/combining.py b/torch/utils/data/datapipes/map/combining.py
index aa3dac0..9bf1880 100644
--- a/torch/utils/data/datapipes/map/combining.py
+++ b/torch/utils/data/datapipes/map/combining.py
@@ -30,7 +30,6 @@
[0, 1, 2, 0, 1, 2]
"""
datapipes: Tuple[MapDataPipe]
- length: int
def __init__(self, *datapipes: MapDataPipe):
if len(datapipes) == 0:
@@ -40,7 +39,6 @@
if not all(isinstance(dp, Sized) for dp in datapipes):
raise TypeError("Expected all inputs to be `Sized`")
self.datapipes = datapipes # type: ignore[assignment]
- self.length = -1
def __getitem__(self, index) -> T_co:
offset = 0
@@ -52,9 +50,7 @@
raise IndexError("Index {} is out of range.".format(index))
def __len__(self) -> int:
- if self.length == -1:
- self.length = sum(len(dp) for dp in self.datapipes)
- return self.length
+ return sum(len(dp) for dp in self.datapipes)
@functional_datapipe('zip')
@@ -76,7 +72,6 @@
[(0, 10), (1, 11), (2, 12)]
"""
datapipes: Tuple[MapDataPipe[T_co], ...]
- length: int
def __init__(self, *datapipes: MapDataPipe[T_co]) -> None:
if len(datapipes) == 0:
@@ -86,7 +81,6 @@
if not all(isinstance(dp, Sized) for dp in datapipes):
raise TypeError("Expected all inputs to be `Sized`")
self.datapipes = datapipes
- self.length = -1
def __getitem__(self, index) -> Tuple[T_co, ...]:
res = []
@@ -98,6 +92,4 @@
return tuple(res)
def __len__(self) -> int:
- if self.length == -1:
- self.length = min(len(dp) for dp in self.datapipes)
- return self.length
+ return min(len(dp) for dp in self.datapipes)
diff --git a/torch/utils/data/datapipes/map/grouping.py b/torch/utils/data/datapipes/map/grouping.py
index 443c088..da3cf56 100644
--- a/torch/utils/data/datapipes/map/grouping.py
+++ b/torch/utils/data/datapipes/map/grouping.py
@@ -1,6 +1,6 @@
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import MapDataPipe, DataChunk
-from typing import List, Optional, Sized, TypeVar
+from typing import List, Sized, TypeVar
__all__ = ["BatcherMapDataPipe", ]
@@ -30,7 +30,6 @@
datapipe: MapDataPipe
batch_size: int
drop_last: bool
- length: Optional[int]
def __init__(self,
datapipe: MapDataPipe[T],
@@ -43,7 +42,6 @@
self.datapipe = datapipe
self.batch_size = batch_size
self.drop_last = drop_last
- self.length = None
self.wrapper_class = wrapper_class
def __getitem__(self, index) -> DataChunk:
@@ -60,12 +58,10 @@
raise IndexError(f"Index {index} is out of bound.") from e
def __len__(self) -> int:
- if self.length is not None:
- return self.length
if isinstance(self.datapipe, Sized):
if self.drop_last:
- self.length = len(self.datapipe) // self.batch_size
+ return len(self.datapipe) // self.batch_size
else:
- self.length = (len(self.datapipe) + self.batch_size - 1) // self.batch_size
- return self.length
- raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+ return (len(self.datapipe) + self.batch_size - 1) // self.batch_size
+ else:
+ raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))