Fix: Make `__len__` of datapipes dynamic (#88302)

Fixes #88074

Several datapipes have their lengths cached on being executed for the first time. However, source datapipes might change in length (most prominently, whenever `apply_sharding` is called). The behaviour is counter-intuitive because we do not expect `__len__` to have side-effects.

This PR makes `__len__` dynamically computed.

Changes:
- Add note to the `datapipes` README that `__len__` should be dynamic and why.
- Remove caching of length computations in `ConcaterIterDataPipe`, `MultiplexerIterDataPipe`, `ZipperIterDataPipe`, `BatcherIterDataPipe`, `ConcaterMapDataPipe`, and `BatcherMapDataPipe`.
- This required removal of the `length` attribute in setstate/getstate of `MultiplexerIterDataPipe`. I am unsure whether to remove this completely and risk breaking saved checkpoints (as I did) or whether to just ignore the `length` of the loaded `state`.
- This also means the classes above no longer have a `length` attribute. I have found no uses of this, though.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88302
Approved by: https://github.com/NivekT
diff --git a/torch/utils/data/datapipes/README.md b/torch/utils/data/datapipes/README.md
index 39a0a93..76dfd0e 100644
--- a/torch/utils/data/datapipes/README.md
+++ b/torch/utils/data/datapipes/README.md
@@ -38,6 +38,8 @@
 
 ### Length
 In the most common cases, as the example of `MapperIterDataPipe` above, the `__len__` method of DataPipe should return the length of source DataPipe.
+Take care that `__len__` must be computed dynamically, because the length of source data-pipes might change after initialization (for example if sharding is applied).
+
 ```py
 class MapperIterDataPipe(IterDataPipe):
     ...
diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py
index 088fcbf..178f043 100644
--- a/torch/utils/data/datapipes/iter/combining.py
+++ b/torch/utils/data/datapipes/iter/combining.py
@@ -39,7 +39,6 @@
         [0, 1, 2, 0, 1, 2, 3, 4]
     """
     datapipes: Tuple[IterDataPipe]
-    length: Optional[int]
 
     def __init__(self, *datapipes: IterDataPipe):
         if len(datapipes) == 0:
@@ -47,7 +46,6 @@
         if not all(isinstance(dp, IterDataPipe) for dp in datapipes):
             raise TypeError("Expected all inputs to be `IterDataPipe`")
         self.datapipes = datapipes  # type: ignore[assignment]
-        self.length = None
 
     def __iter__(self) -> Iterator:
         for dp in self.datapipes:
@@ -55,15 +53,10 @@
                 yield data
 
     def __len__(self) -> int:
-        if self.length is not None:
-            if self.length == -1:
-                raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
-            return self.length
         if all(isinstance(dp, Sized) for dp in self.datapipes):
-            self.length = sum(len(dp) for dp in self.datapipes)
+            return sum(len(dp) for dp in self.datapipes)
         else:
-            self.length = -1
-        return len(self)
+            raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
 
 
 @functional_datapipe('fork')
@@ -519,7 +512,6 @@
     """
     def __init__(self, *datapipes):
         self.datapipes = datapipes
-        self.length: Optional[int] = None
         self.buffer: List = []  # Store values to be yielded only when every iterator provides one
 
     def __iter__(self):
@@ -537,15 +529,10 @@
             self.buffer.clear()
 
     def __len__(self):
-        if self.length is not None:
-            if self.length == -1:
-                raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
-            return self.length
         if all(isinstance(dp, Sized) for dp in self.datapipes):
-            self.length = min(len(dp) for dp in self.datapipes) * len(self.datapipes)
+            return min(len(dp) for dp in self.datapipes) * len(self.datapipes)
         else:
-            self.length = -1
-        return len(self)
+            raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
 
     def reset(self) -> None:
         self.buffer = []
@@ -553,7 +540,6 @@
     def __getstate__(self):
         state = (
             self.datapipes,
-            self.length,
             self._valid_iterator_id,
             self._number_of_samples_yielded,
         )
@@ -564,7 +550,6 @@
     def __setstate__(self, state):
         (
             self.datapipes,
-            self.length,
             self._valid_iterator_id,
             self._number_of_samples_yielded,
         ) = state
@@ -591,7 +576,6 @@
         [(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)]
     """
     datapipes: Tuple[IterDataPipe]
-    length: Optional[int]
 
     def __init__(self, *datapipes: IterDataPipe):
         if not all(isinstance(dp, IterDataPipe) for dp in datapipes):
@@ -599,7 +583,6 @@
                             "for `ZipIterDataPipe`.")
         super().__init__()
         self.datapipes = datapipes  # type: ignore[assignment]
-        self.length = None
 
     def __iter__(self) -> Iterator[Tuple[T_co]]:
         iterators = [iter(datapipe) for datapipe in self.datapipes]
@@ -607,12 +590,7 @@
             yield data
 
     def __len__(self) -> int:
-        if self.length is not None:
-            if self.length == -1:
-                raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
-            return self.length
         if all(isinstance(dp, Sized) for dp in self.datapipes):
-            self.length = min(len(dp) for dp in self.datapipes)
+            return min(len(dp) for dp in self.datapipes)
         else:
-            self.length = -1
-        return len(self)
+            raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py
index 58d1509..61c7f1e 100644
--- a/torch/utils/data/datapipes/iter/grouping.py
+++ b/torch/utils/data/datapipes/iter/grouping.py
@@ -109,7 +109,6 @@
     datapipe: IterDataPipe
     batch_size: int
     drop_last: bool
-    length: Optional[int]
 
     def __init__(self,
                  datapipe: IterDataPipe,
@@ -122,7 +121,6 @@
         self.datapipe = datapipe
         self.batch_size = batch_size
         self.drop_last = drop_last
-        self.length = None
         self.wrapper_class = wrapper_class
 
     def __iter__(self) -> Iterator[DataChunk]:
@@ -137,15 +135,13 @@
                 yield self.wrapper_class(batch)
 
     def __len__(self) -> int:
-        if self.length is not None:
-            return self.length
         if isinstance(self.datapipe, Sized):
             if self.drop_last:
-                self.length = len(self.datapipe) // self.batch_size
+                return len(self.datapipe) // self.batch_size
             else:
-                self.length = (len(self.datapipe) + self.batch_size - 1) // self.batch_size
-            return self.length
-        raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+                return (len(self.datapipe) + self.batch_size - 1) // self.batch_size
+        else:
+            raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
 
 
 @functional_datapipe('unbatch')
diff --git a/torch/utils/data/datapipes/map/combining.py b/torch/utils/data/datapipes/map/combining.py
index aa3dac0..9bf1880 100644
--- a/torch/utils/data/datapipes/map/combining.py
+++ b/torch/utils/data/datapipes/map/combining.py
@@ -30,7 +30,6 @@
         [0, 1, 2, 0, 1, 2]
     """
     datapipes: Tuple[MapDataPipe]
-    length: int
 
     def __init__(self, *datapipes: MapDataPipe):
         if len(datapipes) == 0:
@@ -40,7 +39,6 @@
         if not all(isinstance(dp, Sized) for dp in datapipes):
             raise TypeError("Expected all inputs to be `Sized`")
         self.datapipes = datapipes  # type: ignore[assignment]
-        self.length = -1
 
     def __getitem__(self, index) -> T_co:
         offset = 0
@@ -52,9 +50,7 @@
         raise IndexError("Index {} is out of range.".format(index))
 
     def __len__(self) -> int:
-        if self.length == -1:
-            self.length = sum(len(dp) for dp in self.datapipes)
-        return self.length
+        return sum(len(dp) for dp in self.datapipes)
 
 
 @functional_datapipe('zip')
@@ -76,7 +72,6 @@
         [(0, 10), (1, 11), (2, 12)]
     """
     datapipes: Tuple[MapDataPipe[T_co], ...]
-    length: int
 
     def __init__(self, *datapipes: MapDataPipe[T_co]) -> None:
         if len(datapipes) == 0:
@@ -86,7 +81,6 @@
         if not all(isinstance(dp, Sized) for dp in datapipes):
             raise TypeError("Expected all inputs to be `Sized`")
         self.datapipes = datapipes
-        self.length = -1
 
     def __getitem__(self, index) -> Tuple[T_co, ...]:
         res = []
@@ -98,6 +92,4 @@
         return tuple(res)
 
     def __len__(self) -> int:
-        if self.length == -1:
-            self.length = min(len(dp) for dp in self.datapipes)
-        return self.length
+        return min(len(dp) for dp in self.datapipes)
diff --git a/torch/utils/data/datapipes/map/grouping.py b/torch/utils/data/datapipes/map/grouping.py
index 443c088..da3cf56 100644
--- a/torch/utils/data/datapipes/map/grouping.py
+++ b/torch/utils/data/datapipes/map/grouping.py
@@ -1,6 +1,6 @@
 from torch.utils.data.datapipes._decorator import functional_datapipe
 from torch.utils.data.datapipes.datapipe import MapDataPipe, DataChunk
-from typing import List, Optional, Sized, TypeVar
+from typing import List, Sized, TypeVar
 
 __all__ = ["BatcherMapDataPipe", ]
 
@@ -30,7 +30,6 @@
     datapipe: MapDataPipe
     batch_size: int
     drop_last: bool
-    length: Optional[int]
 
     def __init__(self,
                  datapipe: MapDataPipe[T],
@@ -43,7 +42,6 @@
         self.datapipe = datapipe
         self.batch_size = batch_size
         self.drop_last = drop_last
-        self.length = None
         self.wrapper_class = wrapper_class
 
     def __getitem__(self, index) -> DataChunk:
@@ -60,12 +58,10 @@
                 raise IndexError(f"Index {index} is out of bound.") from e
 
     def __len__(self) -> int:
-        if self.length is not None:
-            return self.length
         if isinstance(self.datapipe, Sized):
             if self.drop_last:
-                self.length = len(self.datapipe) // self.batch_size
+                return len(self.datapipe) // self.batch_size
             else:
-                self.length = (len(self.datapipe) + self.batch_size - 1) // self.batch_size
-            return self.length
-        raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+                return (len(self.datapipe) + self.batch_size - 1) // self.batch_size
+        else:
+            raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))