[DataPipe] adding description, __len__, tests for mux() (#64224)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64224

cc VitalyFedyunin ejguan

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D30651551

Pulled By: NivekT

fbshipit-source-id: f8af98ba71a592900b992a8077432062ec57bb48
diff --git a/test/test_datapipe.py b/test/test_datapipe.py
index 4e37f41..24d0ce2 100644
--- a/test/test_datapipe.py
+++ b/test/test_datapipe.py
@@ -354,6 +354,15 @@
         n = n1.mux(n2, n3)
         self.assertEqual(list(range(10)), list(n))
 
+        # Test Case: Uneven DataPipes
+        source_numbers = list(range(0, 10)) + [10, 12]
+        numbers_dp = IDP(source_numbers)
+        n1, n2 = numbers_dp.demux(2, lambda x: x % 2)
+        self.assertEqual([0, 2, 4, 6, 8, 10, 12], list(n1))
+        self.assertEqual([1, 3, 5, 7, 9], list(n2))
+        n = n1.mux(n2)
+        self.assertEqual(source_numbers, list(n))
+
 
 class FileLoggerSimpleHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
     def __init__(self, *args, logfile=None, **kwargs):
@@ -1221,6 +1230,39 @@
                 map_dp[index], torch.tensor(input_dp[index], dtype=torch.int).sum()
             )
 
+    def test_mux_datapipe(self):
+
+        # Test Case: Elements are yielded one at a time from each DataPipe, until they are all exhausted
+        input_dp1 = IDP(range(4))
+        input_dp2 = IDP(range(4, 8))
+        input_dp3 = IDP(range(8, 12))
+        output_dp = input_dp1.mux(input_dp2, input_dp3)
+        expected_output = [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11]
+        self.assertEqual(len(expected_output), len(output_dp))
+        self.assertEqual(expected_output, list(output_dp))
+
+        # Test Case: Uneven input Data Pipes
+        input_dp1 = IDP([1, 2, 3, 4])
+        input_dp2 = IDP([10])
+        input_dp3 = IDP([100, 200, 300])
+        output_dp = input_dp1.mux(input_dp2, input_dp3)
+        expected_output = [1, 10, 100, 2, 200, 3, 300, 4]
+        self.assertEqual(len(expected_output), len(output_dp))
+        self.assertEqual(expected_output, list(output_dp))
+
+        # Test Case: Empty Data Pipe
+        input_dp1 = IDP([0, 1, 2, 3])
+        input_dp2 = IDP([])
+        output_dp = input_dp1.mux(input_dp2)
+        self.assertEqual(len(input_dp1), len(output_dp))
+        self.assertEqual(list(input_dp1), list(output_dp))
+
+        # Test Case: raises TypeError when __len__ is called and an input doesn't have __len__
+        input_dp1 = IDP(range(10))
+        input_dp_no_len = IDP_NoLen(range(10))
+        output_dp = input_dp1.mux(input_dp_no_len)
+        with self.assertRaises(TypeError):
+            len(output_dp)
 
 # Metaclass conflict for Python 3.6
 # Multiple inheritance with NamedTuple is not supported for Python 3.9
diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py
index a837c5b..ed1256f 100644
--- a/torch/utils/data/datapipes/iter/combining.py
+++ b/torch/utils/data/datapipes/iter/combining.py
@@ -1,7 +1,7 @@
 import warnings
 
 from torch.utils.data import IterDataPipe, functional_datapipe
-from typing import Any, Callable, Iterator, List, Optional, Sized, Tuple, TypeVar, Deque
+from typing import Any, Callable, Iterator, List, Optional, Set, Sized, Tuple, TypeVar, Deque
 from collections import deque
 
 T_co = TypeVar('T_co', covariant=True)
@@ -261,24 +261,41 @@
 
 @functional_datapipe('mux')
 class MultiplexerIterDataPipe(IterDataPipe):
+    r""" :class:`MultiplexerIterDataPipe`.
 
+        Iterable DataPipe that yields one element at a time from each input Iterable DataPipe
+        (i.e. one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration,
+        and so on). It skips over DataPipes that are exhausted, and ends when all input DataPipes are exhausted.
+
+        Args:
+            datapipes: Iterable DataPipes that will take turn to yield their elements, until they are all exhausted
+    """
     def __init__(self, *datapipes):
         self.datapipes = datapipes
+        self.length: Optional[int] = None
 
     def __iter__(self):
         iterators = [iter(x) for x in self.datapipes]
-        finished = {}
-        had_more = True
-        while had_more:
-            had_more = False
+        finished: Set[int] = set()
+        while len(finished) < len(iterators):
             for i in range(len(iterators)):
                 if i not in finished:
                     try:
-                        value = iterators[i].__next__()
-                        had_more = True
+                        value = next(iterators[i])
                         yield value
                     except StopIteration:
-                        finished[i] = 1
+                        finished.add(i)
+
+    def __len__(self):
+        if self.length is not None:
+            if self.length == -1:
+                raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
+            return self.length
+        if all(isinstance(dp, Sized) for dp in self.datapipes):
+            self.length = sum(len(dp) for dp in self.datapipes)
+        else:
+            self.length = -1
+        return len(self)
 
 
 @functional_datapipe('zip')