Remove default parameter of ShufflerIterDataPipe (#74370)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74370
Closes https://github.com/pytorch/data/issues/298. This PR:
- removes the `default` parameter of `ShufflerIterDataPipe`
- renames `set_shuffle_setting()` into `set_shuffle()`
- let `set_shuffle()` return `self`.
Test Plan: Imported from OSS
Reviewed By: george-qi
Differential Revision: D35073666
Pulled By: NicolasHug
fbshipit-source-id: 9847b037e70f44f36eaf4471f2c12fa8ec2ed73c
(cherry picked from commit b07ab646f308532886e8daddd57e937a53edb153)
diff --git a/test/test_datapipe.py b/test/test_datapipe.py
index 0b7f093..000b044 100644
--- a/test/test_datapipe.py
+++ b/test/test_datapipe.py
@@ -1407,6 +1407,10 @@
with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"):
len(shuffle_dp_nl)
+ # Test: deactivate shuffling via set_shuffle
+ unshuffled_dp = input_ds.shuffle().set_shuffle(False)
+ self.assertEqual(list(unshuffled_dp), list(input_ds))
+
def test_zip_iterdatapipe(self):
# Functional Test: raises TypeError when an input is not of type `IterDataPipe`
diff --git a/torch/utils/data/datapipes/iter/combinatorics.py b/torch/utils/data/datapipes/iter/combinatorics.py
index ba92e82..9d03359 100644
--- a/torch/utils/data/datapipes/iter/combinatorics.py
+++ b/torch/utils/data/datapipes/iter/combinatorics.py
@@ -84,7 +84,6 @@
def __init__(self,
datapipe: IterDataPipe[T_co],
*,
- default: bool = True,
buffer_size: int = 10000,
unbatch_level: int = 0
) -> None:
@@ -95,7 +94,7 @@
else:
self.datapipe = datapipe.unbatch(unbatch_level=unbatch_level)
self.buffer_size = buffer_size
- self._shuffle_enabled = default
+ self._enabled = True
@staticmethod
def buffer_replace(buffer, x):
@@ -104,11 +103,12 @@
buffer[idx] = x
return val
- def set_shuffle_settings(self, shuffle=True):
- self._shuffle_enabled = shuffle
+ def set_shuffle(self, shuffle=True):
+ self._enabled = shuffle
+ return self
def __iter__(self) -> Iterator[T_co]:
- if not self._shuffle_enabled:
+ if not self._enabled:
for x in self.datapipe:
yield x
else:
diff --git a/torch/utils/data/graph_settings.py b/torch/utils/data/graph_settings.py
index 940f30c..add87a5 100644
--- a/torch/utils/data/graph_settings.py
+++ b/torch/utils/data/graph_settings.py
@@ -1,4 +1,5 @@
import torch.utils.data.graph
+from torch.utils.data.datapipes.iter import Shuffler
def get_all_graph_pipes(graph):
@@ -31,5 +32,5 @@
graph = torch.utils.data.graph.traverse(datapipe, only_datapipe=True)
all_pipes = get_all_graph_pipes(graph)
for pipe in all_pipes:
- if hasattr(pipe, 'set_shuffle_settings'):
- pipe.set_shuffle_settings(shuffle)
+ if isinstance(pipe, Shuffler):
+ pipe.set_shuffle(shuffle)