blob: 23bc1ba2ec0bd825220971d63343d603ae46dd5c [file] [log] [blame]
import math
import sys
import errno
import os
import ctypes
import signal
import torch
import time
import traceback
import unittest
import subprocess
from torch import multiprocessing
from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
from torch.utils.data.dataset import random_split
from torch.utils.data.dataloader import default_collate, ExceptionWrapper, MANAGER_STATUS_CHECK_INTERVAL
from common import TestCase, run_tests, TEST_NUMPY, IS_WINDOWS
# We cannot import TEST_CUDA from common_nn here, because if we do that,
# the TEST_CUDNN line from common_nn will be executed multiple times
# as well during the execution of this test suite, and it will cause
# CUDA OOM error on Windows.
TEST_CUDA = torch.cuda.is_available()
# We need spawn start method for test_manager_unclean_exit, but
# Python 2.7 doesn't allow it.
if sys.version_info[0] == 3:
# Without the try-catch block, some tests will complain that
# context has already been set.
try:
multiprocessing.set_start_method('spawn')
except RuntimeError:
pass
JOIN_TIMEOUT = 17.0 if IS_WINDOWS else 6.5
class TestDatasetRandomSplit(TestCase):
def test_lengths_must_equal_datset_size(self):
with self.assertRaises(ValueError):
random_split([1, 2, 3, 4], [1, 2])
def test_splits_have_correct_size(self):
splits = random_split([1, 2, 3, 4, 5, 6], [2, 4])
self.assertEqual(len(splits), 2)
self.assertEqual(len(splits[0]), 2)
self.assertEqual(len(splits[1]), 4)
def test_splits_are_mutually_exclusive(self):
data = [5, 2, 3, 4, 1, 6]
splits = random_split(data, [2, 4])
all_values = []
all_values.extend(list(splits[0]))
all_values.extend(list(splits[1]))
data.sort()
all_values.sort()
self.assertListEqual(data, all_values)
class TestTensorDataset(TestCase):
def test_len(self):
source = TensorDataset(torch.randn(15, 10, 2, 3, 4, 5), torch.randperm(15))
self.assertEqual(len(source), 15)
def test_getitem(self):
t = torch.randn(15, 10, 2, 3, 4, 5)
l = torch.randn(15, 10)
source = TensorDataset(t, l)
for i in range(15):
self.assertEqual(t[i], source[i][0])
self.assertEqual(l[i], source[i][1])
def test_getitem_1d(self):
t = torch.randn(15)
l = torch.randn(15)
source = TensorDataset(t, l)
for i in range(15):
self.assertEqual(t[i], source[i][0])
self.assertEqual(l[i], source[i][1])
def test_single_tensor(self):
t = torch.randn(5, 10)
source = TensorDataset(t)
self.assertEqual(len(source), 5)
for i in range(5):
self.assertEqual(t[i], source[i][0])
def test_many_tensors(self):
t0 = torch.randn(5, 10, 2, 3, 4, 5)
t1 = torch.randn(5, 10)
t2 = torch.randn(5, 10, 2, 5)
t3 = torch.randn(5, 10, 3, 7)
source = TensorDataset(t0, t1, t2, t3)
self.assertEqual(len(source), 5)
for i in range(5):
self.assertEqual(t0[i], source[i][0])
self.assertEqual(t1[i], source[i][1])
self.assertEqual(t2[i], source[i][2])
self.assertEqual(t3[i], source[i][3])
class TestConcatDataset(TestCase):
def test_concat_two_singletons(self):
result = ConcatDataset([[0], [1]])
self.assertEqual(2, len(result))
self.assertEqual(0, result[0])
self.assertEqual(1, result[1])
def test_concat_two_non_singletons(self):
result = ConcatDataset([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
self.assertEqual(10, len(result))
self.assertEqual(0, result[0])
self.assertEqual(5, result[5])
def test_concat_two_non_singletons_with_empty(self):
# Adding an empty dataset somewhere is correctly handled
result = ConcatDataset([[0, 1, 2, 3, 4],
[],
[5, 6, 7, 8, 9]])
self.assertEqual(10, len(result))
self.assertEqual(0, result[0])
self.assertEqual(5, result[5])
def test_concat_raises_index_error(self):
result = ConcatDataset([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
with self.assertRaises(IndexError):
# this one goes to 11
result[11]
def test_add_dataset(self):
d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
d2 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
d3 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
result = d1 + d2 + d3
self.assertEqual(21, len(result))
self.assertEqual(0, (d1[0][0] - result[0][0]).abs().sum())
self.assertEqual(0, (d2[0][0] - result[7][0]).abs().sum())
self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum())
# Stores the first encountered exception in .exception.
# Inspired by https://stackoverflow.com/a/33599967
class ErrorTrackingProcess(multiprocessing.Process):
def __init__(self, *args, **kwargs):
super(ErrorTrackingProcess, self).__init__(*args, **kwargs)
self._pconn, self._cconn = multiprocessing.Pipe()
self._exception = None
def run(self):
# Disable stderr printing from os level, and make workers not printing
# to stderr.
# Can't use sys.stderr.close, otherwise Python `raise` will error with
# ValueError: I/O operation on closed file.
os.close(sys.stderr.fileno())
try:
super(ErrorTrackingProcess, self).run()
self._cconn.send(None)
except Exception as e:
self._cconn.send(ExceptionWrapper(sys.exc_info()))
raise
@property
def exception(self):
if self._pconn.poll():
self._exception = self._pconn.recv()
if self._exception is None:
return None
else:
return self._exception.exc_type(self._exception.exc_msg)
# ESRCH means that os.kill can't finds alive proc
def send_signal(self, signum, ignore_ESRCH=False):
try:
os.kill(self.pid, signum)
except OSError as e:
if not ignore_ESRCH or e.errno != errno.ESRCH:
raise
class ErrorDataset(Dataset):
def __init__(self, size):
self.size = size
def __len__(self):
return self.size
class SegfaultDataset(Dataset):
def __init__(self, size):
self.size = size
def __getitem__(self, idx):
return ctypes.string_at(0)
def __len__(self):
return self.size
class SleepDataset(Dataset):
def __init__(self, size, sleep_sec):
self.size = size
self.sleep_sec = sleep_sec
def __getitem__(self, idx):
time.sleep(self.sleep_sec)
return idx
def __len__(self):
return self.size
class SeedDataset(Dataset):
def __init__(self, size):
self.size = size
def __getitem__(self, idx):
return torch.initial_seed()
def __len__(self):
return self.size
# Inspired by https://stackoverflow.com/a/26703365
# This will ensure that each worker at least processes one data
class SynchronizedSeedDataset(Dataset):
def __init__(self, size, num_workers):
assert size >= num_workers
self.count = multiprocessing.Value('i', 0, lock=True)
self.barrier = multiprocessing.Semaphore(0)
self.num_workers = num_workers
self.size = size
def __getitem__(self, idx):
with self.count.get_lock():
self.count.value += 1
if self.count.value == self.num_workers:
self.barrier.release()
self.barrier.acquire()
self.barrier.release()
return torch.initial_seed()
def __len__(self):
return self.size
def _test_timeout():
dataset = SleepDataset(10, 10)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1)
_ = next(iter(dataloader))
def _test_segfault():
dataset = SegfaultDataset(10)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2)
_ = next(iter(dataloader))
# test custom init function
def init_fn(worker_id):
torch.manual_seed(12345)
class TestDataLoader(TestCase):
def setUp(self):
self.data = torch.randn(100, 2, 3, 5)
self.labels = torch.randperm(50).repeat(2)
self.dataset = TensorDataset(self.data, self.labels)
def _test_sequential(self, loader):
batch_size = loader.batch_size
for i, (sample, target) in enumerate(loader):
idx = i * batch_size
self.assertEqual(sample, self.data[idx:idx + batch_size])
self.assertEqual(target, self.labels[idx:idx + batch_size])
self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
def _test_shuffle(self, loader):
found_data = {i: 0 for i in range(self.data.size(0))}
found_labels = {i: 0 for i in range(self.labels.size(0))}
batch_size = loader.batch_size
for i, (batch_samples, batch_targets) in enumerate(loader):
for sample, target in zip(batch_samples, batch_targets):
for data_point_idx, data_point in enumerate(self.data):
if data_point.eq(sample).all():
self.assertFalse(found_data[data_point_idx])
found_data[data_point_idx] += 1
break
self.assertEqual(target, self.labels[data_point_idx])
found_labels[data_point_idx] += 1
self.assertEqual(sum(found_data.values()), (i + 1) * batch_size)
self.assertEqual(sum(found_labels.values()), (i + 1) * batch_size)
self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
def _test_error(self, loader):
it = iter(loader)
errors = 0
while True:
try:
next(it)
except NotImplementedError:
errors += 1
except StopIteration:
self.assertEqual(errors,
math.ceil(float(len(loader.dataset)) / loader.batch_size))
return
def test_invalid_assign_after_init(self):
dl = DataLoader(self.dataset)
for attr in ('batch_size', 'sampler', 'drop_last'):
def fn():
setattr(dl, attr, {})
self.assertRaises(ValueError, fn)
def test_sequential(self):
self._test_sequential(DataLoader(self.dataset))
def test_sequential_batch(self):
self._test_sequential(DataLoader(self.dataset, batch_size=2))
def test_growing_dataset(self):
dataset = [torch.ones(4) for _ in range(4)]
dataloader_seq = DataLoader(dataset, shuffle=False)
dataloader_shuffle = DataLoader(dataset, shuffle=True)
dataset.append(torch.ones(4))
self.assertEqual(len(dataloader_seq), 5)
self.assertEqual(len(dataloader_shuffle), 5)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_sequential_pin_memory(self):
loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
for input, target in loader:
self.assertTrue(input.is_pinned())
self.assertTrue(target.is_pinned())
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
def test_multiple_dataloaders(self):
loader1_it = iter(DataLoader(self.dataset, num_workers=1))
loader2_it = iter(DataLoader(self.dataset, num_workers=2))
next(loader1_it)
next(loader1_it)
next(loader2_it)
next(loader2_it)
next(loader1_it)
next(loader2_it)
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
@unittest.skip("temporarily disable until flaky failures are fixed")
def test_segfault(self):
p = ErrorTrackingProcess(target=_test_segfault)
p.start()
p.join(JOIN_TIMEOUT)
try:
self.assertFalse(p.is_alive())
self.assertNotEqual(p.exitcode, 0)
if IS_WINDOWS:
self.assertIsInstance(p.exception, OSError)
self.assertRegex(str(p.exception), r'access violation reading ')
else:
self.assertIsInstance(p.exception, RuntimeError)
self.assertRegex(str(p.exception), r'DataLoader worker \(pid \d+\) is killed by signal: ')
finally:
p.terminate()
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
def test_timeout(self):
p = ErrorTrackingProcess(target=_test_timeout)
p.start()
p.join(JOIN_TIMEOUT)
try:
self.assertFalse(p.is_alive())
self.assertNotEqual(p.exitcode, 0)
self.assertIsInstance(p.exception, RuntimeError)
self.assertRegex(str(p.exception), r'DataLoader timed out after \d+ seconds')
finally:
p.terminate()
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
def test_worker_seed(self):
num_workers = 6
dataset = SynchronizedSeedDataset(num_workers, num_workers)
dataloader = DataLoader(dataset, batch_size=1, num_workers=num_workers)
seeds = set()
for batch in dataloader:
seeds.add(batch[0])
self.assertEqual(len(seeds), num_workers)
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
def test_worker_init_fn(self):
dataset = SeedDataset(4)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2,
worker_init_fn=init_fn)
for batch in dataloader:
self.assertEqual(12345, batch[0])
self.assertEqual(12345, batch[1])
def test_shuffle(self):
self._test_shuffle(DataLoader(self.dataset, shuffle=True))
def test_shuffle_batch(self):
self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True))
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
def test_sequential_workers(self):
self._test_sequential(DataLoader(self.dataset, num_workers=4))
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
def test_seqential_batch_workers(self):
self._test_sequential(DataLoader(self.dataset, batch_size=2, num_workers=4))
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
def test_shuffle_workers(self):
self._test_shuffle(DataLoader(self.dataset, shuffle=True, num_workers=4))
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
def test_shuffle_batch_workers(self):
self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4))
def _test_batch_sampler(self, **kwargs):
# [(0, 1), (2, 3, 4), (5, 6), (7, 8, 9), ...]
batches = []
for i in range(0, 100, 5):
batches.append(tuple(range(i, i + 2)))
batches.append(tuple(range(i + 2, i + 5)))
dl = DataLoader(self.dataset, batch_sampler=batches, **kwargs)
self.assertEqual(len(dl), 40)
for i, (input, _target) in enumerate(dl):
if i % 2 == 0:
offset = i * 5 // 2
self.assertEqual(len(input), 2)
self.assertEqual(input, self.data[offset:offset + 2])
else:
offset = i * 5 // 2
self.assertEqual(len(input), 3)
self.assertEqual(input, self.data[offset:offset + 3])
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
def test_batch_sampler(self):
self._test_batch_sampler()
self._test_batch_sampler(num_workers=4)
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_shuffle_pin_memory(self):
loader = DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
for input, target in loader:
self.assertTrue(input.is_pinned())
self.assertTrue(target.is_pinned())
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
def test_numpy(self):
import numpy as np
class TestDataset(torch.utils.data.Dataset):
def __getitem__(self, i):
return np.ones((2, 3, 4)) * i
def __len__(self):
return 1000
loader = DataLoader(TestDataset(), batch_size=12)
batch = next(iter(loader))
self.assertIsInstance(batch, torch.DoubleTensor)
self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))
def test_error(self):
self._test_error(DataLoader(ErrorDataset(100), batch_size=2, shuffle=True))
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
def test_error_workers(self):
self._test_error(DataLoader(ErrorDataset(41), batch_size=2, shuffle=True, num_workers=4))
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_partial_workers(self):
"check that workers exit even if the iterator is not exhausted"
loader = iter(DataLoader(self.dataset, batch_size=2, num_workers=4, pin_memory=True))
workers = loader.workers
worker_manager_thread = loader.worker_manager_thread
for i, sample in enumerate(loader):
if i == 3:
break
del loader
for w in workers:
w.join(JOIN_TIMEOUT)
self.assertFalse(w.is_alive(), 'subprocess not terminated')
self.assertEqual(w.exitcode, 0)
worker_manager_thread.join(JOIN_TIMEOUT)
self.assertFalse(worker_manager_thread.is_alive())
@staticmethod
def _manager_process(dataset, worker_pids, manager_exit_event):
loader = iter(DataLoader(dataset, batch_size=2, num_workers=4, pin_memory=True))
workers = loader.workers
for i in range(len(workers)):
worker_pids[i] = int(workers[i].pid)
for i, sample in enumerate(loader):
if i == 3:
break
# Simulate a dirty exit of the manager process
manager_exit_event.set()
if IS_WINDOWS:
os.system('taskkill /PID ' + str(os.getpid()) + ' /F')
else:
os.kill(os.getpid(), signal.SIGKILL)
@staticmethod
def _is_process_alive(pid, pname):
# There is a chance of a terminated child process's pid being reused by a new unrelated process,
# but since we are looping this check very frequently, we will know that the child process dies
# before the new unrelated process starts.
if IS_WINDOWS:
command = 'tasklist | find "{}" /i'.format(pid)
else:
command = 'ps -p {} -o comm='.format(pid)
p = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True)
(output, err) = p.communicate()
p_status = p.wait()
output = output.decode('utf-8')
return pname in output
@unittest.skipIf(sys.version_info[0] == 2,
"spawn start method is not supported in Python 2, \
but we need it for creating another process with CUDA")
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_manager_unclean_exit(self):
'''there might be ConnectionResetError or leaked semaphore warning (due to dirty process exit), \
but they are all safe to ignore'''
worker_pids = multiprocessing.Array('i', [0] * 4)
manager_exit_event = multiprocessing.Event()
mp = multiprocessing.Process(target=TestDataLoader._manager_process,
args=(self.dataset, worker_pids, manager_exit_event))
mp.start()
manager_exit_event.wait()
exit_status = [False] * len(worker_pids)
start_time = time.time()
pname = 'python'
while True:
for i in range(len(worker_pids)):
pid = worker_pids[i]
if not exit_status[i]:
if not TestDataLoader._is_process_alive(pid, pname):
exit_status[i] = True
if all(exit_status):
break
else:
time.sleep(1)
self.assertFalse(time.time() - start_time > MANAGER_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT,
'subprocess not terminated')
def test_len(self):
def check_len(dl, expected):
self.assertEqual(len(dl), expected)
n = 0
for sample in dl:
n += 1
self.assertEqual(n, expected)
check_len(self.dataset, 100)
check_len(DataLoader(self.dataset, batch_size=2), 50)
check_len(DataLoader(self.dataset, batch_size=3), 34)
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
def test_numpy_scalars(self):
import numpy as np
class ScalarDataset(torch.utils.data.Dataset):
def __init__(self, dtype):
self.dtype = dtype
def __getitem__(self, i):
return self.dtype()
def __len__(self):
return 4
dtypes = {
np.float64: torch.DoubleTensor,
np.float32: torch.FloatTensor,
np.float16: torch.HalfTensor,
np.int64: torch.LongTensor,
np.int32: torch.IntTensor,
np.int16: torch.ShortTensor,
np.int8: torch.CharTensor,
np.uint8: torch.ByteTensor,
}
for dt, tt in dtypes.items():
dset = ScalarDataset(dt)
loader = DataLoader(dset, batch_size=2)
batch = next(iter(loader))
self.assertIsInstance(batch, tt)
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
def test_default_colate_bad_numpy_types(self):
import numpy as np
# Should be a no-op
arr = np.array(['a', 'b', 'c'])
default_collate(arr)
arr = np.array([[['a', 'b', 'c']]])
self.assertRaises(TypeError, lambda: default_collate(arr))
arr = np.array([object(), object(), object()])
self.assertRaises(TypeError, lambda: default_collate(arr))
arr = np.array([[[object(), object(), object()]]])
self.assertRaises(TypeError, lambda: default_collate(arr))
class StringDataset(Dataset):
def __init__(self):
self.s = '12345'
def __len__(self):
return len(self.s)
def __getitem__(self, ndx):
return (self.s[ndx], ndx)
class TestStringDataLoader(TestCase):
def setUp(self):
self.dataset = StringDataset()
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_shuffle_pin_memory(self):
loader = DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
for batch_ndx, (s, n) in enumerate(loader):
self.assertIsInstance(s[0], str)
self.assertTrue(n.is_pinned())
class DictDataset(Dataset):
def __len__(self):
return 4
def __getitem__(self, ndx):
return {
'a_tensor': torch.Tensor(4, 2).fill_(ndx),
'another_dict': {
'a_number': ndx,
},
}
class TestDictDataLoader(TestCase):
def setUp(self):
self.dataset = DictDataset()
def test_sequential_batch(self):
loader = DataLoader(self.dataset, batch_size=2, shuffle=False)
batch_size = loader.batch_size
for i, sample in enumerate(loader):
idx = i * batch_size
self.assertEqual(set(sample.keys()), {'a_tensor', 'another_dict'})
self.assertEqual(set(sample['another_dict'].keys()), {'a_number'})
t = sample['a_tensor']
self.assertEqual(t.size(), torch.Size([batch_size, 4, 2]))
self.assertTrue((t[0] == idx).all())
self.assertTrue((t[1] == idx + 1).all())
n = sample['another_dict']['a_number']
self.assertEqual(n.size(), torch.Size([batch_size]))
self.assertEqual(n[0], idx)
self.assertEqual(n[1], idx + 1)
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_pin_memory(self):
loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
for batch_ndx, sample in enumerate(loader):
self.assertTrue(sample['a_tensor'].is_pinned())
self.assertTrue(sample['another_dict']['a_number'].is_pinned())
class TestWorkerQueueDataset(Dataset):
def __init__(self, data):
self.data = data
self.worker_id = None
def worker_init_fn(self, worker_id):
self.worker_id = worker_id
def __getitem__(self, item):
return self.worker_id, self.data[item]
def __len__(self):
return len(self.data)
class TestIndividualWorkerQueue(TestCase):
def setUp(self):
self.dataset = TestWorkerQueueDataset([i for i in range(128)])
def _run_ind_worker_queue_test(self, batch_size, num_workers):
loader = DataLoader(
self.dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers,
worker_init_fn=self.dataset.worker_init_fn
)
current_worker_idx = 0
for i, (worker_ids, sample) in enumerate(loader):
self.assertEqual(worker_ids.tolist(), [current_worker_idx] * batch_size)
self.assertEqual(sample.tolist(), [j for j in range(i * batch_size, (i + 1) * batch_size)])
current_worker_idx += 1
if current_worker_idx == num_workers:
current_worker_idx = 0
@unittest.skipIf(IS_WINDOWS, "FIXME: Intermittent CUDA out-of-memory error")
def test_ind_worker_queue(self):
for batch_size in (8, 16, 32, 64):
for num_workers in range(1, 6):
self._run_ind_worker_queue_test(batch_size=batch_size, num_workers=num_workers)
if __name__ == '__main__':
run_tests()