[DataPipe] Optimize Grouper from N^2 to N (#68647)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68647
Fixes #68539
When all data from source datapipe depletes, there is no need to yield the biggest group in the buffer.
Test Plan: Imported from OSS
Reviewed By: jbschlosser
Differential Revision: D32562646
Pulled By: ejguan
fbshipit-source-id: ce91763656bc457e9c7d0af5861a5606c89965d5
diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py
index 37dcc62..6323361 100644
--- a/torch/utils/data/datapipes/iter/grouping.py
+++ b/torch/utils/data/datapipes/iter/grouping.py
@@ -315,7 +315,7 @@
if result_to_yield is not None:
yield self.wrapper_class(result_to_yield)
- while buffer_size:
- (result_to_yield, buffer_size) = self._remove_biggest_key(buffer_elements, buffer_size)
- if result_to_yield is not None:
- yield self.wrapper_class(result_to_yield)
+ for key in tuple(buffer_elements.keys()):
+ res = buffer_elements.pop(key)
+ buffer_size -= len(res)
+ yield self.wrapper_class(res)