blob: e1ca129ec1f75f97518c53e9dd1ea6f08185da0e [file] [log] [blame]
import math
import sys
import errno
import os
import ctypes
import signal
import torch
import gc
import time
import traceback
import unittest
import subprocess
import itertools
from torch import multiprocessing as mp
from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
from torch.utils.data.dataset import random_split
from torch.utils.data.dataloader import default_collate, ExceptionWrapper, MP_STATUS_CHECK_INTERVAL
from common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, IS_PPC, NO_MULTIPROCESSING_SPAWN,
skipIfRocm, load_tests)
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
# We cannot import TEST_CUDA from common_cuda here, because if we do that,
# the TEST_CUDNN line from common_cuda will be executed multiple times
# as well during the execution of this test suite, and it will cause
# CUDA OOM error on Windows.
TEST_CUDA = torch.cuda.is_available()
if not NO_MULTIPROCESSING_SPAWN:
# Get a multiprocessing context because some test / third party library will
# set start_method when imported, and setting again triggers RuntimeError.
mp = mp.get_context(method='spawn')
JOIN_TIMEOUT = 17.0 if IS_WINDOWS or IS_PPC else 8.5
class TestDatasetRandomSplit(TestCase):
def test_lengths_must_equal_datset_size(self):
with self.assertRaises(ValueError):
random_split([1, 2, 3, 4], [1, 2])
def test_splits_have_correct_size(self):
splits = random_split([1, 2, 3, 4, 5, 6], [2, 4])
self.assertEqual(len(splits), 2)
self.assertEqual(len(splits[0]), 2)
self.assertEqual(len(splits[1]), 4)
def test_splits_are_mutually_exclusive(self):
data = [5, 2, 3, 4, 1, 6]
splits = random_split(data, [2, 4])
all_values = []
all_values.extend(list(splits[0]))
all_values.extend(list(splits[1]))
data.sort()
all_values.sort()
self.assertListEqual(data, all_values)
class TestTensorDataset(TestCase):
def test_len(self):
source = TensorDataset(torch.randn(15, 10, 2, 3, 4, 5), torch.randperm(15))
self.assertEqual(len(source), 15)
def test_getitem(self):
t = torch.randn(15, 10, 2, 3, 4, 5)
l = torch.randn(15, 10)
source = TensorDataset(t, l)
for i in range(15):
self.assertEqual(t[i], source[i][0])
self.assertEqual(l[i], source[i][1])
def test_getitem_1d(self):
t = torch.randn(15)
l = torch.randn(15)
source = TensorDataset(t, l)
for i in range(15):
self.assertEqual(t[i], source[i][0])
self.assertEqual(l[i], source[i][1])
def test_single_tensor(self):
t = torch.randn(5, 10)
source = TensorDataset(t)
self.assertEqual(len(source), 5)
for i in range(5):
self.assertEqual(t[i], source[i][0])
def test_many_tensors(self):
t0 = torch.randn(5, 10, 2, 3, 4, 5)
t1 = torch.randn(5, 10)
t2 = torch.randn(5, 10, 2, 5)
t3 = torch.randn(5, 10, 3, 7)
source = TensorDataset(t0, t1, t2, t3)
self.assertEqual(len(source), 5)
for i in range(5):
self.assertEqual(t0[i], source[i][0])
self.assertEqual(t1[i], source[i][1])
self.assertEqual(t2[i], source[i][2])
self.assertEqual(t3[i], source[i][3])
class TestConcatDataset(TestCase):
def test_concat_two_singletons(self):
result = ConcatDataset([[0], [1]])
self.assertEqual(2, len(result))
self.assertEqual(0, result[0])
self.assertEqual(1, result[1])
def test_concat_two_non_singletons(self):
result = ConcatDataset([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
self.assertEqual(10, len(result))
self.assertEqual(0, result[0])
self.assertEqual(5, result[5])
def test_concat_two_non_singletons_with_empty(self):
# Adding an empty dataset somewhere is correctly handled
result = ConcatDataset([[0, 1, 2, 3, 4],
[],
[5, 6, 7, 8, 9]])
self.assertEqual(10, len(result))
self.assertEqual(0, result[0])
self.assertEqual(5, result[5])
def test_concat_raises_index_error(self):
result = ConcatDataset([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
with self.assertRaises(IndexError):
# this one goes to 11
result[11]
def test_add_dataset(self):
d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
d2 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
d3 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
result = d1 + d2 + d3
self.assertEqual(21, len(result))
self.assertEqual(0, (d1[0][0] - result[0][0]).abs().sum())
self.assertEqual(0, (d2[0][0] - result[7][0]).abs().sum())
self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum())
# Stores the first encountered exception in .exception.
# Inspired by https://stackoverflow.com/a/33599967
class ErrorTrackingProcess(mp.Process):
def __init__(self, *args, **kwargs):
super(ErrorTrackingProcess, self).__init__(*args, **kwargs)
self._pconn, self._cconn = mp.Pipe()
self._exception = None
def run(self):
# Disable polluting stderr with errors that are supposed to happen.
sys.stderr = open(os.devnull, "w")
try:
super(ErrorTrackingProcess, self).run()
self._cconn.send(None)
except Exception:
self._cconn.send(ExceptionWrapper(sys.exc_info()))
raise
@property
def exception(self):
if self._pconn.poll():
self._exception = self._pconn.recv()
if self._exception is None:
return None
else:
return self._exception.exc_type(self._exception.exc_msg)
# ESRCH means that os.kill can't finds alive proc
def send_signal(self, signum, ignore_ESRCH=False):
try:
os.kill(self.pid, signum)
except OSError as e:
if not ignore_ESRCH or e.errno != errno.ESRCH:
raise
class ErrorDataset(Dataset):
def __init__(self, size):
self.size = size
def __len__(self):
return self.size
class SegfaultDataset(Dataset):
def __init__(self, size):
self.size = size
def __getitem__(self, idx):
return ctypes.string_at(0)
def __len__(self):
return self.size
class SleepDataset(Dataset):
def __init__(self, size, sleep_sec):
self.size = size
self.sleep_sec = sleep_sec
self.sleeped = False
def __getitem__(self, idx):
if not self.sleeped:
time.sleep(self.sleep_sec)
self.sleeped = True
return idx
def __len__(self):
return self.size
class SeedDataset(Dataset):
def __init__(self, size):
self.size = size
def __getitem__(self, idx):
return torch.initial_seed()
def __len__(self):
return self.size
# Inspired by https://stackoverflow.com/a/26703365
# This will ensure that each worker at least processes one data
class SynchronizedSeedDataset(Dataset):
def __init__(self, size, num_workers):
assert size >= num_workers
self.count = mp.Value('i', 0, lock=True)
self.barrier = mp.Semaphore(0)
self.num_workers = num_workers
self.size = size
def __getitem__(self, idx):
with self.count.get_lock():
self.count.value += 1
if self.count.value == self.num_workers:
self.barrier.release()
self.barrier.acquire()
self.barrier.release()
return torch.initial_seed()
def __len__(self):
return self.size
def _test_timeout():
dataset = SleepDataset(10, 3)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1)
_ = next(iter(dataloader))
def _test_timeout_pin_memory():
dataset = SleepDataset(10, 3)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1, pin_memory=True)
_ = next(iter(dataloader))
def disable_stderr(worker_id):
r"""
Avoids printing "ERROR: Unexpected segmentation fault encountered in worker."
from workers. Since worker signal handler prints with low-level write(),
this has to be done on OS level via dup.
This is used as worker_init_fn for test_segfault.
"""
sys.stderr.flush() # flush library buffers that dup2 knows nothing about
# Can't use a with-block because otherwise the fd will be closed when this
# function ends.
devnull = open(os.devnull, 'w')
os.dup2(devnull.fileno(), sys.stderr.fileno())
def _test_segfault():
dataset = SegfaultDataset(10)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, worker_init_fn=disable_stderr)
_ = next(iter(dataloader))
class TestProperExitDataset(object):
def __init__(self, size, error_event):
self.size = size
self.error_event = error_event
def __len__(self):
return self.size
def __getitem__(self, idx):
if self.error_event is not None and self.error_event.is_set():
raise RuntimeError('Worker error')
return torch.tensor([idx])
# See TestDataLoader.test_proper_exit for usage
def _test_proper_exit(use_workers, pin_memory, exit_method, hold_iter_reference,
worker_pids, setup_event):
num_workers = 2 if use_workers else 0
if exit_method == 'worker_error' or exit_method == 'worker_kill':
assert use_workers is True
ds = TestProperExitDataset(10, setup_event if exit_method == 'worker_error' else None)
loader = DataLoader(ds, batch_size=2, shuffle=False,
num_workers=num_workers, pin_memory=pin_memory)
error_it = 4
assert len(loader) > error_it
it = iter(loader)
if use_workers:
for i, w in enumerate(it.workers):
worker_pids[i] = w.pid
def kill_pid(pid):
if IS_WINDOWS:
os.system('taskkill /PID ' + str(os.getpid()) + ' /F')
else:
os.kill(os.getpid(), signal.SIGKILL)
for i, _ in enumerate(it):
if i == 0:
if not hold_iter_reference:
del it
setup_event.set()
if i == error_it:
if exit_method == 'main_error':
raise RuntimeError('Error')
elif exit_method == 'main_kill':
kill_pid(os.getpid())
elif exit_method == 'worker_kill':
kill_pid(worker_pids[0])
if not hold_iter_reference:
# Tries to trigger the __del__ clean-up rather than the automatic
# exiting of daemonic children. Technically it should be automatically
# triggered, but I don't want to rely on the implementation detail of
# Python gc.
gc.collect()
# test custom init function
def init_fn(worker_id):
torch.manual_seed(12345)
class TestDataLoader(TestCase):
def setUp(self):
self.data = torch.randn(100, 2, 3, 5)
self.labels = torch.randperm(50).repeat(2)
self.dataset = TensorDataset(self.data, self.labels)
def _test_sequential(self, loader):
batch_size = loader.batch_size
for i, (sample, target) in enumerate(loader):
idx = i * batch_size
self.assertEqual(sample, self.data[idx:idx + batch_size])
self.assertEqual(target, self.labels[idx:idx + batch_size])
self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
def _test_shuffle(self, loader):
found_data = {i: 0 for i in range(self.data.size(0))}
found_labels = {i: 0 for i in range(self.labels.size(0))}
batch_size = loader.batch_size
for i, (batch_samples, batch_targets) in enumerate(loader):
for sample, target in zip(batch_samples, batch_targets):
for data_point_idx, data_point in enumerate(self.data):
if data_point.eq(sample).all():
self.assertFalse(found_data[data_point_idx])
found_data[data_point_idx] += 1
break
self.assertEqual(target, self.labels[data_point_idx])
found_labels[data_point_idx] += 1
self.assertEqual(sum(found_data.values()), (i + 1) * batch_size)
self.assertEqual(sum(found_labels.values()), (i + 1) * batch_size)
self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
def _test_error(self, loader):
it = iter(loader)
errors = 0
while True:
try:
next(it)
except NotImplementedError:
errors += 1
except StopIteration:
self.assertEqual(errors,
math.ceil(float(len(loader.dataset)) / loader.batch_size))
return
def test_invalid_assign_after_init(self):
dl = DataLoader(self.dataset)
for attr in ('batch_size', 'sampler', 'drop_last'):
def fn():
setattr(dl, attr, {})
self.assertRaises(ValueError, fn)
def test_sequential(self):
self._test_sequential(DataLoader(self.dataset))
def test_sequential_batch(self):
self._test_sequential(DataLoader(self.dataset, batch_size=2))
def test_growing_dataset(self):
dataset = [torch.ones(4) for _ in range(4)]
dataloader_seq = DataLoader(dataset, shuffle=False)
dataloader_shuffle = DataLoader(dataset, shuffle=True)
dataset.append(torch.ones(4))
self.assertEqual(len(dataloader_seq), 5)
self.assertEqual(len(dataloader_shuffle), 5)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@skipIfRocm
def test_sequential_pin_memory(self):
loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
for input, target in loader:
self.assertTrue(input.is_pinned())
self.assertTrue(target.is_pinned())
def test_multiple_dataloaders(self):
loader1_it = iter(DataLoader(self.dataset, num_workers=1))
loader2_it = iter(DataLoader(self.dataset, num_workers=2))
next(loader1_it)
next(loader1_it)
next(loader2_it)
next(loader2_it)
next(loader1_it)
next(loader2_it)
@unittest.skip("temporarily disable until flaky failures are fixed")
def test_segfault(self):
p = ErrorTrackingProcess(target=_test_segfault)
p.start()
p.join(JOIN_TIMEOUT)
try:
self.assertFalse(p.is_alive())
self.assertNotEqual(p.exitcode, 0)
if IS_WINDOWS:
self.assertIsInstance(p.exception, OSError)
self.assertRegex(str(p.exception), r'access violation reading ')
else:
self.assertIsInstance(p.exception, RuntimeError)
self.assertRegex(str(p.exception), r'DataLoader worker \(pid \d+\) is killed by signal: ')
finally:
p.terminate()
@skipIfRocm
def test_timeout(self):
if TEST_CUDA and not NO_MULTIPROCESSING_SPAWN:
targets = (_test_timeout, _test_timeout_pin_memory)
else:
targets = (_test_timeout,)
for target in targets:
p = ErrorTrackingProcess(target=target)
p.start()
p.join(JOIN_TIMEOUT)
try:
self.assertFalse(p.is_alive())
self.assertNotEqual(p.exitcode, 0)
self.assertIsInstance(p.exception, RuntimeError)
self.assertRegex(str(p.exception), r'DataLoader timed out after \d+ seconds')
finally:
p.terminate()
def test_worker_seed(self):
num_workers = 6
dataset = SynchronizedSeedDataset(num_workers, num_workers)
dataloader = DataLoader(dataset, batch_size=1, num_workers=num_workers)
seeds = set()
for batch in dataloader:
seeds.add(batch[0])
self.assertEqual(len(seeds), num_workers)
def test_worker_init_fn(self):
dataset = SeedDataset(4)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2,
worker_init_fn=init_fn)
for batch in dataloader:
self.assertEqual(12345, batch[0])
self.assertEqual(12345, batch[1])
def test_shuffle(self):
self._test_shuffle(DataLoader(self.dataset, shuffle=True))
def test_shuffle_batch(self):
self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True))
def test_sequential_workers(self):
self._test_sequential(DataLoader(self.dataset, num_workers=4))
def test_seqential_batch_workers(self):
self._test_sequential(DataLoader(self.dataset, batch_size=2, num_workers=4))
def test_shuffle_workers(self):
self._test_shuffle(DataLoader(self.dataset, shuffle=True, num_workers=4))
def test_shuffle_batch_workers(self):
self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4))
def _test_batch_sampler(self, **kwargs):
# [(0, 1), (2, 3, 4), (5, 6), (7, 8, 9), ...]
batches = []
for i in range(0, 100, 5):
batches.append(tuple(range(i, i + 2)))
batches.append(tuple(range(i + 2, i + 5)))
dl = DataLoader(self.dataset, batch_sampler=batches, **kwargs)
self.assertEqual(len(dl), 40)
for i, (input, _target) in enumerate(dl):
if i % 2 == 0:
offset = i * 5 // 2
self.assertEqual(len(input), 2)
self.assertEqual(input, self.data[offset:offset + 2])
else:
offset = i * 5 // 2
self.assertEqual(len(input), 3)
self.assertEqual(input, self.data[offset:offset + 3])
def test_RandomSampler(self):
from collections import Counter
from torch.utils.data import RandomSampler
def sample_stat(sampler, num_samples):
counts = Counter(sampler)
count_repeated = sum(val > 1 for val in counts.values())
return (count_repeated, min(counts.keys()), max(counts.keys()))
# test sample with replacement
n = len(self.dataset) + 1 # ensure at least one sample is drawn more than once
sampler_with_replacement = RandomSampler(self.dataset, replacement=True, num_samples=n)
count_repeated, minval, maxval = sample_stat(sampler_with_replacement, n)
self.assertTrue(count_repeated > 0)
self.assertTrue(minval >= 0)
self.assertTrue(maxval < len(self.dataset))
# test sample without replacement
sampler_without_replacement = RandomSampler(self.dataset)
count_repeated, minval, maxval = sample_stat(sampler_without_replacement, len(self.dataset))
self.assertTrue(count_repeated == 0)
self.assertTrue(minval == 0)
self.assertTrue(maxval == len(self.dataset) - 1)
# raise error when replacement=False and num_samples is not None
self.assertRaises(ValueError, lambda: RandomSampler(self.dataset, num_samples=len(self.dataset)))
self.assertRaises(ValueError, lambda: RandomSampler(self.dataset, num_samples=0))
def test_duplicating_data_with_drop_last(self):
from torch.utils.data.distributed import DistributedSampler
num_processes = 4
num_batches = 9
data_set = torch.IntTensor(range(num_batches))
scanned_data = torch.IntTensor([])
for i in range(num_processes):
s = DistributedSampler(data_set, num_processes, i)
d_loader = DataLoader(data_set, batch_size=int(num_batches / num_processes), drop_last=True, sampler=s)
for k, data in enumerate(d_loader):
scanned_data = torch.cat((scanned_data, data), 0)
self.assertEqual(scanned_data.size(), scanned_data.unique().size())
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
don't support multiprocessing with spawn start method")
def test_batch_sampler(self):
self._test_batch_sampler()
self._test_batch_sampler(num_workers=4)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@skipIfRocm
def test_shuffle_pin_memory(self):
loader = DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
for input, target in loader:
self.assertTrue(input.is_pinned())
self.assertTrue(target.is_pinned())
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
def test_numpy(self):
import numpy as np
class TestDataset(torch.utils.data.Dataset):
def __getitem__(self, i):
return np.ones((2, 3, 4)) * i
def __len__(self):
return 1000
loader = DataLoader(TestDataset(), batch_size=12)
batch = next(iter(loader))
self.assertIsInstance(batch, torch.DoubleTensor)
self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))
def test_error(self):
self._test_error(DataLoader(ErrorDataset(100), batch_size=2, shuffle=True))
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
don't support multiprocessing with spawn start method")
def test_error_workers(self):
self._test_error(DataLoader(ErrorDataset(41), batch_size=2, shuffle=True, num_workers=4))
@unittest.skipIf(IS_WINDOWS, "FIXME: stuck test")
def test_partial_workers(self):
r"""Check that workers exit even if the iterator is not exhausted."""
if TEST_CUDA:
pin_memory_configs = (True, False)
else:
pin_memory_configs = (False,)
for pin_memory in pin_memory_configs:
loader = iter(DataLoader(self.dataset, batch_size=2, num_workers=4, pin_memory=pin_memory))
workers = loader.workers
if pin_memory:
pin_memory_thread = loader.pin_memory_thread
for i, sample in enumerate(loader):
if i == 10:
break
assert i == 10
del loader
for w in workers:
w.join(JOIN_TIMEOUT)
self.assertFalse(w.is_alive(), 'subprocess not terminated')
if pin_memory:
pin_memory_thread.join(JOIN_TIMEOUT)
self.assertFalse(pin_memory_thread.is_alive())
@staticmethod
def _is_process_alive(pid, pname):
# There is a chance of a terminated child process's pid being reused by a new unrelated process,
# but since we are looping this check very frequently, we will know that the child process dies
# before the new unrelated process starts.
if IS_WINDOWS:
command = 'tasklist | find "{}" /i'.format(pid)
else:
command = 'ps -p {} -o comm='.format(pid)
p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True)
(output, err) = p.communicate()
p_status = p.wait()
output = output.decode('utf-8')
return pname in output
@skipIfRocm
def test_proper_exit(self):
(r'''There might be ConnectionResetError or leaked semaphore warning '''
r'''(due to dirty process exit), but they are all safe to ignore''')
# TODO: test the case where the pin_memory_thread triggers an
# error/fatal signal. I haven't found out how to properly do that.
# Array to store the worker pids.
worker_pids = mp.Array('i', [-1 for _ in range(10)])
def wait_pids(pids, timeout):
r"""Wait for all process specified in pids to exit in given timeout."""
exit_status = [False for _ in pids]
start_time = time.time()
pname = 'python'
while True:
for i in range(len(pids)):
pid = pids[i]
if not exit_status[i]:
if not TestDataLoader._is_process_alive(pid, pname):
exit_status[i] = True
if all(exit_status):
break
else:
if time.time() - start_time > timeout:
break
time.sleep(0.5)
return exit_status
for use_workers, pin_memory, hold_iter_reference in itertools.product([True, False], repeat=3):
# `hold_iter_reference` specifies whether we hold a reference to the
# iterator. This is interesting because Python3 error traces holds a
# reference to the frames, which hold references to all the local
# variables including the iterator, and then the iterator dtor may
# not be called before process end. It is important to see that the
# processes still exit in both cases.
if pin_memory and (not TEST_CUDA or NO_MULTIPROCESSING_SPAWN):
# Can't use CUDA without spawn
continue
# `exit_method` controls the way the loader process ends.
# - `*_kill` means that `*` is killed by OS.
# - `*_error` means that `*` raises an error.
# - `None` means that no error happens.
# In all cases, all processes should end properly.
if use_workers:
exit_methods = [None, 'main_error', 'main_kill', 'worker_kill', 'worker_error']
else:
exit_methods = [None, 'main_error', 'main_kill']
for exit_method in exit_methods:
# clear pids array first
for i in range(len(worker_pids)):
worker_pids[i] = -1
# Event that the loader process uses to signal testing process
# that various things are setup, including that the worker pids
# are specified in `worker_pids` array.
setup_event = mp.Event()
p = ErrorTrackingProcess(target=_test_proper_exit,
args=(use_workers, pin_memory, exit_method,
hold_iter_reference, worker_pids, setup_event))
p.start()
# Wait for loader process to set everything up, i.e., filling
# worker pids in `worker_pids`.
setup_event.wait(timeout=JOIN_TIMEOUT)
self.assertTrue(setup_event.is_set(), 'loader process setup timed out')
pids = [pid for pid in worker_pids if pid > 0]
try:
exit_status = wait_pids(pids, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT))
if not all(exit_status):
self.fail('subprocess (pid(s) {}) not terminated'.format(
', '.join(p for p, exited in zip(pids, exit_status) if not exited)))
p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL)
self.assertFalse(p.is_alive(), 'loader process not terminated')
if exit_method is None:
self.assertEqual(p.exitcode, 0)
else:
self.assertNotEqual(p.exitcode, 0)
finally:
p.terminate()
def test_len(self):
def check_len(dl, expected):
self.assertEqual(len(dl), expected)
n = 0
for sample in dl:
n += 1
self.assertEqual(n, expected)
check_len(self.dataset, 100)
check_len(DataLoader(self.dataset, batch_size=2), 50)
check_len(DataLoader(self.dataset, batch_size=3), 34)
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
def test_numpy_scalars(self):
import numpy as np
class ScalarDataset(torch.utils.data.Dataset):
def __init__(self, dtype):
self.dtype = dtype
def __getitem__(self, i):
return self.dtype()
def __len__(self):
return 4
dtypes = {
np.float64: torch.DoubleTensor,
np.float32: torch.FloatTensor,
np.float16: torch.HalfTensor,
np.int64: torch.LongTensor,
np.int32: torch.IntTensor,
np.int16: torch.ShortTensor,
np.int8: torch.CharTensor,
np.uint8: torch.ByteTensor,
}
for dt, tt in dtypes.items():
dset = ScalarDataset(dt)
loader = DataLoader(dset, batch_size=2)
batch = next(iter(loader))
self.assertIsInstance(batch, tt)
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
def test_default_collate_bad_numpy_types(self):
import numpy as np
# Should be a no-op
arr = np.array(['a', 'b', 'c'])
default_collate(arr)
arr = np.array([[['a', 'b', 'c']]])
self.assertRaises(TypeError, lambda: default_collate(arr))
arr = np.array([object(), object(), object()])
self.assertRaises(TypeError, lambda: default_collate(arr))
arr = np.array([[[object(), object(), object()]]])
self.assertRaises(TypeError, lambda: default_collate(arr))
class StringDataset(Dataset):
def __init__(self):
self.s = '12345'
def __len__(self):
return len(self.s)
def __getitem__(self, ndx):
return (self.s[ndx], ndx)
class TestStringDataLoader(TestCase):
def setUp(self):
self.dataset = StringDataset()
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@skipIfRocm
def test_shuffle_pin_memory(self):
loader = DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
for batch_ndx, (s, n) in enumerate(loader):
self.assertIsInstance(s[0], str)
self.assertTrue(n.is_pinned())
class DictDataset(Dataset):
def __len__(self):
return 4
def __getitem__(self, ndx):
return {
'a_tensor': torch.Tensor(4, 2).fill_(ndx),
'another_dict': {
'a_number': ndx,
},
}
class TestDictDataLoader(TestCase):
def setUp(self):
self.dataset = DictDataset()
def test_sequential_batch(self):
loader = DataLoader(self.dataset, batch_size=2, shuffle=False)
batch_size = loader.batch_size
for i, sample in enumerate(loader):
idx = i * batch_size
self.assertEqual(set(sample.keys()), {'a_tensor', 'another_dict'})
self.assertEqual(set(sample['another_dict'].keys()), {'a_number'})
t = sample['a_tensor']
self.assertEqual(t.size(), torch.Size([batch_size, 4, 2]))
self.assertTrue((t[0] == idx).all())
self.assertTrue((t[1] == idx + 1).all())
n = sample['another_dict']['a_number']
self.assertEqual(n.size(), torch.Size([batch_size]))
self.assertEqual(n[0], idx)
self.assertEqual(n[1], idx + 1)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@skipIfRocm
def test_pin_memory(self):
loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
for batch_ndx, sample in enumerate(loader):
self.assertTrue(sample['a_tensor'].is_pinned())
self.assertTrue(sample['another_dict']['a_number'].is_pinned())
class TestWorkerQueueDataset(Dataset):
def __init__(self, data):
self.data = data
self.worker_id = None
def worker_init_fn(self, worker_id):
self.worker_id = worker_id
def __getitem__(self, item):
return self.worker_id, self.data[item]
def __len__(self):
return len(self.data)
class TestIndividualWorkerQueue(TestCase):
def setUp(self):
self.dataset = TestWorkerQueueDataset([i for i in range(128)])
def _run_ind_worker_queue_test(self, batch_size, num_workers):
loader = DataLoader(
self.dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers,
worker_init_fn=self.dataset.worker_init_fn
)
current_worker_idx = 0
for i, (worker_ids, sample) in enumerate(loader):
self.assertEqual(worker_ids.tolist(), [current_worker_idx] * batch_size)
self.assertEqual(sample.tolist(), [j for j in range(i * batch_size, (i + 1) * batch_size)])
current_worker_idx += 1
if current_worker_idx == num_workers:
current_worker_idx = 0
def test_ind_worker_queue(self):
for batch_size in (8, 16, 32, 64):
for num_workers in range(1, 6):
self._run_ind_worker_queue_test(batch_size=batch_size, num_workers=num_workers)
if __name__ == '__main__':
run_tests()