|  | import math | 
|  | import sys | 
|  | import errno | 
|  | import os | 
|  | import ctypes | 
|  | import signal | 
|  | import torch | 
|  | import time | 
|  | import traceback | 
|  | import unittest | 
|  | from torch import multiprocessing | 
|  | from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset | 
|  | from torch.utils.data.dataset import random_split | 
|  | from torch.utils.data.dataloader import default_collate, ExceptionWrapper | 
|  | from common import TestCase, run_tests, TEST_NUMPY, IS_WINDOWS | 
|  | from common_nn import TEST_CUDA | 
|  |  | 
|  |  | 
|  | JOIN_TIMEOUT = 17.0 if IS_WINDOWS else 4.5 | 
|  |  | 
|  |  | 
|  | class TestDatasetRandomSplit(TestCase): | 
|  | def test_lengths_must_equal_datset_size(self): | 
|  | with self.assertRaises(ValueError): | 
|  | random_split([1, 2, 3, 4], [1, 2]) | 
|  |  | 
|  | def test_splits_have_correct_size(self): | 
|  | splits = random_split([1, 2, 3, 4, 5, 6], [2, 4]) | 
|  | self.assertEqual(len(splits), 2) | 
|  | self.assertEqual(len(splits[0]), 2) | 
|  | self.assertEqual(len(splits[1]), 4) | 
|  |  | 
|  | def test_splits_are_mutually_exclusive(self): | 
|  | data = [5, 2, 3, 4, 1, 6] | 
|  | splits = random_split(data, [2, 4]) | 
|  | all_values = [] | 
|  | all_values.extend(list(splits[0])) | 
|  | all_values.extend(list(splits[1])) | 
|  | data.sort() | 
|  | all_values.sort() | 
|  | self.assertListEqual(data, all_values) | 
|  |  | 
|  |  | 
|  | class TestTensorDataset(TestCase): | 
|  |  | 
|  | def test_len(self): | 
|  | source = TensorDataset(torch.randn(15, 10, 2, 3, 4, 5), torch.randperm(15)) | 
|  | self.assertEqual(len(source), 15) | 
|  |  | 
|  | def test_getitem(self): | 
|  | t = torch.randn(15, 10, 2, 3, 4, 5) | 
|  | l = torch.randn(15, 10) | 
|  | source = TensorDataset(t, l) | 
|  | for i in range(15): | 
|  | self.assertEqual(t[i], source[i][0]) | 
|  | self.assertEqual(l[i], source[i][1]) | 
|  |  | 
|  | def test_getitem_1d(self): | 
|  | t = torch.randn(15) | 
|  | l = torch.randn(15) | 
|  | source = TensorDataset(t, l) | 
|  | for i in range(15): | 
|  | self.assertEqual(t[i], source[i][0]) | 
|  | self.assertEqual(l[i], source[i][1]) | 
|  |  | 
|  |  | 
|  | class TestConcatDataset(TestCase): | 
|  |  | 
|  | def test_concat_two_singletons(self): | 
|  | result = ConcatDataset([[0], [1]]) | 
|  | self.assertEqual(2, len(result)) | 
|  | self.assertEqual(0, result[0]) | 
|  | self.assertEqual(1, result[1]) | 
|  |  | 
|  | def test_concat_two_non_singletons(self): | 
|  | result = ConcatDataset([[0, 1, 2, 3, 4], | 
|  | [5, 6, 7, 8, 9]]) | 
|  | self.assertEqual(10, len(result)) | 
|  | self.assertEqual(0, result[0]) | 
|  | self.assertEqual(5, result[5]) | 
|  |  | 
|  | def test_concat_two_non_singletons_with_empty(self): | 
|  | # Adding an empty dataset somewhere is correctly handled | 
|  | result = ConcatDataset([[0, 1, 2, 3, 4], | 
|  | [], | 
|  | [5, 6, 7, 8, 9]]) | 
|  | self.assertEqual(10, len(result)) | 
|  | self.assertEqual(0, result[0]) | 
|  | self.assertEqual(5, result[5]) | 
|  |  | 
|  | def test_concat_raises_index_error(self): | 
|  | result = ConcatDataset([[0, 1, 2, 3, 4], | 
|  | [5, 6, 7, 8, 9]]) | 
|  | with self.assertRaises(IndexError): | 
|  | # this one goes to 11 | 
|  | result[11] | 
|  |  | 
|  | def test_add_dataset(self): | 
|  | d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7)) | 
|  | d2 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7)) | 
|  | d3 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7)) | 
|  | result = d1 + d2 + d3 | 
|  | self.assertEqual(21, len(result)) | 
|  | self.assertEqual(0, (d1[0][0] - result[0][0]).abs().sum()) | 
|  | self.assertEqual(0, (d2[0][0] - result[7][0]).abs().sum()) | 
|  | self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum()) | 
|  |  | 
|  |  | 
|  | # Stores the first encountered exception in .exception. | 
|  | # Inspired by https://stackoverflow.com/a/33599967 | 
|  | class ErrorTrackingProcess(multiprocessing.Process): | 
|  |  | 
|  | def __init__(self, *args, **kwargs): | 
|  | super(ErrorTrackingProcess, self).__init__(*args, **kwargs) | 
|  | self._pconn, self._cconn = multiprocessing.Pipe() | 
|  | self._exception = None | 
|  |  | 
|  | def run(self): | 
|  | # Disable stderr printing from os level, and make workers not printing | 
|  | # to stderr. | 
|  | # Can't use sys.stderr.close, otherwise Python `raise` will error with | 
|  | # ValueError: I/O operation on closed file. | 
|  | os.close(sys.stderr.fileno()) | 
|  | try: | 
|  | super(ErrorTrackingProcess, self).run() | 
|  | self._cconn.send(None) | 
|  | except Exception as e: | 
|  | self._cconn.send(ExceptionWrapper(sys.exc_info())) | 
|  | raise | 
|  |  | 
|  | @property | 
|  | def exception(self): | 
|  | if self._pconn.poll(): | 
|  | self._exception = self._pconn.recv() | 
|  | if self._exception is None: | 
|  | return None | 
|  | else: | 
|  | return self._exception.exc_type(self._exception.exc_msg) | 
|  |  | 
|  | # ESRCH means that os.kill can't finds alive proc | 
|  | def send_signal(self, signum, ignore_ESRCH=False): | 
|  | try: | 
|  | os.kill(self.pid, signum) | 
|  | except OSError as e: | 
|  | if not ignore_ESRCH or e.errno != errno.ESRCH: | 
|  | raise | 
|  |  | 
|  |  | 
|  | class ErrorDataset(Dataset): | 
|  |  | 
|  | def __init__(self, size): | 
|  | self.size = size | 
|  |  | 
|  | def __len__(self): | 
|  | return self.size | 
|  |  | 
|  |  | 
|  | class SegfaultDataset(Dataset): | 
|  |  | 
|  | def __init__(self, size): | 
|  | self.size = size | 
|  |  | 
|  | def __getitem__(self, idx): | 
|  | return ctypes.string_at(0) | 
|  |  | 
|  | def __len__(self): | 
|  | return self.size | 
|  |  | 
|  |  | 
|  | class SleepDataset(Dataset): | 
|  |  | 
|  | def __init__(self, size, sleep_sec): | 
|  | self.size = size | 
|  | self.sleep_sec = sleep_sec | 
|  |  | 
|  | def __getitem__(self, idx): | 
|  | time.sleep(self.sleep_sec) | 
|  | return idx | 
|  |  | 
|  | def __len__(self): | 
|  | return self.size | 
|  |  | 
|  |  | 
|  | class SeedDataset(Dataset): | 
|  |  | 
|  | def __init__(self, size): | 
|  | self.size = size | 
|  |  | 
|  | def __getitem__(self, idx): | 
|  | return torch.initial_seed() | 
|  |  | 
|  | def __len__(self): | 
|  | return self.size | 
|  |  | 
|  |  | 
|  | # Inspired by https://stackoverflow.com/a/26703365 | 
|  | # This will ensure that each worker at least processes one data | 
|  | class SynchronizedSeedDataset(Dataset): | 
|  |  | 
|  | def __init__(self, size, num_workers): | 
|  | assert size >= num_workers | 
|  | self.count = multiprocessing.Value('i', 0, lock=True) | 
|  | self.barrier = multiprocessing.Semaphore(0) | 
|  | self.num_workers = num_workers | 
|  | self.size = size | 
|  |  | 
|  | def __getitem__(self, idx): | 
|  | with self.count.get_lock(): | 
|  | self.count.value += 1 | 
|  | if self.count.value == self.num_workers: | 
|  | self.barrier.release() | 
|  | self.barrier.acquire() | 
|  | self.barrier.release() | 
|  | return torch.initial_seed() | 
|  |  | 
|  | def __len__(self): | 
|  | return self.size | 
|  |  | 
|  |  | 
|  | def _test_timeout(): | 
|  | dataset = SleepDataset(10, 10) | 
|  | dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1) | 
|  | _ = next(iter(dataloader)) | 
|  |  | 
|  |  | 
|  | def _test_segfault(): | 
|  | dataset = SegfaultDataset(10) | 
|  | dataloader = DataLoader(dataset, batch_size=2, num_workers=2) | 
|  | _ = next(iter(dataloader)) | 
|  |  | 
|  |  | 
|  | # test custom init function | 
|  | def init_fn(worker_id): | 
|  | torch.manual_seed(12345) | 
|  |  | 
|  |  | 
|  | class TestDataLoader(TestCase): | 
|  |  | 
|  | def setUp(self): | 
|  | self.data = torch.randn(100, 2, 3, 5) | 
|  | self.labels = torch.randperm(50).repeat(2) | 
|  | self.dataset = TensorDataset(self.data, self.labels) | 
|  |  | 
|  | def _test_sequential(self, loader): | 
|  | batch_size = loader.batch_size | 
|  | for i, (sample, target) in enumerate(loader): | 
|  | idx = i * batch_size | 
|  | self.assertEqual(sample, self.data[idx:idx + batch_size]) | 
|  | self.assertEqual(target, self.labels[idx:idx + batch_size]) | 
|  | self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size)) | 
|  |  | 
|  | def _test_shuffle(self, loader): | 
|  | found_data = {i: 0 for i in range(self.data.size(0))} | 
|  | found_labels = {i: 0 for i in range(self.labels.size(0))} | 
|  | batch_size = loader.batch_size | 
|  | for i, (batch_samples, batch_targets) in enumerate(loader): | 
|  | for sample, target in zip(batch_samples, batch_targets): | 
|  | for data_point_idx, data_point in enumerate(self.data): | 
|  | if data_point.eq(sample).all(): | 
|  | self.assertFalse(found_data[data_point_idx]) | 
|  | found_data[data_point_idx] += 1 | 
|  | break | 
|  | self.assertEqual(target, self.labels[data_point_idx]) | 
|  | found_labels[data_point_idx] += 1 | 
|  | self.assertEqual(sum(found_data.values()), (i + 1) * batch_size) | 
|  | self.assertEqual(sum(found_labels.values()), (i + 1) * batch_size) | 
|  | self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size)) | 
|  |  | 
|  | def _test_error(self, loader): | 
|  | it = iter(loader) | 
|  | errors = 0 | 
|  | while True: | 
|  | try: | 
|  | next(it) | 
|  | except NotImplementedError: | 
|  | errors += 1 | 
|  | except StopIteration: | 
|  | self.assertEqual(errors, | 
|  | math.ceil(float(len(loader.dataset)) / loader.batch_size)) | 
|  | return | 
|  |  | 
|  | def test_sequential(self): | 
|  | self._test_sequential(DataLoader(self.dataset)) | 
|  |  | 
|  | def test_sequential_batch(self): | 
|  | self._test_sequential(DataLoader(self.dataset, batch_size=2)) | 
|  |  | 
|  | def test_growing_dataset(self): | 
|  | dataset = [torch.ones(4) for _ in range(4)] | 
|  | dataloader_seq = DataLoader(dataset, shuffle=False) | 
|  | dataloader_shuffle = DataLoader(dataset, shuffle=True) | 
|  | dataset.append(torch.ones(4)) | 
|  | self.assertEqual(len(dataloader_seq), 5) | 
|  | self.assertEqual(len(dataloader_shuffle), 5) | 
|  |  | 
|  | @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") | 
|  | def test_sequential_pin_memory(self): | 
|  | loader = DataLoader(self.dataset, batch_size=2, pin_memory=True) | 
|  | for input, target in loader: | 
|  | self.assertTrue(input.is_pinned()) | 
|  | self.assertTrue(target.is_pinned()) | 
|  |  | 
|  | @unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error") | 
|  | def test_multiple_dataloaders(self): | 
|  | loader1_it = iter(DataLoader(self.dataset, num_workers=1)) | 
|  | loader2_it = iter(DataLoader(self.dataset, num_workers=2)) | 
|  | next(loader1_it) | 
|  | next(loader1_it) | 
|  | next(loader2_it) | 
|  | next(loader2_it) | 
|  | next(loader1_it) | 
|  | next(loader2_it) | 
|  |  | 
|  | @unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error") | 
|  | @unittest.skip("temporarily disable until flaky failures are fixed") | 
|  | def test_segfault(self): | 
|  | p = ErrorTrackingProcess(target=_test_segfault) | 
|  | p.start() | 
|  | p.join(JOIN_TIMEOUT) | 
|  | try: | 
|  | self.assertFalse(p.is_alive()) | 
|  | self.assertNotEqual(p.exitcode, 0) | 
|  | if IS_WINDOWS: | 
|  | self.assertIsInstance(p.exception, OSError) | 
|  | self.assertRegex(str(p.exception), r'access violation reading ') | 
|  | else: | 
|  | self.assertIsInstance(p.exception, RuntimeError) | 
|  | self.assertRegex(str(p.exception), r'DataLoader worker \(pid \d+\) is killed by signal: ') | 
|  | finally: | 
|  | p.terminate() | 
|  |  | 
|  | @unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error") | 
|  | def test_timeout(self): | 
|  | p = ErrorTrackingProcess(target=_test_timeout) | 
|  | p.start() | 
|  | p.join(JOIN_TIMEOUT) | 
|  | try: | 
|  | self.assertFalse(p.is_alive()) | 
|  | self.assertNotEqual(p.exitcode, 0) | 
|  | self.assertIsInstance(p.exception, RuntimeError) | 
|  | self.assertRegex(str(p.exception), r'DataLoader timed out after \d+ seconds') | 
|  | finally: | 
|  | p.terminate() | 
|  |  | 
|  | @unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error") | 
|  | def test_worker_seed(self): | 
|  | num_workers = 6 | 
|  | dataset = SynchronizedSeedDataset(num_workers, num_workers) | 
|  | dataloader = DataLoader(dataset, batch_size=1, num_workers=num_workers) | 
|  | seeds = set() | 
|  | for batch in dataloader: | 
|  | seeds.add(batch[0]) | 
|  | self.assertEqual(len(seeds), num_workers) | 
|  |  | 
|  | @unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error") | 
|  | def test_worker_init_fn(self): | 
|  | dataset = SeedDataset(4) | 
|  | dataloader = DataLoader(dataset, batch_size=2, num_workers=2, | 
|  | worker_init_fn=init_fn) | 
|  | for batch in dataloader: | 
|  | self.assertEqual(12345, batch[0]) | 
|  | self.assertEqual(12345, batch[1]) | 
|  |  | 
|  | def test_shuffle(self): | 
|  | self._test_shuffle(DataLoader(self.dataset, shuffle=True)) | 
|  |  | 
|  | def test_shuffle_batch(self): | 
|  | self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True)) | 
|  |  | 
|  | @unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error") | 
|  | def test_sequential_workers(self): | 
|  | self._test_sequential(DataLoader(self.dataset, num_workers=4)) | 
|  |  | 
|  | @unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error") | 
|  | def test_seqential_batch_workers(self): | 
|  | self._test_sequential(DataLoader(self.dataset, batch_size=2, num_workers=4)) | 
|  |  | 
|  | @unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error") | 
|  | def test_shuffle_workers(self): | 
|  | self._test_shuffle(DataLoader(self.dataset, shuffle=True, num_workers=4)) | 
|  |  | 
|  | @unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error") | 
|  | def test_shuffle_batch_workers(self): | 
|  | self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4)) | 
|  |  | 
|  | def _test_batch_sampler(self, **kwargs): | 
|  | # [(0, 1), (2, 3, 4), (5, 6), (7, 8, 9), ...] | 
|  | batches = [] | 
|  | for i in range(0, 100, 5): | 
|  | batches.append(tuple(range(i, i + 2))) | 
|  | batches.append(tuple(range(i + 2, i + 5))) | 
|  |  | 
|  | dl = DataLoader(self.dataset, batch_sampler=batches, **kwargs) | 
|  | self.assertEqual(len(dl), 40) | 
|  | for i, (input, _target) in enumerate(dl): | 
|  | if i % 2 == 0: | 
|  | offset = i * 5 // 2 | 
|  | self.assertEqual(len(input), 2) | 
|  | self.assertEqual(input, self.data[offset:offset + 2]) | 
|  | else: | 
|  | offset = i * 5 // 2 | 
|  | self.assertEqual(len(input), 3) | 
|  | self.assertEqual(input, self.data[offset:offset + 3]) | 
|  |  | 
|  | @unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error") | 
|  | def test_batch_sampler(self): | 
|  | self._test_batch_sampler() | 
|  | self._test_batch_sampler(num_workers=4) | 
|  |  | 
|  | @unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error") | 
|  | @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") | 
|  | def test_shuffle_pin_memory(self): | 
|  | loader = DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True) | 
|  | for input, target in loader: | 
|  | self.assertTrue(input.is_pinned()) | 
|  | self.assertTrue(target.is_pinned()) | 
|  |  | 
|  | @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") | 
|  | def test_numpy(self): | 
|  | import numpy as np | 
|  |  | 
|  | class TestDataset(torch.utils.data.Dataset): | 
|  | def __getitem__(self, i): | 
|  | return np.ones((2, 3, 4)) * i | 
|  |  | 
|  | def __len__(self): | 
|  | return 1000 | 
|  |  | 
|  | loader = DataLoader(TestDataset(), batch_size=12) | 
|  | batch = next(iter(loader)) | 
|  | self.assertIsInstance(batch, torch.DoubleTensor) | 
|  | self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4])) | 
|  |  | 
|  | def test_error(self): | 
|  | self._test_error(DataLoader(ErrorDataset(100), batch_size=2, shuffle=True)) | 
|  |  | 
|  | @unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error") | 
|  | def test_error_workers(self): | 
|  | self._test_error(DataLoader(ErrorDataset(41), batch_size=2, shuffle=True, num_workers=4)) | 
|  |  | 
|  | @unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error") | 
|  | @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") | 
|  | def test_partial_workers(self): | 
|  | "check that workers exit even if the iterator is not exhausted" | 
|  | loader = iter(DataLoader(self.dataset, batch_size=2, num_workers=4, pin_memory=True)) | 
|  | workers = loader.workers | 
|  | worker_manager_thread = loader.worker_manager_thread | 
|  | for i, sample in enumerate(loader): | 
|  | if i == 3: | 
|  | break | 
|  | del loader | 
|  | for w in workers: | 
|  | w.join(JOIN_TIMEOUT) | 
|  | self.assertFalse(w.is_alive(), 'subprocess not terminated') | 
|  | self.assertEqual(w.exitcode, 0) | 
|  | worker_manager_thread.join(JOIN_TIMEOUT) | 
|  | self.assertFalse(worker_manager_thread.is_alive()) | 
|  |  | 
|  | def test_len(self): | 
|  | def check_len(dl, expected): | 
|  | self.assertEqual(len(dl), expected) | 
|  | n = 0 | 
|  | for sample in dl: | 
|  | n += 1 | 
|  | self.assertEqual(n, expected) | 
|  | check_len(self.dataset, 100) | 
|  | check_len(DataLoader(self.dataset, batch_size=2), 50) | 
|  | check_len(DataLoader(self.dataset, batch_size=3), 34) | 
|  |  | 
|  | @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") | 
|  | def test_numpy_scalars(self): | 
|  | import numpy as np | 
|  |  | 
|  | class ScalarDataset(torch.utils.data.Dataset): | 
|  | def __init__(self, dtype): | 
|  | self.dtype = dtype | 
|  |  | 
|  | def __getitem__(self, i): | 
|  | return self.dtype() | 
|  |  | 
|  | def __len__(self): | 
|  | return 4 | 
|  |  | 
|  | dtypes = { | 
|  | np.float64: torch.DoubleTensor, | 
|  | np.float32: torch.FloatTensor, | 
|  | np.float16: torch.HalfTensor, | 
|  | np.int64: torch.LongTensor, | 
|  | np.int32: torch.IntTensor, | 
|  | np.int16: torch.ShortTensor, | 
|  | np.int8: torch.CharTensor, | 
|  | np.uint8: torch.ByteTensor, | 
|  | } | 
|  | for dt, tt in dtypes.items(): | 
|  | dset = ScalarDataset(dt) | 
|  | loader = DataLoader(dset, batch_size=2) | 
|  | batch = next(iter(loader)) | 
|  | self.assertIsInstance(batch, tt) | 
|  |  | 
|  | @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") | 
|  | def test_default_colate_bad_numpy_types(self): | 
|  | import numpy as np | 
|  |  | 
|  | # Should be a no-op | 
|  | arr = np.array(['a', 'b', 'c']) | 
|  | default_collate(arr) | 
|  |  | 
|  | arr = np.array([[['a', 'b', 'c']]]) | 
|  | self.assertRaises(TypeError, lambda: default_collate(arr)) | 
|  |  | 
|  | arr = np.array([object(), object(), object()]) | 
|  | self.assertRaises(TypeError, lambda: default_collate(arr)) | 
|  |  | 
|  | arr = np.array([[[object(), object(), object()]]]) | 
|  | self.assertRaises(TypeError, lambda: default_collate(arr)) | 
|  |  | 
|  |  | 
|  | class StringDataset(Dataset): | 
|  | def __init__(self): | 
|  | self.s = '12345' | 
|  |  | 
|  | def __len__(self): | 
|  | return len(self.s) | 
|  |  | 
|  | def __getitem__(self, ndx): | 
|  | return (self.s[ndx], ndx) | 
|  |  | 
|  |  | 
|  | class TestStringDataLoader(TestCase): | 
|  | def setUp(self): | 
|  | self.dataset = StringDataset() | 
|  |  | 
|  | @unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error") | 
|  | @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") | 
|  | def test_shuffle_pin_memory(self): | 
|  | loader = DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True) | 
|  | for batch_ndx, (s, n) in enumerate(loader): | 
|  | self.assertIsInstance(s[0], str) | 
|  | self.assertTrue(n.is_pinned()) | 
|  |  | 
|  |  | 
|  | class DictDataset(Dataset): | 
|  | def __len__(self): | 
|  | return 4 | 
|  |  | 
|  | def __getitem__(self, ndx): | 
|  | return { | 
|  | 'a_tensor': torch.Tensor(4, 2).fill_(ndx), | 
|  | 'another_dict': { | 
|  | 'a_number': ndx, | 
|  | }, | 
|  | } | 
|  |  | 
|  |  | 
|  | class TestDictDataLoader(TestCase): | 
|  | def setUp(self): | 
|  | self.dataset = DictDataset() | 
|  |  | 
|  | def test_sequential_batch(self): | 
|  | loader = DataLoader(self.dataset, batch_size=2, shuffle=False) | 
|  | batch_size = loader.batch_size | 
|  | for i, sample in enumerate(loader): | 
|  | idx = i * batch_size | 
|  | self.assertEqual(set(sample.keys()), {'a_tensor', 'another_dict'}) | 
|  | self.assertEqual(set(sample['another_dict'].keys()), {'a_number'}) | 
|  |  | 
|  | t = sample['a_tensor'] | 
|  | self.assertEqual(t.size(), torch.Size([batch_size, 4, 2])) | 
|  | self.assertTrue((t[0] == idx).all()) | 
|  | self.assertTrue((t[1] == idx + 1).all()) | 
|  |  | 
|  | n = sample['another_dict']['a_number'] | 
|  | self.assertEqual(n.size(), torch.Size([batch_size])) | 
|  | self.assertEqual(n[0], idx) | 
|  | self.assertEqual(n[1], idx + 1) | 
|  |  | 
|  | @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") | 
|  | def test_pin_memory(self): | 
|  | loader = DataLoader(self.dataset, batch_size=2, pin_memory=True) | 
|  | for batch_ndx, sample in enumerate(loader): | 
|  | self.assertTrue(sample['a_tensor'].is_pinned()) | 
|  | self.assertTrue(sample['another_dict']['a_number'].is_pinned()) | 
|  |  | 
|  |  | 
|  | if __name__ == '__main__': | 
|  | run_tests() |