| import tempfile |
| import warnings |
| |
| import torch |
| from torch.testing._internal.common_utils import (TestCase, run_tests) |
| from torch.utils.data import IterableDataset, RandomSampler |
| from torch.utils.data.datasets import \ |
| (CollateIterableDataset, BatchIterableDataset, ListDirFilesIterableDataset, |
| LoadFilesFromDiskIterableDataset, SamplerIterableDataset) |
| |
| |
| def create_temp_dir_and_files(): |
| # The temp dir and files within it will be released and deleted in tearDown(). |
| # Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function. |
| temp_dir = tempfile.TemporaryDirectory() # noqa: P201 |
| temp_dir_path = temp_dir.name |
| temp_file1 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False) # noqa: P201 |
| temp_file2 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False) # noqa: P201 |
| temp_file3 = tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False) # noqa: P201 |
| |
| return (temp_dir, temp_file1.name, temp_file2.name, temp_file3.name) |
| |
| |
| class TestIterableDatasetBasic(TestCase): |
| |
| def setUp(self): |
| ret = create_temp_dir_and_files() |
| self.temp_dir = ret[0] |
| self.temp_files = ret[1:] |
| |
| def tearDown(self): |
| try: |
| self.temp_dir.cleanup() |
| except Exception as e: |
| warnings.warn("TestIterableDatasetBasic was not able to cleanup temp dir due to {}".format(str(e))) |
| |
| def test_listdirfiles_iterable_dataset(self): |
| temp_dir = self.temp_dir.name |
| dataset = ListDirFilesIterableDataset(temp_dir, '') |
| for pathname in dataset: |
| self.assertTrue(pathname in self.temp_files) |
| |
| def test_loadfilesfromdisk_iterable_dataset(self): |
| temp_dir = self.temp_dir.name |
| dataset1 = ListDirFilesIterableDataset(temp_dir, '') |
| dataset2 = LoadFilesFromDiskIterableDataset(dataset1) |
| |
| for rec in dataset2: |
| self.assertTrue(rec[0] in self.temp_files) |
| self.assertTrue(rec[1].read() == open(rec[0], 'rb').read()) |
| |
| |
| class IterDatasetWithoutLen(IterableDataset): |
| def __init__(self, ds): |
| super().__init__() |
| self.ds = ds |
| |
| def __iter__(self): |
| for i in self.ds: |
| yield i |
| |
| |
| class IterDatasetWithLen(IterableDataset): |
| def __init__(self, ds): |
| super().__init__() |
| self.ds = ds |
| self.length = len(ds) |
| |
| def __iter__(self): |
| for i in self.ds: |
| yield i |
| |
| def __len__(self): |
| return self.length |
| |
| |
| class TestFunctionalIterableDataset(TestCase): |
| def test_collate_dataset(self): |
| arrs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] |
| ds_len = IterDatasetWithLen(arrs) |
| ds_nolen = IterDatasetWithoutLen(arrs) |
| |
| def _collate_fn(batch): |
| return torch.tensor(sum(batch), dtype=torch.float) |
| |
| collate_ds = CollateIterableDataset(ds_len, collate_fn=_collate_fn) |
| self.assertEqual(len(ds_len), len(collate_ds)) |
| ds_iter = iter(ds_len) |
| for x in collate_ds: |
| y = next(ds_iter) |
| self.assertEqual(x, torch.tensor(sum(y), dtype=torch.float)) |
| |
| collate_ds_nolen = CollateIterableDataset(ds_nolen) # type: ignore |
| with self.assertRaises(NotImplementedError): |
| len(collate_ds_nolen) |
| ds_nolen_iter = iter(ds_nolen) |
| for x in collate_ds_nolen: |
| y = next(ds_nolen_iter) |
| self.assertEqual(x, torch.tensor(y)) |
| |
| def test_batch_dataset(self): |
| arrs = range(10) |
| ds = IterDatasetWithLen(arrs) |
| with self.assertRaises(AssertionError): |
| batch_ds0 = BatchIterableDataset(ds, batch_size=0) |
| |
| # Default not drop the last batch |
| batch_ds1 = BatchIterableDataset(ds, batch_size=3) |
| self.assertEqual(len(batch_ds1), 4) |
| batch_iter = iter(batch_ds1) |
| value = 0 |
| for i in range(len(batch_ds1)): |
| batch = next(batch_iter) |
| if i == 3: |
| self.assertEqual(len(batch), 1) |
| self.assertEqual(batch, [9]) |
| else: |
| self.assertEqual(len(batch), 3) |
| for x in batch: |
| self.assertEqual(x, value) |
| value += 1 |
| |
| # Drop the last batch |
| batch_ds2 = BatchIterableDataset(ds, batch_size=3, drop_last=True) |
| self.assertEqual(len(batch_ds2), 3) |
| value = 0 |
| for batch in batch_ds2: |
| self.assertEqual(len(batch), 3) |
| for x in batch: |
| self.assertEqual(x, value) |
| value += 1 |
| |
| batch_ds3 = BatchIterableDataset(ds, batch_size=2) |
| self.assertEqual(len(batch_ds3), 5) |
| batch_ds4 = BatchIterableDataset(ds, batch_size=2, drop_last=True) |
| self.assertEqual(len(batch_ds4), 5) |
| |
| ds_nolen = IterDatasetWithoutLen(arrs) |
| batch_ds_nolen = BatchIterableDataset(ds_nolen, batch_size=5) |
| with self.assertRaises(NotImplementedError): |
| len(batch_ds_nolen) |
| |
| def test_sampler_dataset(self): |
| arrs = range(10) |
| ds = IterDatasetWithLen(arrs) |
| # Default SequentialSampler |
| sampled_ds = SamplerIterableDataset(ds) # type: ignore |
| self.assertEqual(len(sampled_ds), 10) |
| i = 0 |
| for x in sampled_ds: |
| self.assertEqual(x, i) |
| i += 1 |
| |
| # RandomSampler |
| random_sampled_ds = SamplerIterableDataset(ds, sampler=RandomSampler, replacement=True) # type: ignore |
| |
| # Requires `__len__` to build SamplerDataset |
| ds_nolen = IterDatasetWithoutLen(arrs) |
| with self.assertRaises(AssertionError): |
| sampled_ds = SamplerIterableDataset(ds_nolen) |
| |
| |
| if __name__ == '__main__': |
| run_tests() |