| import http.server |
| import itertools |
| import os |
| import os.path |
| import pickle |
| import random |
| import socketserver |
| import sys |
| import tarfile |
| import tempfile |
| import threading |
| import time |
| import unittest |
| import warnings |
| import zipfile |
| from functools import partial |
| from typing import ( |
| Any, |
| Awaitable, |
| Dict, |
| Generic, |
| Iterator, |
| List, |
| NamedTuple, |
| Optional, |
| Set, |
| Tuple, |
| Type, |
| TypeVar, |
| Union, |
| ) |
| from unittest import skipIf |
| |
| import numpy as np |
| |
| import torch |
| import torch.nn as nn |
| import torch.utils.data.backward_compatibility |
| import torch.utils.data.datapipes as dp |
| import torch.utils.data.graph |
| import torch.utils.data.sharding |
| from torch.testing._internal.common_utils import TestCase, run_tests |
| from torch.utils.data import ( |
| DataLoader, |
| IterDataPipe, |
| MapDataPipe, |
| RandomSampler, |
| argument_validation, |
| runtime_validation, |
| runtime_validation_disabled, |
| ) |
| from torch.utils.data.datapipes.utils.decoder import ( |
| basichandlers as decoder_basichandlers, |
| ) |
| |
| try: |
| import torchvision.transforms |
| HAS_TORCHVISION = True |
| except ImportError: |
| HAS_TORCHVISION = False |
| skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision") |
| |
| try: |
| import dill |
| # XXX: By default, dill writes the Pickler dispatch table to inject its |
| # own logic there. This globally affects the behavior of the standard library |
| # pickler for any user who transitively depends on this module! |
| # Undo this extension to avoid altering the behavior of the pickler globally. |
| dill.extend(use_dill=False) |
| HAS_DILL = True |
| except ImportError: |
| HAS_DILL = False |
| skipIfNoDill = skipIf(not HAS_DILL, "no dill") |
| |
| T_co = TypeVar('T_co', covariant=True) |
| |
| |
| def create_temp_dir_and_files(): |
| # The temp dir and files within it will be released and deleted in tearDown(). |
| # Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function. |
| temp_dir = tempfile.TemporaryDirectory() # noqa: P201 |
| temp_dir_path = temp_dir.name |
| with tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False, suffix='.txt') as f: |
| temp_file1_name = f.name |
| with tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False, suffix='.byte') as f: |
| temp_file2_name = f.name |
| with tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False, suffix='.empty') as f: |
| temp_file3_name = f.name |
| |
| with open(temp_file1_name, 'w') as f1: |
| f1.write('0123456789abcdef') |
| with open(temp_file2_name, 'wb') as f2: |
| f2.write(b"0123456789abcdef") |
| |
| temp_sub_dir = tempfile.TemporaryDirectory(dir=temp_dir_path) # noqa: P201 |
| temp_sub_dir_path = temp_sub_dir.name |
| with tempfile.NamedTemporaryFile(dir=temp_sub_dir_path, delete=False, suffix='.txt') as f: |
| temp_sub_file1_name = f.name |
| with tempfile.NamedTemporaryFile(dir=temp_sub_dir_path, delete=False, suffix='.byte') as f: |
| temp_sub_file2_name = f.name |
| |
| with open(temp_sub_file1_name, 'w') as f1: |
| f1.write('0123456789abcdef') |
| with open(temp_sub_file2_name, 'wb') as f2: |
| f2.write(b"0123456789abcdef") |
| |
| return [(temp_dir, temp_file1_name, temp_file2_name, temp_file3_name), |
| (temp_sub_dir, temp_sub_file1_name, temp_sub_file2_name)] |
| |
| |
| class TestIterableDataPipeBasic(TestCase): |
| |
| def setUp(self): |
| ret = create_temp_dir_and_files() |
| self.temp_dir = ret[0][0] |
| self.temp_files = ret[0][1:] |
| self.temp_sub_dir = ret[1][0] |
| self.temp_sub_files = ret[1][1:] |
| |
| def tearDown(self): |
| try: |
| self.temp_sub_dir.cleanup() |
| self.temp_dir.cleanup() |
| except Exception as e: |
| warnings.warn("TestIterableDatasetBasic was not able to cleanup temp dir due to {}".format(str(e))) |
| |
| def test_listdirfiles_iterable_datapipe(self): |
| temp_dir = self.temp_dir.name |
| datapipe = dp.iter.ListDirFiles(temp_dir, '') |
| |
| count = 0 |
| for pathname in datapipe: |
| count = count + 1 |
| self.assertTrue(pathname in self.temp_files) |
| self.assertEqual(count, len(self.temp_files)) |
| |
| count = 0 |
| datapipe = dp.iter.ListDirFiles(temp_dir, '', recursive=True) |
| for pathname in datapipe: |
| count = count + 1 |
| self.assertTrue((pathname in self.temp_files) or (pathname in self.temp_sub_files)) |
| self.assertEqual(count, len(self.temp_files) + len(self.temp_sub_files)) |
| |
| |
| def test_loadfilesfromdisk_iterable_datapipe(self): |
| # test import datapipe class directly |
| from torch.utils.data.datapipes.iter import ( |
| ListDirFiles, |
| LoadFilesFromDisk, |
| ) |
| |
| temp_dir = self.temp_dir.name |
| datapipe1 = ListDirFiles(temp_dir, '') |
| datapipe2 = LoadFilesFromDisk(datapipe1) |
| |
| count = 0 |
| for rec in datapipe2: |
| count = count + 1 |
| self.assertTrue(rec[0] in self.temp_files) |
| with open(rec[0], 'rb') as f: |
| self.assertEqual(rec[1].read(), f.read()) |
| rec[1].close() |
| self.assertEqual(count, len(self.temp_files)) |
| |
| # TODO(VitalyFedyunin): Generates unclosed buffer warning, need to investigate |
| def test_readfilesfromtar_iterable_datapipe(self): |
| temp_dir = self.temp_dir.name |
| temp_tarfile_pathname = os.path.join(temp_dir, "test_tar.tar") |
| with tarfile.open(temp_tarfile_pathname, "w:gz") as tar: |
| tar.add(self.temp_files[0]) |
| tar.add(self.temp_files[1]) |
| tar.add(self.temp_files[2]) |
| datapipe1 = dp.iter.ListDirFiles(temp_dir, '*.tar') |
| datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1) |
| datapipe3 = dp.iter.ReadFilesFromTar(datapipe2) |
| # read extracted files before reaching the end of the tarfile |
| for rec, temp_file in itertools.zip_longest(datapipe3, self.temp_files): |
| self.assertTrue(rec is not None and temp_file is not None) |
| self.assertEqual(os.path.basename(rec[0]), os.path.basename(temp_file)) |
| with open(temp_file, 'rb') as f: |
| self.assertEqual(rec[1].read(), f.read()) |
| rec[1].close() |
| # read extracted files after reaching the end of the tarfile |
| data_refs = list(datapipe3) |
| self.assertEqual(len(data_refs), len(self.temp_files)) |
| for data_ref, temp_file in zip(data_refs, self.temp_files): |
| self.assertEqual(os.path.basename(data_ref[0]), os.path.basename(temp_file)) |
| with open(temp_file, 'rb') as f: |
| self.assertEqual(data_ref[1].read(), f.read()) |
| data_ref[1].close() |
| |
| # TODO(VitalyFedyunin): Generates unclosed buffer warning, need to investigate |
| def test_readfilesfromzip_iterable_datapipe(self): |
| temp_dir = self.temp_dir.name |
| temp_zipfile_pathname = os.path.join(temp_dir, "test_zip.zip") |
| with zipfile.ZipFile(temp_zipfile_pathname, 'w') as myzip: |
| myzip.write(self.temp_files[0]) |
| myzip.write(self.temp_files[1]) |
| myzip.write(self.temp_files[2]) |
| datapipe1 = dp.iter.ListDirFiles(temp_dir, '*.zip') |
| datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1) |
| datapipe3 = dp.iter.ReadFilesFromZip(datapipe2) |
| # read extracted files before reaching the end of the zipfile |
| for rec, temp_file in itertools.zip_longest(datapipe3, self.temp_files): |
| self.assertTrue(rec is not None and temp_file is not None) |
| self.assertEqual(os.path.basename(rec[0]), os.path.basename(temp_file)) |
| with open(temp_file, 'rb') as f: |
| self.assertEqual(rec[1].read(), f.read()) |
| rec[1].close() |
| # read extracted files before reaching the end of the zipile |
| data_refs = list(datapipe3) |
| self.assertEqual(len(data_refs), len(self.temp_files)) |
| for data_ref, temp_file in zip(data_refs, self.temp_files): |
| self.assertEqual(os.path.basename(data_ref[0]), os.path.basename(temp_file)) |
| with open(temp_file, 'rb') as f: |
| self.assertEqual(data_ref[1].read(), f.read()) |
| data_ref[1].close() |
| |
| |
| def test_routeddecoder_iterable_datapipe(self): |
| temp_dir = self.temp_dir.name |
| temp_pngfile_pathname = os.path.join(temp_dir, "test_png.png") |
| png_data = np.array([[[1., 0., 0.], [1., 0., 0.]], [[1., 0., 0.], [1., 0., 0.]]], dtype=np.single) |
| np.save(temp_pngfile_pathname, png_data) |
| datapipe1 = dp.iter.ListDirFiles(temp_dir, ['*.png', '*.txt']) |
| datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1) |
| |
| def _png_decoder(extension, data): |
| if extension != 'png': |
| return None |
| return np.load(data) |
| |
| def _helper(prior_dp, dp, channel_first=False): |
| # Byte stream is not closed |
| for inp in prior_dp: |
| self.assertFalse(inp[1].closed) |
| for inp, rec in zip(prior_dp, dp): |
| ext = os.path.splitext(rec[0])[1] |
| if ext == '.png': |
| expected = np.array([[[1., 0., 0.], [1., 0., 0.]], [[1., 0., 0.], [1., 0., 0.]]], dtype=np.single) |
| if channel_first: |
| expected = expected.transpose(2, 0, 1) |
| self.assertEqual(rec[1], expected) |
| else: |
| with open(rec[0], 'rb') as f: |
| self.assertEqual(rec[1], f.read().decode('utf-8')) |
| # Corresponding byte stream is closed by Decoder |
| self.assertTrue(inp[1].closed) |
| |
| cached = list(datapipe2) |
| datapipe3 = dp.iter.RoutedDecoder(cached, _png_decoder) |
| datapipe3.add_handler(decoder_basichandlers) |
| _helper(cached, datapipe3) |
| |
| cached = list(datapipe2) |
| datapipe4 = dp.iter.RoutedDecoder(cached, decoder_basichandlers) |
| datapipe4.add_handler(_png_decoder) |
| _helper(cached, datapipe4, channel_first=True) |
| |
| # TODO(VitalyFedyunin): Generates unclosed buffer warning, need to investigate |
| def test_groupbykey_iterable_datapipe(self): |
| temp_dir = self.temp_dir.name |
| temp_tarfile_pathname = os.path.join(temp_dir, "test_tar.tar") |
| file_list = [ |
| "a.png", "b.png", "c.json", "a.json", "c.png", "b.json", "d.png", |
| "d.json", "e.png", "f.json", "g.png", "f.png", "g.json", "e.json", |
| "h.txt", "h.json"] |
| with tarfile.open(temp_tarfile_pathname, "w:gz") as tar: |
| for file_name in file_list: |
| file_pathname = os.path.join(temp_dir, file_name) |
| with open(file_pathname, 'w') as f: |
| f.write('12345abcde') |
| tar.add(file_pathname) |
| |
| datapipe1 = dp.iter.ListDirFiles(temp_dir, '*.tar') |
| datapipe2 = dp.iter.LoadFilesFromDisk(datapipe1) |
| datapipe3 = dp.iter.ReadFilesFromTar(datapipe2) |
| datapipe4 = dp.iter.GroupByKey(datapipe3, group_size=2) |
| |
| expected_result = [("a.png", "a.json"), ("c.png", "c.json"), ("b.png", "b.json"), ("d.png", "d.json"), ( |
| "f.png", "f.json"), ("g.png", "g.json"), ("e.png", "e.json"), ("h.json", "h.txt")] |
| |
| count = 0 |
| for rec, expected in zip(datapipe4, expected_result): |
| count = count + 1 |
| self.assertEqual(os.path.basename(rec[0][0]), expected[0]) |
| self.assertEqual(os.path.basename(rec[1][0]), expected[1]) |
| for i in [0, 1]: |
| self.assertEqual(rec[i][1].read(), b'12345abcde') |
| rec[i][1].close() |
| self.assertEqual(count, 8) |
| |
| def test_demux_mux_datapipe(self): |
| numbers = NumbersDataset(10) |
| n1, n2 = numbers.demux(2, lambda x: x % 2) |
| self.assertEqual([0, 2, 4, 6, 8], list(n1)) |
| self.assertEqual([1, 3, 5, 7, 9], list(n2)) |
| |
| numbers = NumbersDataset(10) |
| n1, n2, n3 = numbers.demux(3, lambda x: x % 3) |
| n = n1.mux(n2, n3) |
| self.assertEqual(list(range(10)), list(n)) |
| |
| |
| class FileLoggerSimpleHTTPRequestHandler(http.server.SimpleHTTPRequestHandler): |
| def __init__(self, *args, logfile=None, **kwargs): |
| self.__loggerHandle = None |
| if logfile is not None: |
| self.__loggerHandle = open(logfile, 'a+') |
| super().__init__(*args, **kwargs) |
| |
| def log_message(self, format, *args): |
| if self.__loggerHandle is not None: |
| self.__loggerHandle.write("%s - - [%s] %s\n" % |
| (self.address_string(), |
| self.log_date_time_string(), |
| format % args)) |
| return |
| |
| def finish(self): |
| if self.__loggerHandle is not None: |
| self.__loggerHandle.close() |
| super().finish() |
| |
| |
| def setUpLocalServerInThread(): |
| try: |
| Handler = partial(FileLoggerSimpleHTTPRequestHandler, logfile=None) |
| socketserver.TCPServer.allow_reuse_address = True |
| |
| server = socketserver.TCPServer(("", 0), Handler) |
| server_addr = "{host}:{port}".format(host=server.server_address[0], port=server.server_address[1]) |
| server_thread = threading.Thread(target=server.serve_forever) |
| server_thread.start() |
| |
| # Wait a bit for the server to come up |
| time.sleep(3) |
| |
| return (server_thread, server_addr, server) |
| except Exception: |
| raise |
| |
| |
| def create_temp_files_for_serving(tmp_dir, file_count, file_size, |
| file_url_template): |
| furl_local_file = os.path.join(tmp_dir, "urls_list") |
| with open(furl_local_file, 'w') as fsum: |
| for i in range(0, file_count): |
| f = os.path.join(tmp_dir, "webfile_test_{num}.data".format(num=i)) |
| |
| write_chunk = 1024 * 1024 * 16 |
| rmn_size = file_size |
| while rmn_size > 0: |
| with open(f, 'ab+') as fout: |
| fout.write(os.urandom(min(rmn_size, write_chunk))) |
| rmn_size = rmn_size - min(rmn_size, write_chunk) |
| |
| fsum.write(file_url_template.format(num=i)) |
| |
| |
| class TestIterableDataPipeHttp(TestCase): |
| __server_thread: threading.Thread |
| __server_addr: str |
| __server: socketserver.TCPServer |
| |
| @classmethod |
| def setUpClass(cls): |
| try: |
| (cls.__server_thread, cls.__server_addr, |
| cls.__server) = setUpLocalServerInThread() |
| except Exception as e: |
| warnings.warn("TestIterableDataPipeHttp could\ |
| not set up due to {0}".format(str(e))) |
| |
| @classmethod |
| def tearDownClass(cls): |
| try: |
| cls.__server.shutdown() |
| cls.__server_thread.join(timeout=15) |
| except Exception as e: |
| warnings.warn("TestIterableDataPipeHttp could\ |
| not tear down (clean up temp directory or terminate\ |
| local server) due to {0}".format(str(e))) |
| |
| def _http_test_base(self, test_file_size, test_file_count, timeout=None, |
| chunk=None): |
| |
| def _get_data_from_tuple_fn(data, *args, **kwargs): |
| return data[args[0]] |
| |
| with tempfile.TemporaryDirectory(dir=os.getcwd()) as tmpdir: |
| # create tmp dir and files for test |
| base_tmp_dir = os.path.basename(os.path.normpath(tmpdir)) |
| file_url_template = ("http://{server_addr}/{tmp_dir}/" |
| "/webfile_test_{num}.data\n")\ |
| .format(server_addr=self.__server_addr, tmp_dir=base_tmp_dir, |
| num='{num}') |
| create_temp_files_for_serving(tmpdir, test_file_count, |
| test_file_size, file_url_template) |
| |
| datapipe_dir_f = dp.iter.ListDirFiles(tmpdir, '*_list') |
| datapipe_f_lines = dp.iter.ReadLinesFromFile(datapipe_dir_f) |
| datapipe_line_url: IterDataPipe[str] = \ |
| dp.iter.Map(datapipe_f_lines, _get_data_from_tuple_fn, (1,)) |
| datapipe_http = dp.iter.HttpReader(datapipe_line_url, |
| timeout=timeout) |
| datapipe_tob = dp.iter.ToBytes(datapipe_http, chunk=chunk) |
| |
| for (url, data) in datapipe_tob: |
| self.assertGreater(len(url), 0) |
| self.assertRegex(url, r'^http://.+\d+.data$') |
| if chunk is not None: |
| self.assertEqual(len(data), chunk) |
| else: |
| self.assertEqual(len(data), test_file_size) |
| |
| @unittest.skip("Stress test on large amount of files skipped\ |
| due to the CI timing constraint.") |
| def test_stress_http_reader_iterable_datapipes(self): |
| test_file_size = 10 |
| # STATS: It takes about 5 hours to stress test 16 * 1024 * 1024 |
| # files locally |
| test_file_count = 1024 |
| self._http_test_base(test_file_size, test_file_count) |
| |
| @unittest.skip("Test on the very large file skipped\ |
| due to the CI timing constraint.") |
| def test_large_files_http_reader_iterable_datapipes(self): |
| # STATS: It takes about 11 mins to test a large file of 64GB locally |
| test_file_size = 1024 * 1024 * 128 |
| test_file_count = 1 |
| timeout = 30 |
| chunk = 1024 * 1024 * 8 |
| self._http_test_base(test_file_size, test_file_count, timeout=timeout, |
| chunk=chunk) |
| |
| |
| class IDP_NoLen(IterDataPipe): |
| def __init__(self, input_dp): |
| super().__init__() |
| self.input_dp = input_dp |
| |
| def __iter__(self): |
| for i in self.input_dp: |
| yield i |
| |
| |
| class IDP(IterDataPipe): |
| def __init__(self, input_dp): |
| super().__init__() |
| self.input_dp = input_dp |
| self.length = len(input_dp) |
| |
| def __iter__(self): |
| for i in self.input_dp: |
| yield i |
| |
| def __len__(self): |
| return self.length |
| |
| |
| class MDP(MapDataPipe): |
| def __init__(self, input_dp): |
| super().__init__() |
| self.input_dp = input_dp |
| self.length = len(input_dp) |
| |
| def __getitem__(self, index): |
| return self.input_dp[index] |
| |
| def __len__(self) -> int: |
| return self.length |
| |
| |
| def _fake_fn(data, *args, **kwargs): |
| return data |
| |
| |
| def _fake_filter_fn(data, *args, **kwargs): |
| return data >= 5 |
| |
| |
| def _worker_init_fn(worker_id): |
| random.seed(123) |
| |
| |
| class TestFunctionalIterDataPipe(TestCase): |
| |
| # TODO(VitalyFedyunin): If dill installed this test fails |
| def _test_picklable(self): |
| arr = range(10) |
| picklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Tuple, Dict[str, Any]]] = [ |
| (dp.iter.Map, IDP(arr), (), {}), |
| (dp.iter.Map, IDP(arr), (_fake_fn, (0, ), {'test': True}), {}), |
| (dp.iter.Collate, IDP(arr), (), {}), |
| (dp.iter.Collate, IDP(arr), (_fake_fn, (0, ), {'test': True}), {}), |
| (dp.iter.Filter, IDP(arr), (_fake_filter_fn, (0, ), {'test': True}), {}), |
| ] |
| for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes: |
| p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg] |
| |
| unpicklable_datapipes: List[Tuple[Type[IterDataPipe], IterDataPipe, Tuple, Dict[str, Any]]] = [ |
| (dp.iter.Map, IDP(arr), (lambda x: x, ), {}), |
| (dp.iter.Collate, IDP(arr), (lambda x: x, ), {}), |
| (dp.iter.Filter, IDP(arr), (lambda x: x >= 5, ), {}), |
| ] |
| for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes: |
| with warnings.catch_warnings(record=True) as wa: |
| datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] |
| self.assertEqual(len(wa), 1) |
| self.assertRegex(str(wa[0].message), r"^Lambda function is not supported for pickle") |
| with self.assertRaises(AttributeError): |
| p = pickle.dumps(datapipe) |
| |
| def test_concat_datapipe(self): |
| input_dp1 = IDP(range(10)) |
| input_dp2 = IDP(range(5)) |
| |
| with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"): |
| dp.iter.Concat() |
| |
| with self.assertRaisesRegex(TypeError, r"Expected all inputs to be `IterDataPipe`"): |
| dp.iter.Concat(input_dp1, ()) # type: ignore[arg-type] |
| |
| concat_dp = input_dp1.concat(input_dp2) |
| self.assertEqual(len(concat_dp), 15) |
| self.assertEqual(list(concat_dp), list(range(10)) + list(range(5))) |
| |
| # Test Reset |
| self.assertEqual(list(concat_dp), list(range(10)) + list(range(5))) |
| |
| input_dp_nl = IDP_NoLen(range(5)) |
| |
| concat_dp = input_dp1.concat(input_dp_nl) |
| with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): |
| len(concat_dp) |
| |
| self.assertEqual(list(concat_dp), list(range(10)) + list(range(5))) |
| |
| def test_map_datapipe(self): |
| input_dp = IDP(range(10)) |
| |
| def fn(item, dtype=torch.float, *, sum=False): |
| data = torch.tensor(item, dtype=dtype) |
| return data if not sum else data.sum() |
| |
| map_dp = input_dp.map(fn) |
| self.assertEqual(len(input_dp), len(map_dp)) |
| for x, y in zip(map_dp, input_dp): |
| self.assertEqual(x, torch.tensor(y, dtype=torch.float)) |
| |
| map_dp = input_dp.map(fn=fn, fn_args=(torch.int, ), fn_kwargs={'sum': True}) |
| self.assertEqual(len(input_dp), len(map_dp)) |
| for x, y in zip(map_dp, input_dp): |
| self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum()) |
| |
| from functools import partial |
| map_dp = input_dp.map(partial(fn, dtype=torch.int, sum=True)) |
| self.assertEqual(len(input_dp), len(map_dp)) |
| for x, y in zip(map_dp, input_dp): |
| self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum()) |
| |
| input_dp_nl = IDP_NoLen(range(10)) |
| map_dp_nl = input_dp_nl.map() |
| with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): |
| len(map_dp_nl) |
| for x, y in zip(map_dp_nl, input_dp_nl): |
| self.assertEqual(x, torch.tensor(y, dtype=torch.float)) |
| |
| # TODO(VitalyFedyunin): If dill installed this test fails |
| def _test_map_datapipe_nested_level(self): |
| |
| input_dp = IDP([list(range(10)) for _ in range(3)]) |
| |
| def fn(item, *, dtype=torch.float): |
| return torch.tensor(item, dtype=dtype) |
| |
| with warnings.catch_warnings(record=True) as wa: |
| map_dp = input_dp.map(lambda ls: ls * 2, nesting_level=0) |
| self.assertEqual(len(wa), 1) |
| self.assertRegex(str(wa[0].message), r"^Lambda function is not supported for pickle") |
| self.assertEqual(len(input_dp), len(map_dp)) |
| for x, y in zip(map_dp, input_dp): |
| self.assertEqual(x, y * 2) |
| |
| map_dp = input_dp.map(fn, nesting_level=1) |
| self.assertEqual(len(input_dp), len(map_dp)) |
| for x, y in zip(map_dp, input_dp): |
| self.assertEqual(len(x), len(y)) |
| for a, b in zip(x, y): |
| self.assertEqual(a, torch.tensor(b, dtype=torch.float)) |
| |
| map_dp = input_dp.map(fn, nesting_level=-1) |
| self.assertEqual(len(input_dp), len(map_dp)) |
| for x, y in zip(map_dp, input_dp): |
| self.assertEqual(len(x), len(y)) |
| for a, b in zip(x, y): |
| self.assertEqual(a, torch.tensor(b, dtype=torch.float)) |
| |
| map_dp = input_dp.map(fn, nesting_level=4) |
| with self.assertRaises(IndexError): |
| list(map_dp) |
| |
| with self.assertRaises(ValueError): |
| input_dp.map(fn, nesting_level=-2) |
| |
| def test_collate_datapipe(self): |
| arrs = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] |
| input_dp = IDP(arrs) |
| |
| def _collate_fn(batch): |
| return torch.tensor(sum(batch), dtype=torch.float) |
| |
| collate_dp = input_dp.collate(collate_fn=_collate_fn) |
| self.assertEqual(len(input_dp), len(collate_dp)) |
| for x, y in zip(collate_dp, input_dp): |
| self.assertEqual(x, torch.tensor(sum(y), dtype=torch.float)) |
| |
| input_dp_nl = IDP_NoLen(arrs) |
| collate_dp_nl = input_dp_nl.collate() |
| with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): |
| len(collate_dp_nl) |
| for x, y in zip(collate_dp_nl, input_dp_nl): |
| self.assertEqual(x, torch.tensor(y)) |
| |
| def test_batch_datapipe(self): |
| arrs = list(range(10)) |
| input_dp = IDP(arrs) |
| with self.assertRaises(AssertionError): |
| input_dp.batch(batch_size=0) |
| |
| # Default not drop the last batch |
| bs = 3 |
| batch_dp = input_dp.batch(batch_size=bs) |
| self.assertEqual(len(batch_dp), 4) |
| for i, batch in enumerate(batch_dp): |
| self.assertEqual(len(batch), 1 if i == 3 else bs) |
| self.assertEqual(batch, arrs[i * bs: i * bs + len(batch)]) |
| |
| # Drop the last batch |
| bs = 4 |
| batch_dp = input_dp.batch(batch_size=bs, drop_last=True) |
| self.assertEqual(len(batch_dp), 2) |
| for i, batch in enumerate(batch_dp): |
| self.assertEqual(len(batch), bs) |
| self.assertEqual(batch, arrs[i * bs: i * bs + len(batch)]) |
| |
| input_dp_nl = IDP_NoLen(range(10)) |
| batch_dp_nl = input_dp_nl.batch(batch_size=2) |
| with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): |
| len(batch_dp_nl) |
| |
| def test_unbatch_datapipe(self): |
| |
| target_length = 6 |
| prebatch_dp = IDP(range(target_length)) |
| |
| input_dp = prebatch_dp.batch(3) |
| unbatch_dp = input_dp.unbatch() |
| self.assertEqual(len(list(unbatch_dp)), target_length) |
| for i, res in zip(prebatch_dp, unbatch_dp): |
| self.assertEqual(i, res) |
| |
| input_dp = IDP([[0, 1, 2], [3, 4, 5]]) |
| unbatch_dp = input_dp.unbatch() |
| self.assertEqual(len(list(unbatch_dp)), target_length) |
| for i, res in zip(prebatch_dp, unbatch_dp): |
| self.assertEqual(i, res) |
| |
| input_dp = IDP([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]) |
| |
| unbatch_dp = input_dp.unbatch() |
| expected_dp = [[0, 1], [2, 3], [4, 5], [6, 7]] |
| self.assertEqual(len(list(unbatch_dp)), 4) |
| for i, res in zip(expected_dp, unbatch_dp): |
| self.assertEqual(i, res) |
| |
| unbatch_dp = input_dp.unbatch(unbatch_level=2) |
| expected_dp2 = [0, 1, 2, 3, 4, 5, 6, 7] |
| self.assertEqual(len(list(unbatch_dp)), 8) |
| for i, res in zip(expected_dp2, unbatch_dp): |
| self.assertEqual(i, res) |
| |
| unbatch_dp = input_dp.unbatch(unbatch_level=-1) |
| self.assertEqual(len(list(unbatch_dp)), 8) |
| for i, res in zip(expected_dp2, unbatch_dp): |
| self.assertEqual(i, res) |
| |
| input_dp = IDP([[0, 1, 2], [3, 4, 5]]) |
| with self.assertRaises(ValueError): |
| unbatch_dp = input_dp.unbatch(unbatch_level=-2) |
| for i in unbatch_dp: |
| print(i) |
| |
| with self.assertRaises(IndexError): |
| unbatch_dp = input_dp.unbatch(unbatch_level=5) |
| for i in unbatch_dp: |
| print(i) |
| |
| |
| def test_bucket_batch_datapipe(self): |
| input_dp = IDP(range(20)) |
| with self.assertRaises(AssertionError): |
| input_dp.bucket_batch(batch_size=0) |
| |
| input_dp_nl = IDP_NoLen(range(20)) |
| bucket_dp_nl = input_dp_nl.bucket_batch(batch_size=7) |
| with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): |
| len(bucket_dp_nl) |
| |
| # Test Bucket Batch without sort_key |
| def _helper(**kwargs): |
| arrs = list(range(100)) |
| random.shuffle(arrs) |
| input_dp = IDP(arrs) |
| bucket_dp = input_dp.bucket_batch(**kwargs) |
| if kwargs["sort_key"] is None: |
| # BatchDataset as reference |
| ref_dp = input_dp.batch(batch_size=kwargs['batch_size'], drop_last=kwargs['drop_last']) |
| for batch, rbatch in zip(bucket_dp, ref_dp): |
| self.assertEqual(batch, rbatch) |
| else: |
| bucket_size = bucket_dp.bucket_size |
| bucket_num = (len(input_dp) - 1) // bucket_size + 1 |
| it = iter(bucket_dp) |
| for i in range(bucket_num): |
| ref = sorted(arrs[i * bucket_size: (i + 1) * bucket_size]) |
| bucket: List = [] |
| while len(bucket) < len(ref): |
| try: |
| batch = next(it) |
| bucket += batch |
| # If drop last, stop in advance |
| except StopIteration: |
| break |
| if len(bucket) != len(ref): |
| ref = ref[:len(bucket)] |
| # Sorted bucket |
| self.assertEqual(bucket, ref) |
| |
| _helper(batch_size=7, drop_last=False, sort_key=None) |
| _helper(batch_size=7, drop_last=True, bucket_size_mul=5, sort_key=None) |
| |
| # Test Bucket Batch with sort_key |
| def _sort_fn(data): |
| return data |
| |
| _helper(batch_size=7, drop_last=False, bucket_size_mul=5, sort_key=_sort_fn) |
| _helper(batch_size=7, drop_last=True, bucket_size_mul=5, sort_key=_sort_fn) |
| |
| def test_filter_datapipe(self): |
| input_ds = IDP(range(10)) |
| |
| def _filter_fn(data, val, clip=False): |
| if clip: |
| return data >= val |
| return True |
| |
| filter_dp = input_ds.filter(filter_fn=_filter_fn, fn_args=(5, )) |
| for data, exp in zip(filter_dp, range(10)): |
| self.assertEqual(data, exp) |
| |
| filter_dp = input_ds.filter(filter_fn=_filter_fn, fn_kwargs={'val': 5, 'clip': True}) |
| for data, exp in zip(filter_dp, range(5, 10)): |
| self.assertEqual(data, exp) |
| |
| with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): |
| len(filter_dp) |
| |
| def _non_bool_fn(data): |
| return 1 |
| |
| filter_dp = input_ds.filter(filter_fn=_non_bool_fn) |
| with self.assertRaises(ValueError): |
| temp = list(filter_dp) |
| |
| def test_filter_datapipe_nested_list(self): |
| |
| input_ds = IDP(range(10)).batch(5) |
| |
| def _filter_fn(data, val): |
| return data >= val |
| |
| filter_dp = input_ds.filter(nesting_level=-1, filter_fn=_filter_fn, fn_kwargs={'val': 5}) |
| expected_dp1 = [[5, 6, 7, 8, 9]] |
| self.assertEqual(len(list(filter_dp)), len(expected_dp1)) |
| for data, exp in zip(filter_dp, expected_dp1): |
| self.assertEqual(data, exp) |
| |
| filter_dp = input_ds.filter(nesting_level=-1, drop_empty_batches=False, filter_fn=_filter_fn, fn_kwargs={'val': 5}) |
| expected_dp2: List[List[int]] = [[], [5, 6, 7, 8, 9]] |
| self.assertEqual(len(list(filter_dp)), len(expected_dp2)) |
| for data, exp in zip(filter_dp, expected_dp2): |
| self.assertEqual(data, exp) |
| |
| with self.assertRaises(IndexError): |
| filter_dp = input_ds.filter(nesting_level=5, filter_fn=_filter_fn, fn_kwargs={'val': 5}) |
| temp = list(filter_dp) |
| |
| input_ds = IDP(range(10)).batch(3) |
| |
| filter_dp = input_ds.filter(lambda ls: len(ls) >= 3) |
| expected_dp3: List[List[int]] = [[0, 1, 2], [3, 4, 5], [6, 7, 8]] |
| self.assertEqual(len(list(filter_dp)), len(expected_dp3)) |
| for data, exp in zip(filter_dp, expected_dp3): |
| self.assertEqual(data, exp) |
| |
| input_ds = IDP([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [1, 2, 3]]]) |
| filter_dp = input_ds.filter(lambda x: x > 3, nesting_level=-1) |
| expected_dp4 = [[[4, 5]], [[6, 7, 8]]] |
| self.assertEqual(len(list(filter_dp)), len(expected_dp4)) |
| for data2, exp2 in zip(filter_dp, expected_dp4): |
| self.assertEqual(data2, exp2) |
| |
| input_ds = IDP([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [1, 2, 3]]]) |
| filter_dp = input_ds.filter(lambda x: x > 7, nesting_level=-1) |
| expected_dp5 = [[[8]]] |
| self.assertEqual(len(list(filter_dp)), len(expected_dp5)) |
| for data2, exp2 in zip(filter_dp, expected_dp5): |
| self.assertEqual(data2, exp2) |
| |
| input_ds = IDP([[[0, 1], [3, 4]], [[6, 7, 8], [1, 2, 3]]]) |
| filter_dp = input_ds.filter(lambda ls: len(ls) >= 3, nesting_level=1) |
| expected_dp6 = [[[6, 7, 8], [1, 2, 3]]] |
| self.assertEqual(len(list(filter_dp)), len(expected_dp6)) |
| for data2, exp2 in zip(filter_dp, expected_dp6): |
| self.assertEqual(data2, exp2) |
| |
| |
| def test_sampler_datapipe(self): |
| input_dp = IDP(range(10)) |
| # Default SequentialSampler |
| sampled_dp = dp.iter.Sampler(input_dp) # type: ignore[var-annotated] |
| self.assertEqual(len(sampled_dp), 10) |
| for i, x in enumerate(sampled_dp): |
| self.assertEqual(x, i) |
| |
| # RandomSampler |
| random_sampled_dp = dp.iter.Sampler(input_dp, sampler=RandomSampler, sampler_kwargs={'replacement': True}) # type: ignore[var-annotated] # noqa: B950 |
| |
| # Requires `__len__` to build SamplerDataPipe |
| input_dp_nolen = IDP_NoLen(range(10)) |
| with self.assertRaises(AssertionError): |
| sampled_dp = dp.iter.Sampler(input_dp_nolen) |
| |
| def test_shuffle_datapipe(self): |
| exp = list(range(20)) |
| input_ds = IDP(exp) |
| |
| with self.assertRaises(AssertionError): |
| shuffle_dp = input_ds.shuffle(buffer_size=0) |
| |
| for bs in (5, 20, 25): |
| shuffle_dp = input_ds.shuffle(buffer_size=bs) |
| self.assertEqual(len(shuffle_dp), len(input_ds)) |
| |
| random.seed(123) |
| res = list(shuffle_dp) |
| self.assertEqual(sorted(res), exp) |
| |
| # Test Deterministic |
| for num_workers in (0, 1): |
| random.seed(123) |
| dl = DataLoader(shuffle_dp, num_workers=num_workers, worker_init_fn=_worker_init_fn) |
| dl_res = list(dl) |
| self.assertEqual(res, dl_res) |
| |
| shuffle_dp_nl = IDP_NoLen(range(20)).shuffle(buffer_size=5) |
| with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): |
| len(shuffle_dp_nl) |
| |
| @skipIfNoTorchVision |
| def test_transforms_datapipe(self): |
| torch.set_default_dtype(torch.float) |
| # A sequence of numpy random numbers representing 3-channel images |
| w = h = 32 |
| inputs = [np.random.randint(0, 255, (h, w, 3), dtype=np.uint8) for i in range(10)] |
| tensor_inputs = [torch.tensor(x, dtype=torch.float).permute(2, 0, 1) / 255. for x in inputs] |
| |
| input_dp = IDP(inputs) |
| # Raise TypeError for python function |
| with self.assertRaisesRegex(TypeError, r"`transforms` are required to be"): |
| input_dp.legacy_transforms(_fake_fn) |
| |
| # transforms.Compose of several transforms |
| transforms = torchvision.transforms.Compose([ |
| torchvision.transforms.ToTensor(), |
| torchvision.transforms.Pad(1, fill=1, padding_mode='constant'), |
| ]) |
| tsfm_dp = input_dp.legacy_transforms(transforms) |
| self.assertEqual(len(tsfm_dp), len(input_dp)) |
| for tsfm_data, input_data in zip(tsfm_dp, tensor_inputs): |
| self.assertEqual(tsfm_data[:, 1:(h + 1), 1:(w + 1)], input_data) |
| |
| # nn.Sequential of several transforms (required to be instances of nn.Module) |
| input_dp = IDP(tensor_inputs) |
| transforms = nn.Sequential( |
| torchvision.transforms.Pad(1, fill=1, padding_mode='constant'), |
| ) |
| tsfm_dp = input_dp.legacy_transforms(transforms) |
| self.assertEqual(len(tsfm_dp), len(input_dp)) |
| for tsfm_data, input_data in zip(tsfm_dp, tensor_inputs): |
| self.assertEqual(tsfm_data[:, 1:(h + 1), 1:(w + 1)], input_data) |
| |
| # Single transform |
| input_dp = IDP_NoLen(inputs) # type: ignore[assignment] |
| transform = torchvision.transforms.ToTensor() |
| tsfm_dp = input_dp.legacy_transforms(transform) |
| with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): |
| len(tsfm_dp) |
| for tsfm_data, input_data in zip(tsfm_dp, tensor_inputs): |
| self.assertEqual(tsfm_data, input_data) |
| |
| def test_zip_datapipe(self): |
| with self.assertRaises(TypeError): |
| dp.iter.Zip(IDP(range(10)), list(range(10))) # type: ignore[arg-type] |
| |
| zipped_dp = dp.iter.Zip(IDP(range(10)), IDP_NoLen(range(5))) # type: ignore[var-annotated] |
| with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): |
| len(zipped_dp) |
| exp = list((i, i) for i in range(5)) |
| self.assertEqual(list(zipped_dp), exp) |
| |
| zipped_dp = dp.iter.Zip(IDP(range(10)), IDP(range(5))) |
| self.assertEqual(len(zipped_dp), 5) |
| self.assertEqual(list(zipped_dp), exp) |
| # Reset |
| self.assertEqual(list(zipped_dp), exp) |
| |
| |
| class TestFunctionalMapDataPipe(TestCase): |
| # TODO(VitalyFedyunin): If dill installed this test fails |
| def _test_picklable(self): |
| arr = range(10) |
| picklable_datapipes: List[ |
| Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]] |
| ] = [ |
| (dp.map.Map, MDP(arr), (), {}), |
| (dp.map.Map, MDP(arr), (_fake_fn, (0,), {'test': True}), {}), |
| ] |
| for dpipe, input_dp, dp_args, dp_kwargs in picklable_datapipes: |
| p = pickle.dumps(dpipe(input_dp, *dp_args, **dp_kwargs)) # type: ignore[call-arg] |
| |
| unpicklable_datapipes: List[ |
| Tuple[Type[MapDataPipe], MapDataPipe, Tuple, Dict[str, Any]] |
| ] = [ |
| (dp.map.Map, MDP(arr), (lambda x: x,), {}), |
| ] |
| for dpipe, input_dp, dp_args, dp_kwargs in unpicklable_datapipes: |
| with warnings.catch_warnings(record=True) as wa: |
| datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] |
| self.assertEqual(len(wa), 1) |
| self.assertRegex( |
| str(wa[0].message), r"^Lambda function is not supported for pickle" |
| ) |
| with self.assertRaises(AttributeError): |
| p = pickle.dumps(datapipe) |
| |
| def test_concat_datapipe(self): |
| input_dp1 = MDP(range(10)) |
| input_dp2 = MDP(range(5)) |
| |
| with self.assertRaisesRegex(ValueError, r"Expected at least one DataPipe"): |
| dp.map.Concat() |
| |
| with self.assertRaisesRegex(TypeError, r"Expected all inputs to be `MapDataPipe`"): |
| dp.map.Concat(input_dp1, ()) # type: ignore[arg-type] |
| |
| concat_dp = input_dp1.concat(input_dp2) |
| self.assertEqual(len(concat_dp), 15) |
| for index in range(15): |
| self.assertEqual(concat_dp[index], (list(range(10)) + list(range(5)))[index]) |
| self.assertEqual(list(concat_dp), list(range(10)) + list(range(5))) |
| |
| def test_map_datapipe(self): |
| arr = range(10) |
| input_dp = MDP(arr) |
| |
| def fn(item, dtype=torch.float, *, sum=False): |
| data = torch.tensor(item, dtype=dtype) |
| return data if not sum else data.sum() |
| |
| map_dp = input_dp.map(fn) |
| self.assertEqual(len(input_dp), len(map_dp)) |
| for index in arr: |
| self.assertEqual( |
| map_dp[index], torch.tensor(input_dp[index], dtype=torch.float) |
| ) |
| |
| map_dp = input_dp.map(fn=fn, fn_args=(torch.int,), fn_kwargs={'sum': True}) |
| self.assertEqual(len(input_dp), len(map_dp)) |
| for index in arr: |
| self.assertEqual( |
| map_dp[index], torch.tensor(input_dp[index], dtype=torch.int).sum() |
| ) |
| |
| from functools import partial |
| |
| map_dp = input_dp.map(partial(fn, dtype=torch.int, sum=True)) |
| self.assertEqual(len(input_dp), len(map_dp)) |
| for index in arr: |
| self.assertEqual( |
| map_dp[index], torch.tensor(input_dp[index], dtype=torch.int).sum() |
| ) |
| |
| |
| # Metaclass conflict for Python 3.6 |
| # Multiple inheritance with NamedTuple is not supported for Python 3.9 |
| _generic_namedtuple_allowed = sys.version_info >= (3, 7) and sys.version_info < (3, 9) |
| if _generic_namedtuple_allowed: |
| class InvalidData(Generic[T_co], NamedTuple): |
| name: str |
| data: T_co |
| |
| |
| class TestTyping(TestCase): |
| def test_subtype(self): |
| from torch.utils.data._typing import issubtype |
| |
| basic_type = (int, str, bool, float, complex, |
| list, tuple, dict, set, T_co) |
| for t in basic_type: |
| self.assertTrue(issubtype(t, t)) |
| self.assertTrue(issubtype(t, Any)) |
| if t == T_co: |
| self.assertTrue(issubtype(Any, t)) |
| else: |
| self.assertFalse(issubtype(Any, t)) |
| for t1, t2 in itertools.product(basic_type, basic_type): |
| if t1 == t2 or t2 == T_co: |
| self.assertTrue(issubtype(t1, t2)) |
| else: |
| self.assertFalse(issubtype(t1, t2)) |
| |
| T = TypeVar('T', int, str) |
| S = TypeVar('S', bool, Union[str, int], Tuple[int, T]) # type: ignore[valid-type] |
| types = ((int, Optional[int]), |
| (List, Union[int, list]), |
| (Tuple[int, str], S), |
| (Tuple[int, str], tuple), |
| (T, S), |
| (S, T_co), |
| (T, Union[S, Set])) |
| for sub, par in types: |
| self.assertTrue(issubtype(sub, par)) |
| self.assertFalse(issubtype(par, sub)) |
| |
| subscriptable_types = { |
| List: 1, |
| Tuple: 2, # use 2 parameters |
| Set: 1, |
| Dict: 2, |
| } |
| for subscript_type, n in subscriptable_types.items(): |
| for ts in itertools.combinations(types, n): |
| subs, pars = zip(*ts) |
| sub = subscript_type[subs] # type: ignore[index] |
| par = subscript_type[pars] # type: ignore[index] |
| self.assertTrue(issubtype(sub, par)) |
| self.assertFalse(issubtype(par, sub)) |
| # Non-recursive check |
| self.assertTrue(issubtype(par, sub, recursive=False)) |
| |
| def test_issubinstance(self): |
| from torch.utils.data._typing import issubinstance |
| |
| basic_data = (1, '1', True, 1., complex(1., 0.)) |
| basic_type = (int, str, bool, float, complex) |
| S = TypeVar('S', bool, Union[str, int]) |
| for d in basic_data: |
| self.assertTrue(issubinstance(d, Any)) |
| self.assertTrue(issubinstance(d, T_co)) |
| if type(d) in (bool, int, str): |
| self.assertTrue(issubinstance(d, S)) |
| else: |
| self.assertFalse(issubinstance(d, S)) |
| for t in basic_type: |
| if type(d) == t: |
| self.assertTrue(issubinstance(d, t)) |
| else: |
| self.assertFalse(issubinstance(d, t)) |
| # list/set |
| dt = (([1, '1', 2], List), (set({1, '1', 2}), Set)) |
| for d, t in dt: |
| self.assertTrue(issubinstance(d, t)) |
| self.assertTrue(issubinstance(d, t[T_co])) # type: ignore[index] |
| self.assertFalse(issubinstance(d, t[int])) # type: ignore[index] |
| |
| # dict |
| d = dict({'1': 1, '2': 2.}) |
| self.assertTrue(issubinstance(d, Dict)) |
| self.assertTrue(issubinstance(d, Dict[str, T_co])) |
| self.assertFalse(issubinstance(d, Dict[str, int])) |
| |
| # tuple |
| d = (1, '1', 2) |
| self.assertTrue(issubinstance(d, Tuple)) |
| self.assertTrue(issubinstance(d, Tuple[int, str, T_co])) |
| self.assertFalse(issubinstance(d, Tuple[int, Any])) |
| self.assertFalse(issubinstance(d, Tuple[int, int, int])) |
| |
| # Static checking annotation |
| def test_compile_time(self): |
| with self.assertRaisesRegex(TypeError, r"Expected 'Iterator' as the return"): |
| class InvalidDP1(IterDataPipe[int]): |
| def __iter__(self) -> str: # type: ignore[misc, override] |
| yield 0 |
| |
| with self.assertRaisesRegex(TypeError, r"Expected return type of '__iter__'"): |
| class InvalidDP2(IterDataPipe[Tuple]): |
| def __iter__(self) -> Iterator[int]: # type: ignore[override] |
| yield 0 |
| |
| with self.assertRaisesRegex(TypeError, r"Expected return type of '__iter__'"): |
| class InvalidDP3(IterDataPipe[Tuple[int, str]]): |
| def __iter__(self) -> Iterator[tuple]: # type: ignore[override] |
| yield (0, ) |
| |
| if _generic_namedtuple_allowed: |
| with self.assertRaisesRegex(TypeError, r"is not supported by Python typing"): |
| class InvalidDP4(IterDataPipe["InvalidData[int]"]): # type: ignore[type-arg, misc] |
| pass |
| |
| class DP1(IterDataPipe[Tuple[int, str]]): |
| def __init__(self, length): |
| self.length = length |
| |
| def __iter__(self) -> Iterator[Tuple[int, str]]: |
| for d in range(self.length): |
| yield d, str(d) |
| |
| self.assertTrue(issubclass(DP1, IterDataPipe)) |
| dp1 = DP1(10) |
| self.assertTrue(DP1.type.issubtype(dp1.type) and dp1.type.issubtype(DP1.type)) |
| dp2 = DP1(5) |
| self.assertEqual(dp1.type, dp2.type) |
| |
| with self.assertRaisesRegex(TypeError, r"is not a generic class"): |
| class InvalidDP5(DP1[tuple]): # type: ignore[type-arg] |
| def __iter__(self) -> Iterator[tuple]: # type: ignore[override] |
| yield (0, ) |
| |
| class DP2(IterDataPipe[T_co]): |
| def __iter__(self) -> Iterator[T_co]: |
| for d in range(10): |
| yield d # type: ignore[misc] |
| |
| self.assertTrue(issubclass(DP2, IterDataPipe)) |
| dp1 = DP2() # type: ignore[assignment] |
| self.assertTrue(DP2.type.issubtype(dp1.type) and dp1.type.issubtype(DP2.type)) |
| dp2 = DP2() # type: ignore[assignment] |
| self.assertEqual(dp1.type, dp2.type) |
| |
| class DP3(IterDataPipe[Tuple[T_co, str]]): |
| r""" DataPipe without fixed type with __init__ function""" |
| def __init__(self, datasource): |
| self.datasource = datasource |
| |
| def __iter__(self) -> Iterator[Tuple[T_co, str]]: |
| for d in self.datasource: |
| yield d, str(d) |
| |
| self.assertTrue(issubclass(DP3, IterDataPipe)) |
| dp1 = DP3(range(10)) # type: ignore[assignment] |
| self.assertTrue(DP3.type.issubtype(dp1.type) and dp1.type.issubtype(DP3.type)) |
| dp2 = DP3(5) # type: ignore[assignment] |
| self.assertEqual(dp1.type, dp2.type) |
| |
| class DP4(IterDataPipe[tuple]): |
| r""" DataPipe without __iter__ annotation""" |
| def __iter__(self): |
| raise NotImplementedError |
| |
| self.assertTrue(issubclass(DP4, IterDataPipe)) |
| dp = DP4() |
| self.assertTrue(dp.type.param == tuple) |
| |
| class DP5(IterDataPipe): |
| r""" DataPipe without type annotation""" |
| def __iter__(self) -> Iterator[str]: |
| raise NotImplementedError |
| |
| self.assertTrue(issubclass(DP5, IterDataPipe)) |
| dp = DP5() # type: ignore[assignment] |
| from torch.utils.data._typing import issubtype |
| self.assertTrue(issubtype(dp.type.param, Any) and issubtype(Any, dp.type.param)) |
| |
| class DP6(IterDataPipe[int]): |
| r""" DataPipe with plain Iterator""" |
| def __iter__(self) -> Iterator: |
| raise NotImplementedError |
| |
| self.assertTrue(issubclass(DP6, IterDataPipe)) |
| dp = DP6() # type: ignore[assignment] |
| self.assertTrue(dp.type.param == int) |
| |
| class DP7(IterDataPipe[Awaitable[T_co]]): |
| r""" DataPipe with abstract base class""" |
| |
| self.assertTrue(issubclass(DP6, IterDataPipe)) |
| self.assertTrue(DP7.type.param == Awaitable[T_co]) |
| |
| class DP8(DP7[str]): |
| r""" DataPipe subclass from a DataPipe with abc type""" |
| |
| self.assertTrue(issubclass(DP8, IterDataPipe)) |
| self.assertTrue(DP8.type.param == Awaitable[str]) |
| |
| |
| def test_construct_time(self): |
| class DP0(IterDataPipe[Tuple]): |
| @argument_validation |
| def __init__(self, dp: IterDataPipe): |
| self.dp = dp |
| |
| def __iter__(self) -> Iterator[Tuple]: |
| for d in self.dp: |
| yield d, str(d) |
| |
| class DP1(IterDataPipe[int]): |
| @argument_validation |
| def __init__(self, dp: IterDataPipe[Tuple[int, str]]): |
| self.dp = dp |
| |
| def __iter__(self) -> Iterator[int]: |
| for a, b in self.dp: |
| yield a |
| |
| # Non-DataPipe input with DataPipe hint |
| datasource = [(1, '1'), (2, '2'), (3, '3')] |
| with self.assertRaisesRegex(TypeError, r"Expected argument 'dp' as a IterDataPipe"): |
| dp = DP0(datasource) |
| |
| dp = DP0(IDP(range(10))) |
| with self.assertRaisesRegex(TypeError, r"Expected type of argument 'dp' as a subtype"): |
| dp = DP1(dp) |
| |
| def test_runtime(self): |
| class DP(IterDataPipe[Tuple[int, T_co]]): |
| def __init__(self, datasource): |
| self.ds = datasource |
| |
| @runtime_validation |
| def __iter__(self) -> Iterator[Tuple[int, T_co]]: |
| for d in self.ds: |
| yield d |
| |
| dss = ([(1, '1'), (2, '2')], |
| [(1, 1), (2, '2')]) |
| for ds in dss: |
| dp = DP(ds) # type: ignore[var-annotated] |
| self.assertEqual(list(dp), ds) |
| # Reset __iter__ |
| self.assertEqual(list(dp), ds) |
| |
| dss = ([(1, 1), ('2', 2)], # type: ignore[assignment, list-item] |
| [[1, '1'], [2, '2']], # type: ignore[list-item] |
| [1, '1', 2, '2']) |
| for ds in dss: |
| dp = DP(ds) |
| with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"): |
| list(dp) |
| |
| with runtime_validation_disabled(): |
| self.assertEqual(list(dp), ds) |
| with runtime_validation_disabled(): |
| self.assertEqual(list(dp), ds) |
| |
| with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"): |
| list(dp) |
| |
| |
| def test_reinforce(self): |
| T = TypeVar('T', int, str) |
| |
| |
| class DP(IterDataPipe[T]): |
| def __init__(self, ds): |
| self.ds = ds |
| |
| @runtime_validation |
| def __iter__(self) -> Iterator[T]: |
| for d in self.ds: |
| yield d |
| |
| ds = list(range(10)) |
| # Valid type reinforcement |
| dp = DP(ds).reinforce_type(int) |
| self.assertTrue(dp.type, int) |
| self.assertEqual(list(dp), ds) |
| |
| # Invalid type |
| with self.assertRaisesRegex(TypeError, r"'expected_type' must be a type"): |
| dp = DP(ds).reinforce_type(1) |
| |
| # Type is not subtype |
| with self.assertRaisesRegex(TypeError, r"Expected 'expected_type' as subtype of"): |
| dp = DP(ds).reinforce_type(float) |
| |
| # Invalid data at runtime |
| dp = DP(ds).reinforce_type(str) |
| with self.assertRaisesRegex(RuntimeError, r"Expected an instance as subtype"): |
| list(dp) |
| |
| # Context Manager to disable the runtime validation |
| with runtime_validation_disabled(): |
| self.assertEqual(list(d for d in dp), ds) |
| |
| class NumbersDataset(IterDataPipe): |
| def __init__(self, size=10): |
| self.size = size |
| |
| def __iter__(self): |
| for i in range(self.size): |
| yield i |
| |
| |
| class TestGraph(TestCase): |
| @skipIfNoDill |
| def test_simple_traverse(self): |
| numbers_dp = NumbersDataset(size=50) |
| mapped_dp = numbers_dp.map(lambda x: x * 10) |
| graph = torch.utils.data.graph.traverse(mapped_dp) |
| expected : Dict[Any, Any] = {mapped_dp: {numbers_dp: {}}} |
| self.assertEqual(expected, graph) |
| |
| # TODO(VitalyFedyunin): This test is incorrect because of 'buffer' nature |
| # of the fork fake implementation, update fork first and fix this test too |
| @skipIfNoDill |
| def test_traverse_forked(self): |
| numbers_dp = NumbersDataset(size=50) |
| dp0, dp1, dp2 = numbers_dp.fork(3) |
| dp0_upd = dp0.map(lambda x: x * 10) |
| dp1_upd = dp1.filter(lambda x: x % 3 == 1) |
| combined_dp = dp0_upd.mux(dp1_upd, dp2) |
| graph = torch.utils.data.graph.traverse(combined_dp) |
| expected = {combined_dp: {dp0_upd: {dp0: {}}, dp1_upd: {dp1: {}}, dp2: {}}} |
| self.assertEqual(expected, graph) |
| |
| |
| class TestSharding(TestCase): |
| def _get_pipeline(self): |
| numbers_dp = NumbersDataset(size=10) |
| dp0, dp1 = numbers_dp.fork(2) |
| dp0_upd = dp0.map(lambda x: x * 10) |
| dp1_upd = dp1.filter(lambda x: x % 3 == 1) |
| combined_dp = dp0_upd.mux(dp1_upd) |
| return combined_dp |
| |
| @skipIfNoDill |
| def test_simple_sharding(self): |
| sharded_dp = self._get_pipeline().sharding_filter() |
| torch.utils.data.sharding.apply_sharding(sharded_dp, 3, 1) |
| items = list(sharded_dp) |
| self.assertEqual([1, 20, 40, 70], items) |
| |
| all_items = list(self._get_pipeline()) |
| items = [] |
| for i in range(3): |
| sharded_dp = self._get_pipeline().sharding_filter() |
| torch.utils.data.sharding.apply_sharding(sharded_dp, 3, i) |
| items += list(sharded_dp) |
| |
| self.assertEqual(sorted(all_items), sorted(items)) |
| |
| @skipIfNoDill |
| def test_old_dataloader(self): |
| dp = self._get_pipeline() |
| expected = list(dp) |
| |
| dp = self._get_pipeline().sharding_filter() |
| dl = DataLoader(dp, batch_size=1, shuffle=False, num_workers=2, |
| worker_init_fn=torch.utils.data.backward_compatibility.worker_init_fn) |
| items = [] |
| for i in dl: |
| items.append(i) |
| |
| self.assertEqual(sorted(expected), sorted(items)) |
| |
| if __name__ == '__main__': |
| run_tests() |