| import tempfile |
| import warnings |
| |
| from torch.testing._internal.common_utils import (TestCase, run_tests) |
| |
| import torch.utils.data.datapipes as dp |
| |
| def create_temp_dir_and_files(): |
| # The temp dir and files within it will be released and deleted in tearDown(). |
| # Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function. |
| temp_dir = tempfile.TemporaryDirectory() # noqa: P201 |
| temp_dir_path = temp_dir.name |
| with tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False, suffix='.txt') as f: |
| temp_file1_name = f.name |
| with tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False, suffix='.byte') as f: |
| temp_file2_name = f.name |
| with tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False, suffix='.empty') as f: |
| temp_file3_name = f.name |
| |
| with open(temp_file1_name, 'w') as f1: |
| f1.write('0123456789abcdef') |
| with open(temp_file2_name, 'wb') as f2: |
| f2.write(b"0123456789abcdef") |
| |
| temp_sub_dir = tempfile.TemporaryDirectory(dir=temp_dir_path) # noqa: P201 |
| temp_sub_dir_path = temp_sub_dir.name |
| with tempfile.NamedTemporaryFile(dir=temp_sub_dir_path, delete=False, suffix='.txt') as f: |
| temp_sub_file1_name = f.name |
| with tempfile.NamedTemporaryFile(dir=temp_sub_dir_path, delete=False, suffix='.byte') as f: |
| temp_sub_file2_name = f.name |
| |
| with open(temp_sub_file1_name, 'w') as f1: |
| f1.write('0123456789abcdef') |
| with open(temp_sub_file2_name, 'wb') as f2: |
| f2.write(b"0123456789abcdef") |
| |
| return [(temp_dir, temp_file1_name, temp_file2_name, temp_file3_name), |
| (temp_sub_dir, temp_sub_file1_name, temp_sub_file2_name)] |
| |
| class TestIterableDataPipeBasic(TestCase): |
| |
| def setUp(self): |
| ret = create_temp_dir_and_files() |
| self.temp_dir = ret[0][0] |
| self.temp_files = ret[0][1:] |
| self.temp_sub_dir = ret[1][0] |
| self.temp_sub_files = ret[1][1:] |
| |
| def tearDown(self): |
| try: |
| self.temp_sub_dir.cleanup() |
| self.temp_dir.cleanup() |
| except Exception as e: |
| warnings.warn("TestIterableDatasetBasic was not able to cleanup temp dir due to {}".format(str(e))) |
| |
| def test_listdirfiles_iterable_datapipe(self): |
| temp_dir = self.temp_dir.name |
| datapipe = dp.iter.ListDirFiles(temp_dir, '') |
| |
| count = 0 |
| for pathname in datapipe: |
| count = count + 1 |
| self.assertTrue(pathname in self.temp_files) |
| self.assertEqual(count, len(self.temp_files)) |
| |
| count = 0 |
| datapipe = dp.iter.ListDirFiles(temp_dir, '', recursive=True) |
| for pathname in datapipe: |
| count = count + 1 |
| self.assertTrue((pathname in self.temp_files) or (pathname in self.temp_sub_files)) |
| self.assertEqual(count, len(self.temp_files) + len(self.temp_sub_files)) |
| |
| |
| def test_loadfilesfromdisk_iterable_datapipe(self): |
| # test import datapipe class directly |
| from torch.utils.data.datapipes.iter import ListDirFiles, LoadFilesFromDisk |
| |
| temp_dir = self.temp_dir.name |
| datapipe1 = ListDirFiles(temp_dir, '') |
| datapipe2 = LoadFilesFromDisk(datapipe1) |
| |
| count = 0 |
| for rec in datapipe2: |
| count = count + 1 |
| self.assertTrue(rec[0] in self.temp_files) |
| self.assertTrue(rec[1].read() == open(rec[0], 'rb').read()) |
| self.assertEqual(count, len(self.temp_files)) |
| |
| if __name__ == '__main__': |
| run_tests() |