[DataPipe] Optimize Grouper from N^2 to N (#68647)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68647

Fixes #68539

When all data from source datapipe depletes, there is no need to yield the biggest group in the buffer.

Test Plan: Imported from OSS

Reviewed By: jbschlosser

Differential Revision: D32562646

Pulled By: ejguan

fbshipit-source-id: ce91763656bc457e9c7d0af5861a5606c89965d5
diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py
index 37dcc62..6323361 100644
--- a/torch/utils/data/datapipes/iter/grouping.py
+++ b/torch/utils/data/datapipes/iter/grouping.py
@@ -315,7 +315,7 @@
                 if result_to_yield is not None:
                     yield self.wrapper_class(result_to_yield)
 
-        while buffer_size:
-            (result_to_yield, buffer_size) = self._remove_biggest_key(buffer_elements, buffer_size)
-            if result_to_yield is not None:
-                yield self.wrapper_class(result_to_yield)
+        for key in tuple(buffer_elements.keys()):
+            res = buffer_elements.pop(key)
+            buffer_size -= len(res)
+            yield self.wrapper_class(res)