|  | # Owner(s): ["module: dataloader"] | 
|  |  | 
|  | import math | 
|  | import sys | 
|  | import errno | 
|  | import os | 
|  | import ctypes | 
|  | import faulthandler | 
|  | import torch | 
|  | import gc | 
|  | import time | 
|  | import signal | 
|  | import unittest | 
|  | import itertools | 
|  | import warnings | 
|  | import tempfile | 
|  | import torch.utils.data.datapipes as dp | 
|  | from torch import multiprocessing as mp | 
|  | from torch.utils.data import ( | 
|  | ChainDataset, | 
|  | ConcatDataset, | 
|  | DataLoader, | 
|  | Dataset, | 
|  | IterableDataset, | 
|  | IterDataPipe, | 
|  | Subset, | 
|  | TensorDataset, | 
|  | StackDataset, | 
|  | _utils | 
|  | ) | 
|  | from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL | 
|  | from torch.utils.data.dataset import random_split | 
|  | from torch.utils.data.datapipes.iter import IterableWrapper | 
|  | from torch._utils import ExceptionWrapper | 
|  | from torch.testing._internal.common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, IS_JETSON, | 
|  | IS_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, slowTest, | 
|  | load_tests, TEST_WITH_ASAN, TEST_WITH_TSAN, IS_SANDCASTLE, | 
|  | IS_MACOS, TEST_CUDA) | 
|  |  | 
|  |  | 
|  | try: | 
|  | import psutil | 
|  | HAS_PSUTIL = True | 
|  | except ImportError: | 
|  | HAS_PSUTIL = False | 
|  | err_msg = ("psutil not found. Some critical data loader tests relying on it " | 
|  | "(e.g., TestDataLoader.test_proper_exit) will not run.") | 
|  | if IS_CI: | 
|  | raise ImportError(err_msg) from None | 
|  | else: | 
|  | warnings.warn(err_msg) | 
|  |  | 
|  | try: | 
|  | import dill | 
|  | # XXX: By default, dill writes the Pickler dispatch table to inject its | 
|  | # own logic there. This globally affects the behavior of the standard library | 
|  | # pickler for any user who transitively depends on this module! | 
|  | # Undo this extension to avoid altering the behavior of the pickler globally. | 
|  | dill.extend(use_dill=False) | 
|  | HAS_DILL = True | 
|  | except ImportError: | 
|  | HAS_DILL = False | 
|  | skipIfNoDill = unittest.skipIf(not HAS_DILL, "no dill") | 
|  |  | 
|  |  | 
|  | try: | 
|  | import numpy as np | 
|  | HAS_NUMPY = True | 
|  | except ImportError: | 
|  | HAS_NUMPY = False | 
|  | skipIfNoNumpy = unittest.skipIf(not HAS_NUMPY, "no NumPy") | 
|  |  | 
|  | # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for | 
|  | # sharding on sandcastle. This line silences flake warnings | 
|  | load_tests = load_tests | 
|  |  | 
|  | if TEST_CUDA: | 
|  | torch.cuda.memory._set_allocator_settings('expandable_segments:False') | 
|  |  | 
|  | if not NO_MULTIPROCESSING_SPAWN: | 
|  | # We want to use `spawn` if able because some of our tests check that the | 
|  | # data loader terminiates gracefully. To prevent hanging in the testing | 
|  | # process, such data loaders are run in a separate subprocess. | 
|  | # | 
|  | # We also want to test the `pin_memory=True` configuration, thus `spawn` is | 
|  | # required to launch such processes and they initialize the CUDA context. | 
|  | # | 
|  | # Mixing different start method is a recipe for disaster (e.g., using a fork | 
|  | # `mp.Event` with a spawn `mp.Process` segfaults). So we set this globally | 
|  | # to avoid bugs. | 
|  | # | 
|  | # Get a multiprocessing context because some test / third party library will | 
|  | # set start_method when imported, and setting again triggers `RuntimeError`. | 
|  | mp = mp.get_context(method='spawn') | 
|  |  | 
|  |  | 
|  | # 60s of timeout? | 
|  | # Yes, in environments where physical CPU resources are shared, e.g., CI, the | 
|  | # time for a inter-process communication can be highly varying.  With 15~17s of | 
|  | # timeout, we have observed flakiness in some CI builds (see | 
|  | # pytorch/pytorch#14501, pytorch/pytorch#16608).  We follow the CPython | 
|  | # multiprocessing setup and set the timeout to 60s here: | 
|  | # | 
|  | # https://github.com/python/cpython/blob/e8113f51a8bdf33188ee30a1c038a298329e7bfa/Lib/test/_test_multiprocessing.py#L73 | 
|  | JOIN_TIMEOUT = 60.0  # seconds | 
|  |  | 
|  |  | 
|  | supported_multiprocessing_contexts = [None] + list(torch.multiprocessing.get_all_start_methods()) | 
|  |  | 
|  |  | 
|  | @unittest.skipIf( | 
|  | TEST_WITH_TSAN, | 
|  | "Fails with TSAN with the following error: starting new threads after multi-threaded " | 
|  | "fork is not supported. Dying (set die_after_fork=0 to override)") | 
|  | class TestDatasetRandomSplit(TestCase): | 
|  | def test_lengths_must_equal_dataset_size(self): | 
|  | with self.assertRaises(ValueError): | 
|  | random_split([1, 2, 3, 4], [1, 2]) | 
|  |  | 
|  | def test_splits_have_correct_size(self): | 
|  | splits = random_split([1, 2, 3, 4, 5, 6], [2, 4]) | 
|  | self.assertEqual(len(splits), 2) | 
|  | self.assertEqual(len(splits[0]), 2) | 
|  | self.assertEqual(len(splits[1]), 4) | 
|  |  | 
|  | splits = random_split([1, 2, 3, 4, 5, 6], [0.5, 0.5]) | 
|  | self.assertEqual(len(splits), 2) | 
|  | self.assertEqual(len(splits[0]), 3) | 
|  | self.assertEqual(len(splits[1]), 3) | 
|  |  | 
|  | # Odd size splits | 
|  | self.assertEqual( | 
|  | len(random_split(range(3), [0.5, 0.5], generator=torch.Generator().manual_seed(1))), | 
|  | 2 | 
|  | ) | 
|  |  | 
|  | # Odd sized round-robin splits | 
|  | splits = random_split(range(106), [0.1, 0.2, 0.3, 0.4], | 
|  | generator=torch.Generator().manual_seed(1)) | 
|  | self.assertEqual(len(splits[0]), 11) | 
|  | self.assertEqual(len(splits[1]), 22) | 
|  | self.assertEqual(len(splits[2]), 31) | 
|  | self.assertEqual(len(splits[3]), 42) | 
|  |  | 
|  |  | 
|  | def test_splits_are_mutually_exclusive(self): | 
|  | data = [5, 2, 3, 4, 1, 6] | 
|  | splits = random_split(data, [2, 4]) | 
|  | all_values = [] | 
|  | all_values.extend(list(splits[0])) | 
|  | all_values.extend(list(splits[1])) | 
|  | data.sort() | 
|  | all_values.sort() | 
|  | self.assertListEqual(data, all_values) | 
|  |  | 
|  | splits = random_split(data, [0.33, 0.67]) | 
|  | all_values = [] | 
|  | all_values.extend(list(splits[0])) | 
|  | all_values.extend(list(splits[1])) | 
|  | data.sort() | 
|  | all_values.sort() | 
|  | self.assertListEqual(data, all_values) | 
|  |  | 
|  | data = [1, 2, 3, 4] | 
|  | splits = random_split(data, [0.25, 0.75]) | 
|  | all_values = [] | 
|  | all_values.extend(list(splits[0])) | 
|  | all_values.extend(list(splits[1])) | 
|  | data.sort() | 
|  | all_values.sort() | 
|  | self.assertListEqual(data, all_values) | 
|  |  | 
|  | def test_splits_indexing_type(self): | 
|  | r"""Indices generated by random_split | 
|  | should be of integer type | 
|  | """ | 
|  | class CustomDataset: | 
|  | def __init__(self, test_object, custom_list): | 
|  | self.data = custom_list | 
|  | self.test_object = test_object | 
|  |  | 
|  | def __getitem__(self, key): | 
|  | self.test_object.assertEqual(type(key), int) | 
|  | return self.data[key] | 
|  |  | 
|  | def __len__(self): | 
|  | return len(self.data) | 
|  |  | 
|  | x = [1, 2, 3, 4, 5] | 
|  | dataset = CustomDataset(self, x) | 
|  | dataset = random_split(dataset, [5])[0] | 
|  | data_loader = DataLoader(dataset) | 
|  | for batch in data_loader: | 
|  | pass | 
|  |  | 
|  | # fractional splitting | 
|  | dataset = CustomDataset(self, x) | 
|  | dataset = random_split(dataset, [1.0])[0] | 
|  | data_loader = DataLoader(dataset) | 
|  | for batch in data_loader: | 
|  | pass | 
|  |  | 
|  | def test_splits_reproducibility(self): | 
|  | self.assertEqual( | 
|  | [list(x) for x in random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(1))], | 
|  | [[5, 6, 1], [2, 0, 8, 9, 3, 7, 4]], | 
|  | ) | 
|  | self.assertEqual( | 
|  | random_split(range(100), [60, 40], generator=torch.Generator().manual_seed(42)), | 
|  | random_split(range(100), [60, 40], generator=torch.Generator().manual_seed(42)), | 
|  | ) | 
|  | self.assertEqual( | 
|  | random_split(range(100), [0.5, 0.5], generator=torch.Generator().manual_seed(42)), | 
|  | random_split(range(100), [0.5, 0.5], generator=torch.Generator().manual_seed(42)), | 
|  | ) | 
|  | self.assertEqual( | 
|  | random_split(range(100), [0.33, 0.33, 0.34], generator=torch.Generator().manual_seed(42)), | 
|  | random_split(range(100), [0.33, 0.33, 0.34], generator=torch.Generator().manual_seed(42)), | 
|  | ) | 
|  |  | 
|  | def test_incomplete_fractional_splits(self): | 
|  | with self.assertRaises(ValueError): | 
|  | # should raise since the sum of fractions is not 1 | 
|  | random_split([1, 2, 3, 4], [0.1]) | 
|  |  | 
|  | with self.assertRaises(ValueError): | 
|  | # should raise since fraction > 1 | 
|  | random_split([1, 2, 3, 4], [1.1]) | 
|  |  | 
|  | def test_splits_generator(self): | 
|  | # A random_split without a specific generator should affect the default one | 
|  | state = torch.get_rng_state() | 
|  | a = torch.rand(10) | 
|  | torch.set_rng_state(state) | 
|  | random_split(range(10), [5, 5]) | 
|  | b = torch.rand(10) | 
|  | self.assertNotEqual(a, b) | 
|  |  | 
|  | # A random_split with a specific generator should not affect the default one | 
|  | state = torch.get_rng_state() | 
|  | a = torch.rand(10) | 
|  | torch.set_rng_state(state) | 
|  | random_split(range(10), [5, 5], generator=torch.Generator().manual_seed(42)) | 
|  | b = torch.rand(10) | 
|  | self.assertEqual(a, b) | 
|  |  | 
|  | def test_slicing_of_subset_of_dataset(self): | 
|  | # Testing slicing a subset initialized with a dataset | 
|  | dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5])) | 
|  | subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4]) | 
|  | self.assertEqual(subset_of_dataset[:], dataset[:]) | 
|  | self.assertEqual(subset_of_dataset[1:2], dataset[1:2]) | 
|  | self.assertEqual(subset_of_dataset[0:-1:2], dataset[0:-1:2]) | 
|  | # Testing slicing of subset from random split | 
|  | subset1, subset2 = random_split(dataset, [3, 2]) | 
|  | self.assertEqual(subset1[:], dataset[subset1.indices[:]]) | 
|  | self.assertEqual(subset1[0:2], dataset[subset1.indices[0:2]]) | 
|  | self.assertEqual(subset1[0:-1:2], dataset[subset1.indices[0:-1:2]]) | 
|  |  | 
|  | def test_slicing_of_subset_of_subset(self): | 
|  | # Testing slicing a subset initialized with a subset | 
|  | dataset = TensorDataset(torch.tensor([1, 2, 3, 4, 5])) | 
|  | subset_of_dataset = Subset(dataset, [0, 1, 2, 3, 4]) | 
|  | subset_of_subset = Subset(subset_of_dataset, [0, 1, 2, 3, 4]) | 
|  | self.assertEqual(subset_of_subset[:], dataset[:]) | 
|  | self.assertEqual(subset_of_subset[0:2], dataset[0:2]) | 
|  | self.assertEqual(subset_of_subset[0:-1:2], dataset[0:-1:2]) | 
|  | # Testing slicing of subset of subset from random split | 
|  | subset1, subset2 = random_split(dataset, [4, 1]) | 
|  | subset_of_subset1, subset_of_subset2 = random_split(subset1, [3, 1]) | 
|  | idx = [subset1.indices[i] for i in subset_of_subset1.indices] | 
|  | self.assertEqual(subset_of_subset1[:], dataset[idx[:]]) | 
|  | self.assertEqual(subset_of_subset1[0:2], dataset[idx[0:2]]) | 
|  | self.assertEqual(subset_of_subset1[0:-1:2], dataset[idx[0:-1:2]]) | 
|  |  | 
|  |  | 
|  | class CUDACountingDataset(Dataset): | 
|  | def __init__(self, n): | 
|  | super().__init__() | 
|  | self.n = n | 
|  |  | 
|  | def __getitem__(self, i): | 
|  | return torch.as_tensor(i, device='cuda') | 
|  |  | 
|  | def __len__(self): | 
|  | return self.n | 
|  |  | 
|  |  | 
|  | class CountingDataset(Dataset): | 
|  | def __init__(self, n): | 
|  | super().__init__() | 
|  | self.n = n | 
|  |  | 
|  | def __getitem__(self, i): | 
|  | return i | 
|  |  | 
|  | def __len__(self): | 
|  | return self.n | 
|  |  | 
|  |  | 
|  | class CountingIterableDataset(IterableDataset): | 
|  | def __init__(self, n): | 
|  | super().__init__() | 
|  | self.n = n | 
|  |  | 
|  | def __iter__(self): | 
|  | return iter(range(self.n)) | 
|  |  | 
|  | def __len__(self): | 
|  | return self.n | 
|  |  | 
|  |  | 
|  | @unittest.skipIf( | 
|  | TEST_WITH_TSAN, | 
|  | "Fails with TSAN with the following error: starting new threads after multi-threaded " | 
|  | "fork is not supported. Dying (set die_after_fork=0 to override)") | 
|  | class TestTensorDataset(TestCase): | 
|  |  | 
|  | def test_len(self): | 
|  | source = TensorDataset(torch.randn(15, 10, 2, 3, 4, 5), torch.randperm(15)) | 
|  | self.assertEqual(len(source), 15) | 
|  |  | 
|  | def test_getitem(self): | 
|  | t = torch.randn(15, 10, 2, 3, 4, 5) | 
|  | l = torch.randn(15, 10) | 
|  | source = TensorDataset(t, l) | 
|  | for i in range(15): | 
|  | self.assertEqual(t[i], source[i][0]) | 
|  | self.assertEqual(l[i], source[i][1]) | 
|  |  | 
|  | def test_getitem_1d(self): | 
|  | t = torch.randn(15) | 
|  | l = torch.randn(15) | 
|  | source = TensorDataset(t, l) | 
|  | for i in range(15): | 
|  | self.assertEqual(t[i], source[i][0]) | 
|  | self.assertEqual(l[i], source[i][1]) | 
|  |  | 
|  | def test_single_tensor(self): | 
|  | t = torch.randn(5, 10) | 
|  | source = TensorDataset(t) | 
|  | self.assertEqual(len(source), 5) | 
|  | for i in range(5): | 
|  | self.assertEqual(t[i], source[i][0]) | 
|  |  | 
|  | def test_many_tensors(self): | 
|  | t0 = torch.randn(5, 10, 2, 3, 4, 5) | 
|  | t1 = torch.randn(5, 10) | 
|  | t2 = torch.randn(5, 10, 2, 5) | 
|  | t3 = torch.randn(5, 10, 3, 7) | 
|  | source = TensorDataset(t0, t1, t2, t3) | 
|  | self.assertEqual(len(source), 5) | 
|  | for i in range(5): | 
|  | self.assertEqual(t0[i], source[i][0]) | 
|  | self.assertEqual(t1[i], source[i][1]) | 
|  | self.assertEqual(t2[i], source[i][2]) | 
|  | self.assertEqual(t3[i], source[i][3]) | 
|  |  | 
|  |  | 
|  | @unittest.skipIf( | 
|  | TEST_WITH_TSAN, | 
|  | "Fails with TSAN with the following error: starting new threads after multi-threaded " | 
|  | "fork is not supported. Dying (set die_after_fork=0 to override)") | 
|  | class TestStackDataset(TestCase): | 
|  |  | 
|  | def test_empty(self): | 
|  | with self.assertRaisesRegex(ValueError, "At least one dataset should be passed"): | 
|  | StackDataset() | 
|  |  | 
|  | def test_mixed(self): | 
|  | with self.assertRaisesRegex(ValueError, "Supported either"): | 
|  | StackDataset(TensorDataset(torch.randn(15, 10)), a=TensorDataset(torch.randn(10, 15))) | 
|  |  | 
|  | def test_size_mismatch(self): | 
|  | with self.assertRaisesRegex(ValueError, "Size mismatch between datasets"): | 
|  | StackDataset(TensorDataset(torch.randn(15, 10)), TensorDataset(torch.randn(10, 15))) | 
|  | with self.assertRaisesRegex(ValueError, "Size mismatch between datasets"): | 
|  | StackDataset(a=TensorDataset(torch.randn(15, 10)), b=TensorDataset(torch.randn(10, 15))) | 
|  |  | 
|  | def test_len(self): | 
|  | source = StackDataset(TensorDataset(torch.randn(15, 10)), TensorDataset(torch.randn(15))) | 
|  | self.assertEqual(len(source), 15) | 
|  | source = StackDataset(TensorDataset(torch.randn(15, 10))) | 
|  | self.assertEqual(len(source), 15) | 
|  | source = StackDataset(a=TensorDataset(torch.randn(15, 10)), b=TensorDataset(torch.randn(15))) | 
|  | self.assertEqual(len(source), 15) | 
|  | source = StackDataset(a=TensorDataset(torch.randn(15, 10))) | 
|  | self.assertEqual(len(source), 15) | 
|  |  | 
|  | def test_single(self): | 
|  | t = TensorDataset(torch.randn(15, 10)) | 
|  | source = StackDataset(t) | 
|  | for i in range(15): | 
|  | self.assertEqual(t[i], source[i][0]) | 
|  | source = StackDataset(a=t) | 
|  | for i in range(15): | 
|  | self.assertEqual(t[i], source[i]['a']) | 
|  |  | 
|  | def test_getitem(self): | 
|  | t = TensorDataset(torch.randn(15, 10)) | 
|  | l = TensorDataset(torch.randn(15, 5, 4)) | 
|  | source = StackDataset(t, l) | 
|  | for i in range(15): | 
|  | self.assertEqual(t[i], source[i][0]) | 
|  | self.assertEqual(l[i], source[i][1]) | 
|  | source = StackDataset(a=t, b=l) | 
|  | for i in range(15): | 
|  | self.assertEqual(t[i], source[i]['a']) | 
|  | self.assertEqual(l[i], source[i]['b']) | 
|  |  | 
|  |  | 
|  | @unittest.skipIf( | 
|  | TEST_WITH_TSAN, | 
|  | "Fails with TSAN with the following error: starting new threads after multi-threaded " | 
|  | "fork is not supported. Dying (set die_after_fork=0 to override)") | 
|  | class TestConcatDataset(TestCase): | 
|  |  | 
|  | def test_concat_two_singletons(self): | 
|  | result = ConcatDataset([[0], [1]]) | 
|  | self.assertEqual(2, len(result)) | 
|  | self.assertEqual(0, result[0]) | 
|  | self.assertEqual(1, result[1]) | 
|  |  | 
|  | def test_concat_two_non_singletons(self): | 
|  | result = ConcatDataset([[0, 1, 2, 3, 4], | 
|  | [5, 6, 7, 8, 9]]) | 
|  | self.assertEqual(10, len(result)) | 
|  | self.assertEqual(0, result[0]) | 
|  | self.assertEqual(5, result[5]) | 
|  |  | 
|  | def test_concat_two_non_singletons_with_empty(self): | 
|  | # Adding an empty dataset somewhere is correctly handled | 
|  | result = ConcatDataset([[0, 1, 2, 3, 4], | 
|  | [], | 
|  | [5, 6, 7, 8, 9]]) | 
|  | self.assertEqual(10, len(result)) | 
|  | self.assertEqual(0, result[0]) | 
|  | self.assertEqual(5, result[5]) | 
|  |  | 
|  | def test_concat_raises_index_error(self): | 
|  | result = ConcatDataset([[0, 1, 2, 3, 4], | 
|  | [5, 6, 7, 8, 9]]) | 
|  | with self.assertRaises(IndexError): | 
|  | # this one goes to 11 | 
|  | result[11] | 
|  |  | 
|  | def test_add_dataset(self): | 
|  | d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7)) | 
|  | d2 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7)) | 
|  | d3 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7)) | 
|  | result = d1 + d2 + d3 | 
|  | self.assertEqual(21, len(result)) | 
|  | self.assertEqual(0, (d1[0][0] - result[0][0]).abs().sum()) | 
|  | self.assertEqual(0, (d2[0][0] - result[7][0]).abs().sum()) | 
|  | self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum()) | 
|  |  | 
|  | def test_iterable_dataset_err(self): | 
|  | d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7)) | 
|  | it1 = CountingIterableDataset(5) | 
|  | it2 = CountingIterableDataset(10) | 
|  |  | 
|  | with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"): | 
|  | ConcatDataset([d1, it2, it1]) | 
|  |  | 
|  | with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"): | 
|  | ConcatDataset([it2]) | 
|  |  | 
|  | with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"): | 
|  | ConcatDataset([it1, d1]) | 
|  |  | 
|  |  | 
|  | # takes in dummy var so this can also be used as a `worker_init_fn` | 
|  | def set_faulthander_if_available(_=None): | 
|  | faulthandler.enable(sys.__stderr__) | 
|  | if not IS_WINDOWS: | 
|  | # windows does not have faulthandler.register | 
|  | # chain=False prevents the default behavior of killing the process | 
|  | faulthandler.register(signal.SIGUSR1, file=sys.__stderr__, chain=False) | 
|  |  | 
|  |  | 
|  | set_faulthander_if_available() | 
|  |  | 
|  | # Process `pid` must have called `set_faulthander_if_available` | 
|  | def print_traces_of_all_threads(pid): | 
|  | if not IS_WINDOWS: | 
|  | # use the custom signal if available | 
|  | os.kill(pid, signal.SIGUSR1) | 
|  | else: | 
|  | # otherwise we can still use the handler given by faulthandler.enable() | 
|  | # at the cost of killing the process. | 
|  | os.kill(pid, signal.SIGSEGV) | 
|  |  | 
|  | # wait in parent process to give subprocess some time to print | 
|  | time.sleep(5) | 
|  |  | 
|  |  | 
|  | # The following `ErrorTrackingProcess` stores the first encountered exception in | 
|  | # its `.exception` attribute. | 
|  | # Inspired by https://stackoverflow.com/a/33599967 | 
|  | class ErrorTrackingProcess(mp.Process): | 
|  |  | 
|  | # Why no *args? | 
|  | #   py2 doesn't support def fn(x, *args, key=val, **kwargs) | 
|  | # Setting disable_stderr=True may generate a lot of unrelated error outputs | 
|  | # but could be helpful for debugging. | 
|  | def __init__(self, disable_stderr=True, **kwargs): | 
|  | super().__init__(**kwargs) | 
|  | self._pconn, self._cconn = mp.Pipe() | 
|  | self._exception = None | 
|  | self.disable_stderr = disable_stderr | 
|  |  | 
|  | def run(self): | 
|  | set_faulthander_if_available() | 
|  | if self.disable_stderr: | 
|  | # Disable polluting stderr with errors that are supposed to happen. | 
|  | with open(os.devnull, 'w') as devnull: | 
|  | os.dup2(devnull.fileno(), sys.stderr.fileno()) | 
|  | try: | 
|  | super().run() | 
|  | self._cconn.send(None) | 
|  | except Exception: | 
|  | self._cconn.send(ExceptionWrapper(sys.exc_info())) | 
|  | raise | 
|  |  | 
|  | def print_traces_of_all_threads(self): | 
|  | assert self.is_alive(), "can only use print_traces_of_all_threads if the process is alive" | 
|  | assert not self.disable_stderr, "do not disable stderr if you use print_traces_of_all_threads" | 
|  | # On platforms without `SIGUSR1`, `set_faulthander_if_available` sets | 
|  | # `faulthandler.enable()`, and `print_traces_of_all_threads` may kill | 
|  | # the process. So let's poll the exception first | 
|  | _ = self.exception | 
|  | print_traces_of_all_threads(self.pid) | 
|  |  | 
|  | @property | 
|  | def exception(self): | 
|  | if self._pconn.poll(): | 
|  | self._exception = self._pconn.recv() | 
|  | if self._exception is None: | 
|  | return None | 
|  | else: | 
|  | return self._exception.exc_type(self._exception.exc_msg) | 
|  |  | 
|  | # ESRCH means that os.kill can't finds alive proc | 
|  | def send_signal(self, signum, ignore_ESRCH=False): | 
|  | try: | 
|  | os.kill(self.pid, signum) | 
|  | except OSError as e: | 
|  | if not ignore_ESRCH or e.errno != errno.ESRCH: | 
|  | raise | 
|  |  | 
|  |  | 
|  | class ErrorDataset(Dataset): | 
|  |  | 
|  | def __init__(self, size): | 
|  | self.size = size | 
|  |  | 
|  | def __len__(self): | 
|  | return self.size | 
|  |  | 
|  |  | 
|  | class SegfaultDataset(Dataset): | 
|  |  | 
|  | def __init__(self, size): | 
|  | self.size = size | 
|  |  | 
|  | def __getitem__(self, idx): | 
|  | return ctypes.string_at(0) | 
|  |  | 
|  | def __len__(self): | 
|  | return self.size | 
|  |  | 
|  |  | 
|  | class SleepDataset(Dataset): | 
|  |  | 
|  | def __init__(self, size, sleep_sec): | 
|  | self.size = size | 
|  | self.sleep_sec = sleep_sec | 
|  | self.sleeped = False | 
|  |  | 
|  | def __getitem__(self, idx): | 
|  | if not self.sleeped: | 
|  | time.sleep(self.sleep_sec) | 
|  | self.sleeped = True | 
|  | return idx | 
|  |  | 
|  | def __len__(self): | 
|  | return self.size | 
|  |  | 
|  |  | 
|  | class SeedDataset(Dataset): | 
|  |  | 
|  | def __init__(self, size): | 
|  | self.size = size | 
|  |  | 
|  | def __getitem__(self, idx): | 
|  | return torch.initial_seed() | 
|  |  | 
|  | def __len__(self): | 
|  | return self.size | 
|  |  | 
|  |  | 
|  | class WorkerSpecificIterableDataset(IterableDataset): | 
|  | def __init__(self, sizes_for_all_workers): | 
|  | self.sizes_for_all_workers = sizes_for_all_workers | 
|  |  | 
|  | def __iter__(self): | 
|  | worker_info = torch.utils.data.get_worker_info() | 
|  | assert worker_info is not None | 
|  | return iter(range(self.sizes_for_all_workers[worker_info.id])) | 
|  |  | 
|  | def __len__(self): | 
|  | return sum(self.sizes_for_all_workers) | 
|  |  | 
|  |  | 
|  | # Inspired by https://stackoverflow.com/a/26703365 | 
|  | # If all workers will call `sync_once`, they will be blocked until all workers | 
|  | # reach the call (i.e., acting like a barrier). | 
|  | # This can be used to ensure that each worker at least processes one data. | 
|  | class SynchronizedDataset(Dataset): | 
|  |  | 
|  | def __init__(self, size, batch_size, num_workers): | 
|  | assert size >= num_workers * batch_size | 
|  | self.count = mp.Value('i', 0, lock=True) | 
|  | self.barrier = mp.Semaphore(0) | 
|  | self.num_workers = num_workers | 
|  | self.size = size | 
|  |  | 
|  | def sync_once(self): | 
|  | with self.count.get_lock(): | 
|  | self.count.value += 1 | 
|  | if self.count.value == self.num_workers: | 
|  | self.barrier.release() | 
|  | self.barrier.acquire() | 
|  | self.barrier.release() | 
|  |  | 
|  | def __getitem__(self, idx): | 
|  | raise NotImplementedError | 
|  |  | 
|  | def __len__(self): | 
|  | return self.size | 
|  |  | 
|  |  | 
|  | class EmptyTensorDataset(torch.utils.data.Dataset): | 
|  | def __init__(self, len): | 
|  | self.len = len | 
|  |  | 
|  | def __len__(self): | 
|  | return self.len | 
|  |  | 
|  | def __getitem__(self, any): | 
|  | return torch.empty(0) | 
|  |  | 
|  |  | 
|  | class SynchronizedSeedDataset(SynchronizedDataset): | 
|  | def __getitem__(self, idx): | 
|  | self.sync_once() | 
|  | return torch.initial_seed() | 
|  |  | 
|  |  | 
|  | def _test_timeout(persistent_workers): | 
|  | dataset = SleepDataset(10, 3) | 
|  | dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1, | 
|  | persistent_workers=persistent_workers) | 
|  | _ = next(iter(dataloader)) | 
|  |  | 
|  |  | 
|  | def _test_timeout_pin_memory(persistent_workers): | 
|  | dataset = SleepDataset(10, 3) | 
|  | dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1, pin_memory=True, | 
|  | persistent_workers=persistent_workers) | 
|  | _ = next(iter(dataloader)) | 
|  |  | 
|  |  | 
|  | def _test_large_sampler_indices(persistent_workers): | 
|  | # See | 
|  | #   test_large_sampler_indices | 
|  | #   https://github.com/pytorch/pytorch/issues/48666 | 
|  |  | 
|  | dataloader = torch.utils.data.DataLoader( | 
|  | EmptyTensorDataset(10000000), | 
|  | batch_size=40960, | 
|  | persistent_workers=persistent_workers, | 
|  | num_workers=1) | 
|  |  | 
|  | it = iter(dataloader) | 
|  |  | 
|  | for x in it: | 
|  | assert x.numel() == 0 | 
|  | raise RuntimeError('My Error') | 
|  |  | 
|  |  | 
|  | def disable_stderr(worker_id): | 
|  | r""" | 
|  | Avoids printing "ERROR: Unexpected segmentation fault encountered in worker." | 
|  | from workers. Since worker signal handler prints with low-level write(), | 
|  | this has to be done on OS level via dup. | 
|  |  | 
|  | This is used as worker_init_fn for test_segfault. | 
|  | """ | 
|  | sys.stderr.flush()  # flush library buffers that dup2 knows nothing about | 
|  | # Can't use a with-block because otherwise the fd will be closed when this | 
|  | # function ends. | 
|  | with open(os.devnull, 'w') as devnull: | 
|  | os.dup2(devnull.fileno(), sys.stderr.fileno()) | 
|  |  | 
|  |  | 
|  | def _test_segfault(): | 
|  | dataset = SegfaultDataset(10) | 
|  | dataloader = DataLoader(dataset, batch_size=2, num_workers=2, worker_init_fn=disable_stderr) | 
|  | _ = next(iter(dataloader)) | 
|  |  | 
|  |  | 
|  | def _test_no_segfault(): | 
|  | dataset = [1, 2, 3] | 
|  | num_threads = torch.get_num_threads() | 
|  | if num_threads < 4: | 
|  | torch.set_num_threads(4) | 
|  | else: | 
|  | torch.set_num_threads(num_threads) | 
|  | mp_ctx = torch.multiprocessing.get_context(method='fork') | 
|  | dataloader = DataLoader(dataset, num_workers=1, worker_init_fn=disable_stderr, | 
|  | multiprocessing_context=mp_ctx) | 
|  | _ = next(iter(dataloader)) | 
|  |  | 
|  |  | 
|  | class TestProperExitDataset(Dataset): | 
|  | def __init__(self, size, error_event): | 
|  | self.size = size | 
|  | self.error_event = error_event | 
|  |  | 
|  | def __len__(self): | 
|  | return self.size | 
|  |  | 
|  | def __getitem__(self, idx): | 
|  | worker_info = torch.utils.data.get_worker_info() | 
|  | if self.error_event is not None and self.error_event.is_set() and \ | 
|  | worker_info.id == worker_info.num_workers - 1: | 
|  | # only error in the last worker | 
|  | raise RuntimeError('Worker error') | 
|  | return torch.tensor([idx]) | 
|  |  | 
|  |  | 
|  | class TestProperExitIterableDataset(IterableDataset): | 
|  | def __init__(self, size, error_event): | 
|  | self.error_event = error_event | 
|  | self.size = size | 
|  | self.remaining = size | 
|  |  | 
|  | def __len__(self): | 
|  | return self.size | 
|  |  | 
|  | def __iter__(self): | 
|  | return self | 
|  |  | 
|  | def __next__(self): | 
|  | worker_info = torch.utils.data.get_worker_info() | 
|  | if self.error_event is not None and self.error_event.is_set() and \ | 
|  | worker_info.id == worker_info.num_workers - 1: | 
|  | # only error in the last worker | 
|  | raise RuntimeError('Worker error') | 
|  | self.remaining -= 1 | 
|  | if self.remaining < 0: | 
|  | raise StopIteration | 
|  | return torch.tensor(-1000) | 
|  |  | 
|  |  | 
|  | # See TestDataLoader.test_proper_exit for usage | 
|  | def _test_proper_exit(is_iterable_dataset, use_workers, pin_memory, exit_method, | 
|  | hold_iter_reference, loader_setup_event, tester_setup_event, | 
|  | persistent_workers): | 
|  | num_workers = 2 if use_workers else 0 | 
|  |  | 
|  | if exit_method == 'worker_error' or exit_method == 'worker_kill': | 
|  | assert use_workers is True | 
|  |  | 
|  | if exit_method == 'worker_error': | 
|  | worker_error_event = mp.Event() | 
|  | else: | 
|  | worker_error_event = None | 
|  |  | 
|  | if is_iterable_dataset: | 
|  | ds = TestProperExitIterableDataset(7, worker_error_event) | 
|  | else: | 
|  | ds = TestProperExitDataset(12, worker_error_event) | 
|  |  | 
|  | loader = DataLoader(ds, batch_size=1, shuffle=False, | 
|  | num_workers=num_workers, pin_memory=pin_memory, | 
|  | worker_init_fn=set_faulthander_if_available, | 
|  | persistent_workers=persistent_workers) | 
|  |  | 
|  | error_it = 2 | 
|  |  | 
|  | if use_workers: | 
|  | # 2 is the magical per-worker prefetch number... | 
|  | # FIXME: change this after the number becomes configurable. | 
|  | if is_iterable_dataset: | 
|  | assert len(ds) * num_workers > (error_it + 2 + 1) | 
|  | else: | 
|  | assert len(loader) > (error_it + 2 + 1) * num_workers | 
|  | else: | 
|  | if is_iterable_dataset: | 
|  | assert len(ds) > error_it + 1 | 
|  | else: | 
|  | assert len(loader) > error_it + 1 | 
|  |  | 
|  | it = iter(loader) | 
|  | if use_workers: | 
|  | workers = it._workers | 
|  |  | 
|  | def kill_pid(pid): | 
|  | psutil_p = psutil.Process(pid) | 
|  | psutil_p.kill() | 
|  | psutil_p.wait(JOIN_TIMEOUT) | 
|  | assert not psutil_p.is_running() | 
|  |  | 
|  | for i, _ in enumerate(it): | 
|  | if i == 0: | 
|  | if not hold_iter_reference: | 
|  | del it | 
|  | del loader | 
|  | loader_setup_event.set() | 
|  | tester_setup_event.wait() | 
|  | # ensure that the workers are still alive | 
|  | if use_workers: | 
|  | for w in workers: | 
|  | assert w.is_alive() | 
|  | if worker_error_event is not None: | 
|  | worker_error_event.set() | 
|  |  | 
|  | if i == error_it: | 
|  | if exit_method == 'loader_error': | 
|  | raise RuntimeError('Loader error') | 
|  | elif exit_method == 'loader_kill': | 
|  | kill_pid(os.getpid()) | 
|  | elif exit_method == 'worker_kill': | 
|  | kill_pid(workers[-1].pid)  # kill last worker | 
|  |  | 
|  | if not hold_iter_reference: | 
|  | # Tries to trigger the __del__ clean-up rather than the automatic | 
|  | # exiting of daemonic children. Technically it should be automatically | 
|  | # triggered, but I don't want to rely on the implementation detail of | 
|  | # Python gc. | 
|  | gc.collect() | 
|  |  | 
|  |  | 
|  | class TestWorkerInfoDataset(SynchronizedDataset): | 
|  | def __getitem__(self, idx): | 
|  | self.sync_once() | 
|  | return torch.tensor(self.value) | 
|  |  | 
|  |  | 
|  | # Should be used as worker_init_fn with TestWorkerInfoDataset. | 
|  | # See _test_get_worker_info below for usage. | 
|  | def _test_worker_info_init_fn(worker_id): | 
|  | worker_info = torch.utils.data.get_worker_info() | 
|  | assert worker_id == worker_info.id, "worker_init_fn and worker_info should have consistent id" | 
|  | assert worker_id < worker_info.num_workers, "worker_init_fn and worker_info should have valid id" | 
|  | assert worker_info.seed == torch.initial_seed(), "worker_init_fn and worker_info should have consistent seed" | 
|  | dataset = worker_info.dataset | 
|  | assert isinstance(dataset, TestWorkerInfoDataset), "worker_info should have correct dataset copy" | 
|  | assert not hasattr(dataset, 'value'), "worker_info should have correct dataset copy" | 
|  | # test that WorkerInfo attributes are read-only | 
|  | try: | 
|  | worker_info.id = 3999 | 
|  | except RuntimeError as e: | 
|  | assert str(e) == "Cannot assign attributes to WorkerInfo objects" | 
|  | try: | 
|  | worker_info.a = 3 | 
|  | except RuntimeError as e: | 
|  | assert str(e) == "Cannot assign attributes to WorkerInfo objects" | 
|  | for k in ['id', 'num_workers', 'seed', 'dataset']: | 
|  | assert f"{k}=" in repr(worker_info) | 
|  | dataset.value = [worker_id, os.getpid()] | 
|  |  | 
|  |  | 
|  | def _test_get_worker_info(): | 
|  | # get_worker_info returns None in main proc | 
|  | assert torch.utils.data.get_worker_info() is None | 
|  | num_workers = 2 | 
|  | batch_size = 2 | 
|  | dataset = TestWorkerInfoDataset(6, batch_size, num_workers) | 
|  | dataloader = DataLoader(dataset, batch_size=batch_size, | 
|  | num_workers=num_workers, | 
|  | worker_init_fn=_test_worker_info_init_fn) | 
|  | it = iter(dataloader) | 
|  | data = [] | 
|  | for d in it: | 
|  | data.append(d) | 
|  | worker_pids = [w.pid for w in it._workers] | 
|  | data = torch.cat(data, 0) | 
|  | for d in data: | 
|  | # each `d` is a [worker_id, worker_pid] pair, which is set in | 
|  | # _test_worker_info_init_fn | 
|  | assert d[1] == worker_pids[d[0]] | 
|  | # get_worker_info returns None in main proc after data loading | 
|  | assert torch.utils.data.get_worker_info() is None | 
|  | # main proc dataset was never assigned this attribute | 
|  | assert not hasattr(dataset, 'value') | 
|  | try: | 
|  | _ = dataset[0] | 
|  | except AttributeError: | 
|  | return | 
|  | raise RuntimeError('Expected AttributeError') | 
|  |  | 
|  |  | 
|  | # test custom init function | 
|  | def init_fn(worker_id): | 
|  | torch.manual_seed(12345) | 
|  |  | 
|  |  | 
|  | # used with test_error_in_init | 
|  | class ErrorIterableDataset(IterableDataset): | 
|  | def __iter__(self): | 
|  | raise RuntimeError("Error in __iter__") | 
|  |  | 
|  |  | 
|  | # used with test_error_in_init | 
|  | def error_worker_init_fn(_): | 
|  | raise RuntimeError("Error in worker_init_fn") | 
|  |  | 
|  |  | 
|  | class BulkLoadingDataset(Dataset): | 
|  | def __init__(self, length): | 
|  | self.length = length | 
|  |  | 
|  | def __getitem__(self, indices): | 
|  | assert isinstance(indices, (list, tuple)) | 
|  | return torch.as_tensor(indices) | 
|  |  | 
|  | def __len__(self): | 
|  | return self.length | 
|  |  | 
|  |  | 
|  | class BulkLoadingSampler(torch.utils.data.Sampler): | 
|  | def __init__(self, dataset, batch_size): | 
|  | self.dataset = dataset | 
|  | self.batch_size = batch_size | 
|  |  | 
|  | def __iter__(self): | 
|  | for x in torch.randperm(len(self.dataset)).split(self.batch_size): | 
|  | yield x.tolist() | 
|  |  | 
|  | def __len__(self): | 
|  | return int(math.ceil(len(self.dataset) / float(self.batch_size))) | 
|  |  | 
|  |  | 
|  | class TestMultiEpochDataset(IterableDataset): | 
|  | def __init__(self, length): | 
|  | self.length = length | 
|  |  | 
|  | def __iter__(self): | 
|  | worker_info = torch.utils.data.get_worker_info() | 
|  | assert worker_info is not None | 
|  | worker_id = worker_info.id | 
|  | for idx in range(self.length // worker_info.num_workers): | 
|  | yield worker_id | 
|  |  | 
|  | def __len__(self): | 
|  | return self.length | 
|  |  | 
|  |  | 
|  | class CustomList(list): | 
|  | pass | 
|  |  | 
|  |  | 
|  | class CustomDict(dict): | 
|  | pass | 
|  |  | 
|  |  | 
|  | def row_processor(row): | 
|  | return np.add(row, 1) | 
|  |  | 
|  |  | 
|  | def filter_len(row): | 
|  | return len(row) == 4 | 
|  |  | 
|  |  | 
|  | @unittest.skipIf( | 
|  | TEST_WITH_TSAN, | 
|  | "Fails with TSAN with the following error: starting new threads after multi-threaded " | 
|  | "fork is not supported. Dying (set die_after_fork=0 to override)") | 
|  | @unittest.skipIf( | 
|  | TEST_WITH_ASAN, | 
|  | "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223") | 
|  | class TestDataLoader(TestCase): | 
|  |  | 
|  | def setUp(self): | 
|  | super().setUp() | 
|  | self.data = torch.randn(100, 2, 3, 5) | 
|  | self.labels = torch.randperm(50).repeat(2) | 
|  | self.dataset = TensorDataset(self.data, self.labels) | 
|  | self.persistent_workers = False | 
|  |  | 
|  | def _get_data_loader(self, dataset, **kwargs): | 
|  | persistent_workers = kwargs.get('persistent_workers', self.persistent_workers) | 
|  | if persistent_workers and kwargs.get('num_workers', 0) == 0: | 
|  | persistent_workers = False | 
|  | kwargs['persistent_workers'] = persistent_workers | 
|  | return DataLoader(dataset, **kwargs) | 
|  |  | 
|  | def _test_sequential(self, loader): | 
|  | batch_size = loader.batch_size | 
|  | if batch_size is None: | 
|  | for idx, (sample, target) in enumerate(loader): | 
|  | self.assertEqual(sample, self.data[idx]) | 
|  | self.assertEqual(target, self.labels[idx]) | 
|  | self.assertEqual(idx, len(self.dataset) - 1) | 
|  | else: | 
|  | for i, (sample, target) in enumerate(loader): | 
|  | idx = i * batch_size | 
|  | self.assertEqual(sample, self.data[idx:idx + batch_size]) | 
|  | self.assertEqual(target, self.labels[idx:idx + batch_size]) | 
|  | self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size)) | 
|  |  | 
|  | def _test_shuffle(self, loader): | 
|  | found_data = {i: 0 for i in range(self.data.size(0))} | 
|  | found_labels = {i: 0 for i in range(self.labels.size(0))} | 
|  | batch_size = loader.batch_size | 
|  | if batch_size is None: | 
|  | for i, (batch_samples, batch_targets) in enumerate(loader): | 
|  | sample, target = (batch_samples, batch_targets) | 
|  | for data_point_idx, data_point in enumerate(self.data): | 
|  | if data_point.eq(sample).all(): | 
|  | self.assertFalse(found_data[data_point_idx]) | 
|  | found_data[data_point_idx] += 1 | 
|  | break | 
|  | self.assertEqual(target, self.labels[data_point_idx]) | 
|  | found_labels[data_point_idx] += 1 | 
|  | self.assertEqual(sum(found_data.values()), (i + 1)) | 
|  | self.assertEqual(sum(found_labels.values()), (i + 1)) | 
|  | self.assertEqual(i, (len(self.dataset) - 1)) | 
|  | else: | 
|  | for i, (batch_samples, batch_targets) in enumerate(loader): | 
|  | for sample, target in zip(batch_samples, batch_targets): | 
|  | for data_point_idx, data_point in enumerate(self.data): | 
|  | if data_point.eq(sample).all(): | 
|  | self.assertFalse(found_data[data_point_idx]) | 
|  | found_data[data_point_idx] += 1 | 
|  | break | 
|  | self.assertEqual(target, self.labels[data_point_idx]) | 
|  | found_labels[data_point_idx] += 1 | 
|  | self.assertEqual(sum(found_data.values()), (i + 1) * batch_size) | 
|  | self.assertEqual(sum(found_labels.values()), (i + 1) * batch_size) | 
|  | self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size)) | 
|  |  | 
|  | def _test_error(self, loader): | 
|  | it = iter(loader) | 
|  | errors = 0 | 
|  | while True: | 
|  | try: | 
|  | next(it) | 
|  | except NotImplementedError: | 
|  | errors += 1 | 
|  | except StopIteration: | 
|  | self.assertEqual(errors, | 
|  | math.ceil(float(len(loader.dataset)) / loader.batch_size)) | 
|  | return | 
|  |  | 
|  | def test_error_in_init(self): | 
|  | for num_workers in [0, 2]: | 
|  | loader = self._get_data_loader(ErrorIterableDataset(), num_workers=num_workers) | 
|  | with self.assertRaisesRegex(RuntimeError, 'Error in __iter__'): | 
|  | list(iter(loader)) | 
|  |  | 
|  | loader = self._get_data_loader(self.dataset, num_workers=2, worker_init_fn=error_worker_init_fn) | 
|  | with self.assertRaisesRegex(RuntimeError, 'Error in worker_init_fn'): | 
|  | list(iter(loader)) | 
|  |  | 
|  | def test_typing(self): | 
|  | from typing import List | 
|  | # Make sure there is no TypeError | 
|  |  | 
|  | class SomeDatasetClass(Dataset[List[torch.Tensor]]): | 
|  | pass | 
|  |  | 
|  | def _create_dataloader(is_train: bool) -> DataLoader[List[torch.Tensor]]: | 
|  | pass | 
|  |  | 
|  | @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI") | 
|  | @unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows") | 
|  | def test_fd_limit_exceeded(self): | 
|  | # See NOTE [ DataLoader on Linux and open files limit ] | 
|  | import subprocess | 
|  | subprocess.check_output([sys.executable, '-c', """\ | 
|  | import torch | 
|  | import resource | 
|  | from torch.utils.data import DataLoader, IterableDataset | 
|  |  | 
|  | class RandomDataset(IterableDataset): | 
|  | def __init__(self, len, size): | 
|  | super(RandomDataset).__init__() | 
|  | self.len = len | 
|  | self.size = size | 
|  |  | 
|  | def __iter__(self): | 
|  | return self | 
|  |  | 
|  | def __next__(self): | 
|  | if self.len <= 0: | 
|  | raise StopIteration | 
|  | self.len -= 1 | 
|  | return torch.randn(self.size) | 
|  |  | 
|  | try: | 
|  | keep_fds_alive = [] | 
|  | resource.setrlimit(resource.RLIMIT_NOFILE, (100, 100)) | 
|  | for random_t in DataLoader(RandomDataset(200, (2,2)), multiprocessing_context="fork", | 
|  | num_workers=1): | 
|  | random_t.max(dim=0) | 
|  | keep_fds_alive.append(random_t) | 
|  | except RuntimeError as e: | 
|  | assert "ulimit -n" in str(e) | 
|  | assert "set_sharing_strategy" in str(e) | 
|  | """]) | 
|  |  | 
|  | def test_invalid_assign_after_init(self): | 
|  | dl = self._get_data_loader(self.dataset) | 
|  | for attr in ('batch_size', 'sampler', 'batch_sampler', 'drop_last', 'dataset'): | 
|  | def fn(): | 
|  | setattr(dl, attr, {}) | 
|  |  | 
|  | self.assertRaises(ValueError, fn) | 
|  |  | 
|  | def test_sequential_nonbatch(self): | 
|  | self._test_sequential(self._get_data_loader(self.dataset, batch_size=None)) | 
|  |  | 
|  | def test_sequential_batch(self): | 
|  | self._test_sequential(self._get_data_loader(self.dataset)) | 
|  | self._test_sequential(self._get_data_loader(self.dataset, batch_size=2)) | 
|  |  | 
|  | def test_bulk_loading_nobatch(self): | 
|  | n = 35 | 
|  | bs = 4 | 
|  | ds = BulkLoadingDataset(n) | 
|  | sampler = BulkLoadingSampler(ds, batch_size=4) | 
|  |  | 
|  | for num_workers in [0, 4]: | 
|  | dl = self._get_data_loader(ds, num_workers=num_workers, batch_size=None, sampler=sampler, pin_memory=TEST_CUDA) | 
|  | self.assertFalse(dl._auto_collation) | 
|  | samples = list(dl) | 
|  | self.assertEqual(samples[0].is_pinned(), TEST_CUDA) | 
|  | self.assertEqual(set(torch.cat(samples, 0).tolist()), set(range(n))) | 
|  |  | 
|  | def test_growing_dataset(self): | 
|  | dataset = [torch.ones(4) for _ in range(4)] | 
|  | dataloader_seq = self._get_data_loader(dataset, shuffle=False) | 
|  | dataloader_shuffle = self._get_data_loader(dataset, shuffle=True) | 
|  | dataset.append(torch.ones(4)) | 
|  | self.assertEqual(len(dataloader_seq), 5) | 
|  | self.assertEqual(len(dataloader_shuffle), 5) | 
|  |  | 
|  | @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") | 
|  | def test_sequential_pin_memory(self): | 
|  | loader = self._get_data_loader(self.dataset, batch_size=2, pin_memory=True) | 
|  | for input, target in loader: | 
|  | self.assertTrue(input.is_pinned()) | 
|  | self.assertTrue(target.is_pinned()) | 
|  |  | 
|  | @unittest.skipIf(IS_JETSON, "Not working on Jetson") | 
|  | def test_multiple_dataloaders(self): | 
|  | for multiprocessing_context in supported_multiprocessing_contexts: | 
|  | loader1_it = iter(self._get_data_loader(self.dataset, num_workers=1)) | 
|  | loader2_it = iter(self._get_data_loader(self.dataset, num_workers=2, multiprocessing_context=multiprocessing_context)) | 
|  | next(loader1_it) | 
|  | next(loader1_it) | 
|  | next(loader2_it) | 
|  | next(loader2_it) | 
|  | next(loader1_it) | 
|  | next(loader2_it) | 
|  | del loader1_it | 
|  | del loader2_it | 
|  |  | 
|  | def test_segfault(self): | 
|  | p = ErrorTrackingProcess(target=_test_segfault) | 
|  | p.start() | 
|  | p.join(JOIN_TIMEOUT) | 
|  | try: | 
|  | self.assertFalse(p.is_alive()) | 
|  | self.assertNotEqual(p.exitcode, 0) | 
|  | if IS_WINDOWS: | 
|  | self.assertIsInstance(p.exception, OSError) | 
|  | self.assertRegex(str(p.exception), r'access violation reading ') | 
|  | else: | 
|  | self.assertIsInstance(p.exception, RuntimeError) | 
|  | self.assertRegex(str(p.exception), r'DataLoader worker \(pid \d+\) is killed by signal: ') | 
|  | finally: | 
|  | p.terminate() | 
|  |  | 
|  | # Tests if the child process forked by the DataLoader segfaults due to having more than 3 threads | 
|  | # in the parent process after at least one set_num_threads invocation in the parent process. | 
|  | # After forking, set_num_threads(1) in the child process entails handling some inherited data-structures | 
|  | # of the Caffe2 thread-pool of the parent process, culminating in a segfault. | 
|  | # Reference: https://github.com/pytorch/pytorch/issues/54752 | 
|  | @unittest.skipIf(IS_WINDOWS, "Needs fork") | 
|  | def test_no_segfault(self): | 
|  | p = ErrorTrackingProcess(target=_test_no_segfault) | 
|  | p.start() | 
|  | p.join(JOIN_TIMEOUT) | 
|  | try: | 
|  | self.assertFalse(p.is_alive()) | 
|  | if p.exception: | 
|  | self.assertIsInstance(p.exception, RuntimeError) | 
|  | self.assertRegex(str(p.exception), r'DataLoader worker \(pid \d+\) is killed by signal: ') | 
|  | self.fail("Segfault occurred in worker process after fork") | 
|  | finally: | 
|  | p.terminate() | 
|  |  | 
|  | def test_timeout(self): | 
|  | if TEST_CUDA and not NO_MULTIPROCESSING_SPAWN: | 
|  | # This test runs in a subprocess, which can only initialize CUDA with spawn. | 
|  | # _test_timeout_pin_memory with pin_memory=True initializes CUDA when the iterator is | 
|  | # constructed. | 
|  | targets = (_test_timeout, _test_timeout_pin_memory) | 
|  | else: | 
|  | targets = (_test_timeout,) | 
|  | for target in targets: | 
|  | p = ErrorTrackingProcess(target=target, args=(self.persistent_workers,)) | 
|  | p.start() | 
|  | p.join(JOIN_TIMEOUT) | 
|  | try: | 
|  | self.assertFalse(p.is_alive()) | 
|  | self.assertNotEqual(p.exitcode, 0) | 
|  | self.assertIsInstance(p.exception, RuntimeError) | 
|  | self.assertRegex(str(p.exception), r'DataLoader timed out after \d+ seconds') | 
|  | finally: | 
|  | p.terminate() | 
|  |  | 
|  | def test_large_sampler_indices(self): | 
|  | # Test that the data loader cleanly exit when the process errors | 
|  | #   1. having an reference to the iterator | 
|  | #   2. using a sampler that yields big elements s.t. _index_queues putters block | 
|  | # | 
|  | # More context: https://github.com/pytorch/pytorch/issues/48666 | 
|  |  | 
|  | p = ErrorTrackingProcess(target=_test_large_sampler_indices, args=(self.persistent_workers,)) | 
|  | p.start() | 
|  | p.join(JOIN_TIMEOUT) | 
|  | try: | 
|  | self.assertFalse(p.is_alive()) | 
|  | self.assertNotEqual(p.exitcode, 0) | 
|  | self.assertIsInstance(p.exception, RuntimeError) | 
|  | self.assertRegex(str(p.exception), r'My Error') | 
|  | finally: | 
|  | p.terminate() | 
|  |  | 
|  | def test_invalid_ctor_args_combinations(self): | 
|  | # general | 
|  | with self.assertRaisesRegex(ValueError, "num_workers option should be non-negative"): | 
|  | self._get_data_loader(self.dataset, num_workers=-1) | 
|  | with self.assertRaisesRegex(ValueError, "timeout option should be non-negative"): | 
|  | self._get_data_loader(self.dataset, timeout=-1) | 
|  |  | 
|  | # disable auto-batching | 
|  | with self.assertRaisesRegex(ValueError, | 
|  | "batch_size=None option disables auto-batching and is mutually exclusive"): | 
|  | self._get_data_loader(self.dataset, batch_size=None, drop_last=True) | 
|  |  | 
|  | valid_ctx = list(torch.multiprocessing.get_all_start_methods())[-1] | 
|  | with self.assertRaisesRegex(ValueError, r"multi-process loading \(num_workers > 0\), but got"): | 
|  | self._get_data_loader(self.dataset, num_workers=0, multiprocessing_context=valid_ctx) | 
|  | with self.assertRaisesRegex(ValueError, "should specify a valid start method in"): | 
|  | self._get_data_loader(self.dataset, num_workers=1, multiprocessing_context='bad') | 
|  | with self.assertRaisesRegex(TypeError, "multiprocessing_context option should be a valid context "): | 
|  | self._get_data_loader(self.dataset, num_workers=1, multiprocessing_context=object()) | 
|  |  | 
|  | # map-style | 
|  | sampler = torch.utils.data.SequentialSampler(self.dataset) | 
|  | batch_sampler = torch.utils.data.BatchSampler(sampler, 3, False) | 
|  | with self.assertRaisesRegex(ValueError, "sampler option is mutually exclusive with shuffle"): | 
|  | self._get_data_loader(self.dataset, batch_size=11, sampler=sampler, shuffle=True) | 
|  | with self.assertRaisesRegex(ValueError, "sampler option is mutually exclusive with shuffle"): | 
|  | self._get_data_loader(self.dataset, batch_sampler=batch_sampler, sampler=sampler, shuffle=True) | 
|  | with self.assertRaisesRegex(ValueError, "sampler option is mutually exclusive with shuffle"): | 
|  | self._get_data_loader(self.dataset, batch_sampler=batch_sampler, sampler=sampler, shuffle=3) | 
|  | with self.assertRaisesRegex(ValueError, "batch_sampler option is mutually exclusive with"): | 
|  | self._get_data_loader(self.dataset, batch_size=11, batch_sampler=batch_sampler) | 
|  | with self.assertRaisesRegex(ValueError, "batch_sampler option is mutually exclusive with"): | 
|  | self._get_data_loader(self.dataset, shuffle=True, batch_sampler=batch_sampler) | 
|  | with self.assertRaisesRegex(ValueError, "batch_sampler option is mutually exclusive with"): | 
|  | self._get_data_loader(self.dataset, drop_last=True, batch_sampler=batch_sampler) | 
|  | with self.assertRaisesRegex(ValueError, "batch_sampler option is mutually exclusive with"): | 
|  | self._get_data_loader(self.dataset, drop_last=3, batch_sampler=batch_sampler) | 
|  |  | 
|  | # iterable-style | 
|  | dataset = CountingIterableDataset(20) | 
|  | with self.assertRaisesRegex(ValueError, "DataLoader with IterableDataset: expected unspecified shuffle"): | 
|  | self._get_data_loader(dataset, shuffle=True) | 
|  | with self.assertRaisesRegex(ValueError, "DataLoader with IterableDataset: expected unspecified shuffle"): | 
|  | self._get_data_loader(dataset, shuffle=3) | 
|  | with self.assertRaisesRegex(ValueError, "DataLoader with IterableDataset: expected unspecified sampler"): | 
|  | self._get_data_loader(dataset, sampler=torch.utils.data.SequentialSampler(dataset)) | 
|  | with self.assertRaisesRegex(ValueError, "DataLoader with IterableDataset: expected unspecified sampler"): | 
|  | self._get_data_loader(dataset, sampler=3) | 
|  | with self.assertRaisesRegex(ValueError, "DataLoader with IterableDataset: expected unspecified batch_sampler"): | 
|  | self._get_data_loader(dataset, batch_sampler=torch.utils.data.BatchSampler( | 
|  | torch.utils.data.SequentialSampler(dataset), 3, False)) | 
|  | with self.assertRaisesRegex(ValueError, "DataLoader with IterableDataset: expected unspecified batch_sampler"): | 
|  | self._get_data_loader(dataset, batch_sampler=3) | 
|  |  | 
|  | def test_builtin_collection_conversion(self): | 
|  | for coll_ty in (list, tuple): | 
|  | for num_workers in (0, 1): | 
|  | # map-style dataset | 
|  | dataset = CountingDataset(20) | 
|  | # no auto-batching | 
|  | fetched = coll_ty(self._get_data_loader(dataset, batch_size=None, num_workers=num_workers)) | 
|  | self.assertEqual(fetched, coll_ty(range(20))) | 
|  | # auto-batching | 
|  | fetched = coll_ty(self._get_data_loader(dataset, batch_size=2, num_workers=num_workers)) | 
|  | self.assertEqual(fetched, coll_ty(torch.tensor([i, i + 1]) for i in range(0, 20, 2))) | 
|  |  | 
|  | # iterable-style dataset | 
|  | dataset = CountingIterableDataset(20) | 
|  | # no auto-batching | 
|  | fetched = coll_ty(self._get_data_loader(dataset, batch_size=None, num_workers=num_workers)) | 
|  | self.assertEqual(fetched, coll_ty(range(20))) | 
|  | # auto-batching | 
|  | # this IterableDataset isn't configured for each worker, so for | 
|  | # the equality test below to be valid, we cannot have more than 1 workers. | 
|  | assert num_workers in [0, 1], "invalid test" | 
|  | fetched = coll_ty(self._get_data_loader(dataset, batch_size=2, num_workers=num_workers)) | 
|  | self.assertEqual(fetched, coll_ty(torch.tensor([i, i + 1]) for i in range(0, 20, 2))) | 
|  |  | 
|  | def test_iterable_style_dataset(self): | 
|  | # [no auto-batching] single process loading | 
|  | dataset = CountingIterableDataset(20) | 
|  | dataloader = self._get_data_loader(dataset, batch_size=None) | 
|  | fetched = list(dataloader) | 
|  | self.assertEqual(len(fetched), 20) | 
|  | for i, d in enumerate(fetched): | 
|  | # non-batched should not convert ints into tensors | 
|  | self.assertIsInstance(d, int) | 
|  | self.assertEqual(d, i) | 
|  | # DataLoader should match len of the iterable-style dataset (if implemented) | 
|  | self.assertEqual(len(dataloader), len(dataset)) | 
|  |  | 
|  | # [no auto-batching] multiprocessing loading | 
|  | num_workers = 3 | 
|  | sizes_for_all_workers = [0, 4, 20] | 
|  | expected = sorted(sum((list(range(s)) for s in sizes_for_all_workers), [])) | 
|  | assert len(sizes_for_all_workers) == num_workers, 'invalid test case' | 
|  | for prefetch_factor in [2, 3, 4]: | 
|  | dataset = WorkerSpecificIterableDataset(sizes_for_all_workers) | 
|  | dataloader = self._get_data_loader(dataset, num_workers=num_workers, batch_size=None, | 
|  | worker_init_fn=set_faulthander_if_available, | 
|  | prefetch_factor=prefetch_factor) | 
|  | dataloader_iter = iter(dataloader) | 
|  | fetched = sorted(dataloader_iter) | 
|  | for a, b in zip(fetched, expected): | 
|  | # non-batched should not convert ints into tensors | 
|  | self.assertIsInstance(a, int) | 
|  | self.assertEqual(a, b) | 
|  | # DataLoader should match len of the iterable-style dataset (if implemented) | 
|  | self.assertEqual(len(dataloader), len(dataset)) | 
|  | # When loading more than len(dataset) data, after accessing len(dataloader), | 
|  | # we should get a warning. See NOTE [ IterableDataset and __len__ ]. | 
|  | dataset = CountingIterableDataset(20) | 
|  | dataloader = self._get_data_loader(dataset, num_workers=num_workers, | 
|  | worker_init_fn=set_faulthander_if_available, | 
|  | prefetch_factor=prefetch_factor) | 
|  | it = iter(dataloader) | 
|  | for _ in range(40): | 
|  | self.assertNotWarn(lambda: next(it), "Should not warn before accessing len(dataloader)") | 
|  | self.assertEqual(len(dataloader), len(dataset)) | 
|  | self.assertEqual(len(dataloader), 20) | 
|  | it = iter(dataloader) | 
|  | for _ in range(20): | 
|  | self.assertNotWarn(lambda: next(it), "Should not warn before exceeding length") | 
|  | for _ in range(3): | 
|  | with self.assertWarnsRegex( | 
|  | UserWarning, | 
|  | r"but [0-9]+ samples have been fetched\. For multiprocessing data-loading, this", | 
|  | msg="Should always warn after exceeding length"): | 
|  | next(it) | 
|  | # [no auto-batching] test that workers exit gracefully | 
|  | workers = dataloader_iter._workers | 
|  | del dataloader_iter | 
|  | del dataloader | 
|  | try: | 
|  | for w in workers: | 
|  | w.join(JOIN_TIMEOUT) | 
|  | self.assertFalse(w.is_alive()) | 
|  | self.assertEqual(w.exitcode, 0) | 
|  | finally: | 
|  | for w in workers: | 
|  | w.terminate() | 
|  |  | 
|  | # [auto-batching] single process loading | 
|  | dataset = CountingIterableDataset(20) | 
|  | fetched = list(self._get_data_loader(dataset, batch_size=7)) | 
|  | self.assertEqual(len(fetched), 3) | 
|  | self.assertEqual(fetched[0].tolist(), list(range(7))) | 
|  | self.assertEqual(fetched[1].tolist(), list(range(7, 14))) | 
|  | self.assertEqual(fetched[2].tolist(), list(range(14, 20))) | 
|  |  | 
|  | # [auto-batching] multiprocessing loading | 
|  | num_workers = 3 | 
|  | sizes_for_all_workers = [0, 4, 20] | 
|  | expected = sorted(sum((list(range(s)) for s in sizes_for_all_workers), [])) | 
|  | assert len(sizes_for_all_workers) == num_workers, 'invalid test case' | 
|  | for prefetch_factor in [2, 3, 4]: | 
|  | dataset = WorkerSpecificIterableDataset(sizes_for_all_workers) | 
|  | # worker 0 should return 0 batches | 
|  | # worker 1 should return 1 batches | 
|  | # worker 2 should return 3 batches | 
|  | dataloader = self._get_data_loader(dataset, num_workers=num_workers, batch_size=7, prefetch_factor=prefetch_factor) | 
|  | dataloader_iter = iter(dataloader) | 
|  | fetched = list(dataloader_iter) | 
|  | self.assertEqual(len(fetched), 4) | 
|  | fetched = {tuple(t.tolist()) for t in fetched} | 
|  | self.assertEqual(fetched, {tuple(range(4)), tuple(range(7)), tuple(range(7, 14)), tuple(range(14, 20))}) | 
|  |  | 
|  | # [auto-batching] test that workers exit gracefully | 
|  | workers = dataloader_iter._workers | 
|  | del dataloader_iter | 
|  | del dataloader | 
|  | try: | 
|  | for w in workers: | 
|  | w.join(JOIN_TIMEOUT) | 
|  | self.assertFalse(w.is_alive()) | 
|  | self.assertEqual(w.exitcode, 0) | 
|  | finally: | 
|  | for w in workers: | 
|  | w.terminate() | 
|  | # [auto-batching & drop_last] single process loading | 
|  | dataset = CountingIterableDataset(20) | 
|  | fetched = list(self._get_data_loader(dataset, batch_size=7, drop_last=True)) | 
|  | self.assertEqual(len(fetched), 2) | 
|  | self.assertEqual(fetched[0].tolist(), list(range(7))) | 
|  | self.assertEqual(fetched[1].tolist(), list(range(7, 14))) | 
|  |  | 
|  | # [auto-batching & drop_last] multiprocessing loading | 
|  | num_workers = 3 | 
|  | sizes_for_all_workers = [0, 4, 20] | 
|  | expected = sorted(sum((list(range(s)) for s in sizes_for_all_workers), [])) | 
|  | assert len(sizes_for_all_workers) == num_workers, 'invalid test case' | 
|  | for prefetch_factor in [2, 3, 4]: | 
|  | dataset = WorkerSpecificIterableDataset(sizes_for_all_workers) | 
|  | # worker 0 should return 0 batches | 
|  | # worker 1 should return 1 batches | 
|  | # worker 2 should return 3 batches | 
|  | dataloader = self._get_data_loader(dataset, num_workers=num_workers, batch_size=7, drop_last=True, | 
|  | worker_init_fn=set_faulthander_if_available, | 
|  | prefetch_factor=prefetch_factor) | 
|  | dataloader_iter = iter(dataloader) | 
|  | fetched = list(dataloader_iter) | 
|  | self.assertEqual(len(fetched), 2) | 
|  | fetched = {tuple(t.tolist()) for t in fetched} | 
|  | self.assertEqual(fetched, {tuple(range(7)), tuple(range(7, 14))}) | 
|  |  | 
|  | # [auto-batching & drop_last] test that workers exit gracefully | 
|  | workers = dataloader_iter._workers | 
|  | del dataloader_iter | 
|  | del dataloader | 
|  | try: | 
|  | for w in workers: | 
|  | w.join(JOIN_TIMEOUT) | 
|  | self.assertFalse(w.is_alive()) | 
|  | self.assertEqual(w.exitcode, 0) | 
|  | finally: | 
|  | for w in workers: | 
|  | w.terminate() | 
|  |  | 
|  | def test_chain_iterable_style_dataset(self): | 
|  | # chaining (concatenation) | 
|  | dataset1 = CountingIterableDataset(20) | 
|  | dataset2 = CountingIterableDataset(15) | 
|  | expected = list(range(20)) + list(range(15)) | 
|  | for num_workers in [0, 1]: | 
|  | for chained_dataset in [dataset1 + dataset2, ChainDataset([dataset1, dataset2])]: | 
|  | fetched = list(self._get_data_loader(chained_dataset, num_workers=num_workers)) | 
|  | self.assertEqual(len(fetched), len(expected)) | 
|  | for e, d in zip(expected, fetched): | 
|  | self.assertIsInstance(d, torch.Tensor) | 
|  | self.assertEqual(e, d) | 
|  |  | 
|  | with self.assertRaisesRegex(AssertionError, "ChainDataset only supports IterableDataset"): | 
|  | list(iter(dataset1 + self.dataset)) | 
|  |  | 
|  | with self.assertRaisesRegex(AssertionError, "ChainDataset only supports IterableDataset"): | 
|  | list(iter(ChainDataset([dataset1, self.dataset]))) | 
|  |  | 
|  | @unittest.skipIf(IS_MACOS, "Not working on macos") | 
|  | @unittest.skipIf(IS_MACOS or IS_JETSON, "Not working on macos or Jetson") | 
|  | @skipIfRocm  # https://github.com/pytorch/pytorch/issues/90940 | 
|  | def test_multiprocessing_contexts(self): | 
|  | reference = [ | 
|  | torch.arange(3), | 
|  | torch.arange(3, 6), | 
|  | torch.arange(6, 9), | 
|  | torch.arange(9, 11), | 
|  | ] | 
|  | counting_ds_n = 11 | 
|  | dl_common_args = dict(num_workers=3, batch_size=3, pin_memory=(not TEST_CUDA)) | 
|  | for ctx in supported_multiprocessing_contexts: | 
|  | # windows and jetson devices don't support sharing cuda tensor; ROCm does not yet fully support IPC | 
|  | if ctx in ['spawn', 'forkserver'] and TEST_CUDA and not IS_WINDOWS and not IS_JETSON: | 
|  | ds_cls = CUDACountingDataset | 
|  | else: | 
|  | ds_cls = CountingDataset | 
|  | self.assertEqual( | 
|  | reference, list(self._get_data_loader(ds_cls(counting_ds_n), multiprocessing_context=ctx, **dl_common_args))) | 
|  | if ctx is not None: | 
|  | # test ctx object | 
|  | ctx = mp.get_context(ctx) | 
|  | self.assertEqual( | 
|  | reference, list(self._get_data_loader(ds_cls(counting_ds_n), multiprocessing_context=ctx, **dl_common_args))) | 
|  |  | 
|  | @skipIfNoNumpy | 
|  | @unittest.skipIf(IS_JETSON, "Not working on Jetson") | 
|  | def test_multiprocessing_iterdatapipe(self): | 
|  | # Testing to make sure that function from global scope (e.g. imported from library) can be serialized | 
|  | # and used with multiprocess DataLoader | 
|  |  | 
|  | reference = [torch.as_tensor([[2, 3, 4, 5]], dtype=torch.int64), | 
|  | torch.as_tensor([[2, 3, 4, 5]], dtype=torch.int64)] | 
|  | datapipe: IterDataPipe = IterableWrapper([[1, 2, 3, 4], [1, 2, 3, 4, 5, 6]]) | 
|  | datapipe = datapipe.map(row_processor) | 
|  | datapipe = datapipe.filter(lambda row: len(row) == 4) if HAS_DILL else datapipe.filter(filter_len) | 
|  |  | 
|  | dl_common_args = dict(num_workers=2, batch_size=2, shuffle=True, pin_memory=(not TEST_CUDA)) | 
|  | for ctx in supported_multiprocessing_contexts: | 
|  | self.assertEqual(reference, | 
|  | [t.type(torch.int64) | 
|  | for t in self._get_data_loader(datapipe, multiprocessing_context=ctx, **dl_common_args)]) | 
|  | if ctx is not None: | 
|  | # test ctx object | 
|  | ctx = mp.get_context(ctx) | 
|  | self.assertEqual(reference, | 
|  | [t.type(torch.int64) | 
|  | for t in | 
|  | self._get_data_loader(datapipe, multiprocessing_context=ctx, **dl_common_args)]) | 
|  |  | 
|  | def test_worker_seed(self): | 
|  | num_workers = 6 | 
|  | batch_size = 1 | 
|  | dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers) | 
|  | dataloader = self._get_data_loader(dataset, batch_size=batch_size, num_workers=num_workers) | 
|  | seeds = set() | 
|  | for batch in dataloader: | 
|  | seeds.add(batch[0]) | 
|  | self.assertEqual(len(seeds), num_workers) | 
|  |  | 
|  | def test_worker_seed_reproducibility(self): | 
|  | def get_dataloader(): | 
|  | return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=torch.Generator().manual_seed(42)) | 
|  |  | 
|  | num_workers = 6 | 
|  | batch_size = 1 | 
|  | dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers) | 
|  | self.assertEqual({int(batch) for batch in get_dataloader()}, {int(batch) for batch in get_dataloader()}) | 
|  |  | 
|  | def test_multi_epochs_reproducibility(self): | 
|  | num_workers = 2 | 
|  | batch_size = 10 | 
|  | num_epochs = 3 | 
|  |  | 
|  | dataset = TestMultiEpochDataset(batch_size * num_workers) | 
|  | dataloader = self._get_data_loader(dataset, batch_size=batch_size, | 
|  | shuffle=False, num_workers=num_workers) | 
|  |  | 
|  | for ind in range(num_epochs): | 
|  | for batch_idx, sample in enumerate(dataloader): | 
|  | self.assertEqual(sample.tolist(), [batch_idx % num_workers] * batch_size) | 
|  |  | 
|  | def test_worker_init_fn(self): | 
|  | dataset = SeedDataset(4) | 
|  | dataloader = self._get_data_loader(dataset, batch_size=2, num_workers=2, | 
|  | worker_init_fn=init_fn) | 
|  | for batch in dataloader: | 
|  | self.assertEqual(12345, batch[0]) | 
|  | self.assertEqual(12345, batch[1]) | 
|  |  | 
|  | def test_get_worker_info(self): | 
|  | p = ErrorTrackingProcess(target=_test_get_worker_info) | 
|  | p.start() | 
|  | p.join(JOIN_TIMEOUT) | 
|  | try: | 
|  | self.assertFalse(p.is_alive()) | 
|  | self.assertEqual(p.exitcode, 0) | 
|  | finally: | 
|  | p.terminate() | 
|  |  | 
|  | def test_shuffle(self): | 
|  | self._test_shuffle(self._get_data_loader(self.dataset, shuffle=True)) | 
|  |  | 
|  | def test_shuffle_batch_none(self): | 
|  | self._test_shuffle(DataLoader(self.dataset, batch_size=None, shuffle=True)) | 
|  |  | 
|  | def test_shuffle_batch(self): | 
|  | self._test_shuffle(self._get_data_loader(self.dataset, batch_size=2, shuffle=True)) | 
|  |  | 
|  | def test_shuffle_reproducibility(self): | 
|  | for fn in ( | 
|  | lambda: DataLoader(self.dataset, shuffle=True, num_workers=0, generator=torch.Generator().manual_seed(42)), | 
|  | lambda: DataLoader(self.dataset, shuffle=True, num_workers=2, generator=torch.Generator().manual_seed(42)), | 
|  | ): | 
|  | self.assertEqual(list(fn()), list(fn())) | 
|  |  | 
|  | def test_sequential_workers(self): | 
|  | self._test_sequential(self._get_data_loader(self.dataset, num_workers=4)) | 
|  |  | 
|  | def test_seqential_batch_workers(self): | 
|  | self._test_sequential(self._get_data_loader(self.dataset, batch_size=2, num_workers=4)) | 
|  |  | 
|  | def test_seqential_batch_workers_prefetch(self): | 
|  | self._test_sequential(DataLoader(self.dataset, batch_size=2, num_workers=4, prefetch_factor=3)) | 
|  |  | 
|  | def test_shuffle_workers(self): | 
|  | self._test_shuffle(self._get_data_loader(self.dataset, shuffle=True, num_workers=4)) | 
|  |  | 
|  | def test_shuffle_batch_workers(self): | 
|  | self._test_shuffle(self._get_data_loader(self.dataset, batch_size=2, shuffle=True, num_workers=4)) | 
|  |  | 
|  | def test_shuffle_batch_workers_prefetch(self): | 
|  | self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, prefetch_factor=3)) | 
|  |  | 
|  | def test_random_sampler(self): | 
|  |  | 
|  | from collections import Counter | 
|  | from torch.utils.data import RandomSampler | 
|  |  | 
|  | def sample_stat(sampler, num_samples): | 
|  | counts = Counter(sampler) | 
|  | count_repeated = sum(val > 1 for val in counts.values()) | 
|  | return (count_repeated, min(counts.keys()), max(counts.keys()), sum(counts.values())) | 
|  |  | 
|  | # test sample with replacement | 
|  | n = len(self.dataset) + 1  # ensure at least one sample is drawn more than once | 
|  | sampler_with_replacement = RandomSampler(self.dataset, replacement=True, num_samples=n) | 
|  | count_repeated, minval, maxval, count_total = sample_stat(sampler_with_replacement, n) | 
|  | self.assertTrue(count_repeated > 0) | 
|  | self.assertTrue(minval >= 0) | 
|  | self.assertTrue(maxval < len(self.dataset)) | 
|  | self.assertTrue(count_total == n) | 
|  |  | 
|  | # test sample without replacement and without specified num_samples | 
|  | sampler_without_replacement = RandomSampler(self.dataset) | 
|  | count_repeated, minval, maxval, count_total = sample_stat(sampler_without_replacement, len(self.dataset)) | 
|  | self.assertTrue(count_repeated == 0) | 
|  | self.assertTrue(minval == 0) | 
|  | self.assertTrue(maxval == len(self.dataset) - 1) | 
|  | self.assertTrue(count_total == len(self.dataset)) | 
|  |  | 
|  | # test sample without replacement and with specified num_samples | 
|  | n = len(self.dataset) * 2 | 
|  | sampler_without_replacement = RandomSampler(self.dataset, num_samples=n) | 
|  | count_repeated, minval, maxval, count_total = sample_stat(sampler_without_replacement, len(self.dataset)) | 
|  | self.assertTrue(count_repeated == len(self.dataset)) | 
|  | self.assertTrue(minval == 0) | 
|  | self.assertTrue(maxval == len(self.dataset) - 1) | 
|  | self.assertTrue(count_total == n) | 
|  |  | 
|  | n = len(self.dataset) - 1 | 
|  | sampler_without_replacement = RandomSampler(self.dataset, num_samples=n) | 
|  | count_repeated, minval, maxval, count_total = sample_stat(sampler_without_replacement, len(self.dataset)) | 
|  | self.assertTrue(count_repeated == 0) | 
|  | self.assertTrue(minval >= 0) | 
|  | self.assertTrue(maxval < len(self.dataset)) | 
|  | self.assertTrue(count_total == n) | 
|  |  | 
|  | n = len(self.dataset) + 1 | 
|  | sampler_without_replacement = RandomSampler(self.dataset, num_samples=n) | 
|  | count_repeated, minval, maxval, count_total = sample_stat(sampler_without_replacement, len(self.dataset)) | 
|  | self.assertTrue(count_repeated == 1) | 
|  | self.assertTrue(minval == 0) | 
|  | self.assertTrue(maxval == len(self.dataset) - 1) | 
|  | self.assertTrue(count_total == n) | 
|  |  | 
|  | # raise error when replacement is non-boolean | 
|  | with self.assertRaisesRegex(TypeError, "replacement should be a boolean value, but got replacement=0"): | 
|  | RandomSampler(self.dataset, replacement=0) | 
|  |  | 
|  | def test_random_sampler_len_with_replacement(self): | 
|  | from torch.utils.data import RandomSampler | 
|  | # add 5 extra samples | 
|  | num_samples = len(self.dataset) + 5 | 
|  | sampler = RandomSampler(self.dataset, | 
|  | replacement=True, | 
|  | num_samples=num_samples) | 
|  | # test len method | 
|  | self.assertEqual(num_samples, len(sampler)) | 
|  |  | 
|  | # test with iteration | 
|  | count_num_samples = sum(1 for _ in sampler) | 
|  | self.assertEqual(num_samples, count_num_samples) | 
|  |  | 
|  | # test with dataloader, batch_size = 1 | 
|  | batch_size = 1 | 
|  | count_num_samples_in_data_loader = len(self._get_data_loader( | 
|  | self.dataset, batch_size=batch_size, sampler=sampler)) | 
|  | self.assertEqual(num_samples, count_num_samples_in_data_loader) | 
|  |  | 
|  | # test with dataloader, batch_size = 6 | 
|  | batch_size = 6 | 
|  | count_num_samples_in_data_loader = len(self._get_data_loader( | 
|  | self.dataset, batch_size=batch_size, sampler=sampler)) | 
|  | self.assertEqual(int(math.ceil(float(num_samples) / batch_size)), | 
|  | count_num_samples_in_data_loader) | 
|  |  | 
|  | def test_random_sampler_len_without_replacement(self): | 
|  | from torch.utils.data import RandomSampler | 
|  | # add 5 extra samples | 
|  | num_samples = len(self.dataset) + 5 | 
|  | sampler = RandomSampler(self.dataset, | 
|  | replacement=False, | 
|  | num_samples=num_samples) | 
|  | # test len method | 
|  | self.assertEqual(num_samples, len(sampler)) | 
|  |  | 
|  | # test with iteration | 
|  | count_num_samples = sum(1 for _ in sampler) | 
|  | self.assertEqual(num_samples, count_num_samples) | 
|  |  | 
|  | # test with dataloader, batch_size = 1 | 
|  | batch_size = 1 | 
|  | count_num_samples_in_data_loader = len(self._get_data_loader( | 
|  | self.dataset, batch_size=batch_size, sampler=sampler)) | 
|  | self.assertEqual(num_samples, count_num_samples_in_data_loader) | 
|  |  | 
|  | # test with dataloader, batch_size = 6 | 
|  | batch_size = 6 | 
|  | count_num_samples_in_data_loader = len(self._get_data_loader( | 
|  | self.dataset, batch_size=batch_size, sampler=sampler)) | 
|  | self.assertEqual(num_samples // batch_size + (num_samples % batch_size > 0), | 
|  | count_num_samples_in_data_loader) | 
|  |  | 
|  | def test_distributed_sampler_invalid_rank(self): | 
|  | from torch.utils.data.distributed import DistributedSampler | 
|  | dataset = torch.IntTensor(range(10)) | 
|  | with self.assertRaisesRegex(ValueError, "Invalid rank"): | 
|  | sampler = DistributedSampler(dataset, 3, 3) | 
|  |  | 
|  | with self.assertRaisesRegex(ValueError, "Invalid rank"): | 
|  | sampler = DistributedSampler(dataset, 3, -1) | 
|  |  | 
|  | def test_duplicating_data_with_drop_last(self): | 
|  |  | 
|  | from torch.utils.data.distributed import DistributedSampler | 
|  |  | 
|  | num_processes = 4 | 
|  | num_batches = 9 | 
|  | data_set = torch.IntTensor(range(num_batches)) | 
|  | scanned_data = torch.IntTensor([]) | 
|  | for i in range(num_processes): | 
|  | s = DistributedSampler(data_set, num_processes, i) | 
|  | d_loader = self._get_data_loader(data_set, batch_size=int(num_batches / num_processes), drop_last=True, sampler=s) | 
|  | for data in d_loader: | 
|  | scanned_data = torch.cat((scanned_data, data), 0) | 
|  |  | 
|  | self.assertEqual(scanned_data.size(), scanned_data.unique().size()) | 
|  |  | 
|  | def test_sampler_reproducibility(self): | 
|  | from torch.utils.data import RandomSampler, WeightedRandomSampler, SubsetRandomSampler | 
|  |  | 
|  | weights = [0.1, 0.9, 0.4, 0.7, 3.0, 0.6] | 
|  | for fn in ( | 
|  | lambda: RandomSampler(self.dataset, num_samples=5, replacement=True, generator=torch.Generator().manual_seed(42)), | 
|  | lambda: RandomSampler(self.dataset, replacement=False, generator=torch.Generator().manual_seed(42)), | 
|  | lambda: WeightedRandomSampler(weights, num_samples=5, replacement=True, generator=torch.Generator().manual_seed(42)), | 
|  | lambda: WeightedRandomSampler(weights, num_samples=5, replacement=False, generator=torch.Generator().manual_seed(42)), | 
|  | lambda: SubsetRandomSampler(range(10), generator=torch.Generator().manual_seed(42)), | 
|  | ): | 
|  | self.assertEqual(list(fn()), list(fn())) | 
|  |  | 
|  | for sampler in ( | 
|  | RandomSampler(self.dataset, num_samples=5, replacement=True), | 
|  | RandomSampler(self.dataset, replacement=False), | 
|  | WeightedRandomSampler(weights, num_samples=5, replacement=True), | 
|  | WeightedRandomSampler(weights, num_samples=5, replacement=False), | 
|  | SubsetRandomSampler(range(10)), | 
|  | ): | 
|  | torch.manual_seed(0) | 
|  | l1 = list(sampler) + list(sampler) | 
|  |  | 
|  | torch.manual_seed(0) | 
|  | l2 = list(sampler) + list(sampler) | 
|  | self.assertEqual(l1, l2) | 
|  |  | 
|  | its = (iter(sampler), iter(sampler)) | 
|  | ls = ([], []) | 
|  | for idx in range(len(sampler)): | 
|  | for i in range(2): | 
|  | if idx == 0: | 
|  | torch.manual_seed(0) | 
|  | ls[i].append(next(its[i])) | 
|  | self.assertEqual(ls[0], ls[1]) | 
|  |  | 
|  | def _test_sampler(self, **kwargs): | 
|  | indices = range(2, 12)  # using a regular iterable | 
|  | dl = self._get_data_loader(self.dataset, sampler=indices, batch_size=2, **kwargs) | 
|  | self.assertEqual(len(dl), 5) | 
|  | for i, (input, _target) in enumerate(dl): | 
|  | self.assertEqual(len(input), 2) | 
|  | self.assertEqual(input, self.data[i * 2 + 2:i * 2 + 4]) | 
|  |  | 
|  | def test_sampler(self): | 
|  | self._test_sampler() | 
|  | self._test_sampler(num_workers=4) | 
|  | if not NO_MULTIPROCESSING_SPAWN: | 
|  | self._test_batch_sampler(num_workers=4, multiprocessing_context='spawn') | 
|  |  | 
|  | def _test_batch_sampler(self, **kwargs): | 
|  | # [(0, 1), (2, 3, 4), (5, 6), (7, 8, 9), ...] | 
|  | batches = []  # using a regular iterable | 
|  | for i in range(0, 20, 5): | 
|  | batches.append(tuple(range(i, i + 2))) | 
|  | batches.append(tuple(range(i + 2, i + 5))) | 
|  |  | 
|  | dl = self._get_data_loader(self.dataset, batch_sampler=batches, **kwargs) | 
|  | self.assertEqual(len(dl), 8) | 
|  | for i, (input, _target) in enumerate(dl): | 
|  | if i % 2 == 0: | 
|  | offset = i * 5 // 2 | 
|  | self.assertEqual(len(input), 2) | 
|  | self.assertEqual(input, self.data[offset:offset + 2]) | 
|  | else: | 
|  | offset = i * 5 // 2 | 
|  | self.assertEqual(len(input), 3) | 
|  | self.assertEqual(input, self.data[offset:offset + 3]) | 
|  |  | 
|  | def test_batch_sampler(self): | 
|  | self._test_batch_sampler() | 
|  | self._test_batch_sampler(num_workers=4) | 
|  | if not NO_MULTIPROCESSING_SPAWN: | 
|  | self._test_batch_sampler(num_workers=4, multiprocessing_context='spawn') | 
|  |  | 
|  | @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") | 
|  | def test_shuffle_pin_memory(self): | 
|  | loader = self._get_data_loader(self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True) | 
|  | for input, target in loader: | 
|  | self.assertTrue(input.is_pinned()) | 
|  | self.assertTrue(target.is_pinned()) | 
|  |  | 
|  | @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") | 
|  | def test_numpy(self): | 
|  | import numpy as np | 
|  |  | 
|  | class TestDataset(torch.utils.data.Dataset): | 
|  | def __getitem__(self, i): | 
|  | return np.ones((2, 3, 4)) * i | 
|  |  | 
|  | def __len__(self): | 
|  | return 1000 | 
|  |  | 
|  | loader = self._get_data_loader(TestDataset(), batch_size=12) | 
|  | batch = next(iter(loader)) | 
|  | self.assertIsInstance(batch, torch.DoubleTensor) | 
|  | self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4])) | 
|  |  | 
|  | @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") | 
|  | def test_numpy_gen_state(self): | 
|  | from torch.utils.data._utils.worker import _generate_state | 
|  | # Using NumPy generated states as the reference to test `_generate_state` | 
|  | # having the same result. | 
|  | # Test case: ((worker_id, base_seed), expected_state) | 
|  | test_cases = [ | 
|  | ((4, 13434589827475259383), (2884386318, 1088094898, 3523808998, 3860348662)), | 
|  | ((1, 15014285634777110771), (1934848465, 763213760, 2959016433, 179751970)), | 
|  | ((10, 978296274032934101), (1759791917, 3550927336, 1225977135, 1036538043)), | 
|  | ((12, 11868770762134256968), (3974661794, 3331131333, 3630387033, 2885815368)), | 
|  | ((9, 15378787925219019706), (3815056996, 3162224466, 2735102421, 3190253477)), | 
|  | ((5, 9055612723125076328), (3522565701, 3368424109, 959377806, 621878693)), | 
|  | ((15, 14617792358407278405), (3402479508, 1588702753, 1169536393, 3675067356)), | 
|  | ((9, 17363320784006640087), (957989458, 2518334477, 1421725660, 3086155459)), | 
|  | ((12, 480002904169484764), (2732851467, 1762620729, 4055801988, 1277640511)), | 
|  | ((15, 16803975943592702950), (3479415043, 4022359553, 295994005, 3358606349)), | 
|  | ((9, 11704776406047813044), (1968928009, 710113752, 2442656196, 1587420279)), | 
|  | ((10, 16357891985431864516), (1271733898, 4197047399, 3727213786, 2338547348)), | 
|  | ((2, 17423369006318065007), (544294336, 1911284083, 3299147734, 3231058347)), | 
|  | ((2, 2889492011444113593), (3721591783, 2595811276, 2212881745, 977682627)), | 
|  | ((0, 8979703111668486195), (4276723937, 2556068849, 2962827292, 233130238)), | 
|  | ((6, 6269787272229682235), (2548857855, 1216457374, 1012973562, 2999759647)) | 
|  | ] | 
|  |  | 
|  | for (worker_id, base_seed), exp in test_cases: | 
|  | self.assertEqual(exp, _generate_state(base_seed, worker_id)) | 
|  |  | 
|  | def test_error(self): | 
|  | self._test_error(self._get_data_loader(ErrorDataset(100), batch_size=2, shuffle=True)) | 
|  |  | 
|  | def test_error_workers(self): | 
|  | self._test_error(self._get_data_loader(ErrorDataset(41), batch_size=2, shuffle=True, num_workers=4)) | 
|  |  | 
|  | @unittest.skipIf(IS_WINDOWS, "FIXME: stuck test") | 
|  | def test_partial_workers(self): | 
|  | r"""Check that workers exit even if the iterator is not exhausted.""" | 
|  | if TEST_CUDA: | 
|  | pin_memory_configs = (True, False) | 
|  | else: | 
|  | pin_memory_configs = (False,) | 
|  |  | 
|  | for pin_memory in pin_memory_configs: | 
|  | loader = iter(self._get_data_loader(self.dataset, batch_size=2, num_workers=4, pin_memory=pin_memory)) | 
|  | workers = loader._workers | 
|  | if pin_memory: | 
|  | pin_memory_thread = loader._pin_memory_thread | 
|  | for i, _ in enumerate(loader): | 
|  | if i == 10: | 
|  | break | 
|  | assert i == 10 | 
|  | del loader | 
|  | for w in workers: | 
|  | w.join(JOIN_TIMEOUT) | 
|  | self.assertFalse(w.is_alive(), 'subprocess not terminated') | 
|  | if pin_memory: | 
|  | pin_memory_thread.join(JOIN_TIMEOUT) | 
|  | self.assertFalse(pin_memory_thread.is_alive()) | 
|  |  | 
|  | # Takes 2.5min to finish, see https://github.com/pytorch/pytorch/issues/46065 | 
|  | @skipIfRocm | 
|  | @unittest.skipIf(not HAS_PSUTIL, "psutil not found") | 
|  | @slowTest | 
|  | def test_proper_exit(self): | 
|  | (r'''There might be ConnectionResetError or leaked semaphore warning ''' | 
|  | r'''(due to dirty process exit), but they are all safe to ignore''') | 
|  |  | 
|  | # TODO: test the case where the pin_memory_thread triggers an | 
|  | #       error/fatal signal. I haven't found out how to properly do that. | 
|  |  | 
|  | for is_iterable_dataset, use_workers, pin_memory, hold_iter_reference in \ | 
|  | itertools.product([True, False], repeat=4): | 
|  |  | 
|  | # `hold_iter_reference` specifies whether we hold a reference to the | 
|  | # iterator. This is interesting because Python3 error traces holds a | 
|  | # reference to the frames, which hold references to all the local | 
|  | # variables including the iterator, and then the iterator dtor may | 
|  | # not be called before process end. It is important to see that the | 
|  | # processes still exit in both cases. | 
|  |  | 
|  | if pin_memory and (not TEST_CUDA or NO_MULTIPROCESSING_SPAWN or IS_WINDOWS): | 
|  | # This test runs in a subprocess, which can only initialize CUDA with spawn. | 
|  | # DataLoader with pin_memory=True initializes CUDA when its iterator is constructed. | 
|  | # For windows, pin_memory sometimes causes CUDA oom. | 
|  | continue | 
|  |  | 
|  | # `exit_method` controls the way the loader process ends. | 
|  | #   - `*_kill` means that `*` is killed by OS. | 
|  | #   - `*_error` means that `*` raises an error. | 
|  | #   - `None` means that no error happens. | 
|  | # In all cases, all processes should end properly. | 
|  | if use_workers: | 
|  | # TODO: Fix test for 'loader_kill' that would cause running out of shared memory. | 
|  | # Killing loader process would prevent DataLoader iterator clean up all queues | 
|  | # and worker processes | 
|  | exit_methods = [None, 'loader_error', 'worker_error', 'worker_kill'] | 
|  | persistent_workers = self.persistent_workers | 
|  | else: | 
|  | exit_methods = [None, 'loader_error', 'loader_kill'] | 
|  | persistent_workers = False | 
|  |  | 
|  | for exit_method in exit_methods: | 
|  | if exit_method == 'worker_kill': | 
|  | # FIXME: This sometimes hangs. See #16608. | 
|  | continue | 
|  |  | 
|  | desc = [] | 
|  | desc.append(f'is_iterable_dataset={is_iterable_dataset}') | 
|  | desc.append(f'use_workers={use_workers}') | 
|  | desc.append(f'pin_memory={pin_memory}') | 
|  | desc.append(f'hold_iter_reference={hold_iter_reference}') | 
|  | desc.append(f'exit_method={exit_method}') | 
|  | desc = 'test_proper_exit with ' + ', '.join(desc) | 
|  |  | 
|  | # Event that the loader process uses to signal testing process | 
|  | # that various things are setup, including that the worker pids | 
|  | # are specified in `worker_pids` array. | 
|  | loader_setup_event = mp.Event() | 
|  |  | 
|  | # Event that this process has finished setting up, and the | 
|  | # loader process can now proceed to trigger error events or | 
|  | # finish normally. | 
|  | tester_setup_event = mp.Event() | 
|  |  | 
|  | loader_p = ErrorTrackingProcess(target=_test_proper_exit, | 
|  | args=(is_iterable_dataset, use_workers, pin_memory, | 
|  | exit_method, hold_iter_reference, | 
|  | loader_setup_event, tester_setup_event, | 
|  | persistent_workers), | 
|  | disable_stderr=False) | 
|  | loader_p.start() | 
|  | loader_psutil_p = psutil.Process(loader_p.pid) | 
|  |  | 
|  | # Wait for loader process to set everything up, e.g., starting | 
|  | # workers. | 
|  | loader_setup_event.wait(timeout=JOIN_TIMEOUT) | 
|  | if not loader_setup_event.is_set(): | 
|  | fail_msg = desc + ': loader process failed to setup within given time' | 
|  | if loader_p.exception is not None: | 
|  | fail_msg += f', and had exception {loader_p.exception}' | 
|  | elif not loader_p.is_alive(): | 
|  | fail_msg += f', and exited with code {loader_p.exitcode} but had no exception' | 
|  | else: | 
|  | fail_msg += ', and is still alive.' | 
|  | if loader_p.is_alive(): | 
|  | # this may kill the process, needs to run after the above lines | 
|  | loader_p.print_traces_of_all_threads() | 
|  | self.fail(fail_msg) | 
|  |  | 
|  | # We are certain that the workers have started now. | 
|  | worker_psutil_ps = loader_psutil_p.children() | 
|  |  | 
|  | def fail(reason): | 
|  | report_psutil_attrs = ['pid', 'name', 'cpu_times', 'io_counters', | 
|  | 'memory_full_info', 'num_ctx_switches', | 
|  | 'open_files', 'threads', 'status', | 
|  | 'nice', 'ionice'] | 
|  | if reason is None: | 
|  | err_msg = desc | 
|  | else: | 
|  | err_msg = f'{desc}: {reason}' | 
|  | err_msg += '\nLoader info:\n\t' | 
|  | if loader_psutil_p.is_running(): | 
|  | err_msg += str(loader_psutil_p.as_dict(attrs=report_psutil_attrs)) | 
|  | # this may kill the process, needs to run after the above line | 
|  | loader_p.print_traces_of_all_threads() | 
|  | else: | 
|  | err_msg += f'exited with code {loader_p.exitcode}' | 
|  | if use_workers: | 
|  | err_msg += '\nWorker(s) info:' | 
|  | for idx, worker_psutil_p in enumerate(worker_psutil_ps): | 
|  | err_msg += f'\n\tWorker {idx}:\n\t\t' | 
|  | if worker_psutil_p.is_running(): | 
|  | err_msg += str(worker_psutil_p.as_dict(attrs=report_psutil_attrs)) | 
|  | # this may kill the process, needs to run after the above line | 
|  | print_traces_of_all_threads(worker_psutil_p.pid) | 
|  | else: | 
|  | err_msg += 'exited with unknown code' | 
|  | self.fail(err_msg) | 
|  |  | 
|  | tester_setup_event.set() | 
|  |  | 
|  | try: | 
|  | loader_p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL) | 
|  | if loader_p.is_alive(): | 
|  | fail_reason = 'loader process did not terminate' | 
|  | if loader_p.exception is not None: | 
|  | fail(fail_reason + f', and had exception {loader_p.exception}') | 
|  | else: | 
|  | fail(fail_reason + ', and had no exception') | 
|  | _, alive = psutil.wait_procs(worker_psutil_ps, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT)) | 
|  | if len(alive) > 0: | 
|  | fail('worker process (pid(s) {}) did not terminate'.format( | 
|  | ', '.join(str(p.pid) for p in alive))) | 
|  | if exit_method is None: | 
|  | if loader_p.exitcode != 0: | 
|  | fail(f'loader process had nonzero exitcode {loader_p.exitcode}') | 
|  | else: | 
|  | if loader_p.exitcode == 0: | 
|  | fail('loader process had zero exitcode') | 
|  | if exit_method == 'loader_error': | 
|  | if not isinstance(loader_p.exception, RuntimeError) or \ | 
|  | 'Loader error' not in str(loader_p.exception): | 
|  | fail(f'loader process did not raise expected exception, but had {loader_p.exception}') | 
|  | elif exit_method == 'worker_kill': | 
|  | if isinstance(loader_p.exception, RuntimeError): | 
|  | if 'DataLoader worker (pid' not in str(loader_p.exception): | 
|  | fail('loader process did not raise expected exception, but had {}'.format( | 
|  | loader_p.exception)) | 
|  | elif isinstance(loader_p.exception, ConnectionRefusedError): | 
|  | # Sometimes, when the worker is being killed and is freeing its | 
|  | # resources, the unpickling in loader process will be met an | 
|  | # a `ConnectionRefusedError` as it can not open a socket to receive | 
|  | # resource. In such cases, the worker may not have fully exited, | 
|  | # and the loader can't know this via `is_alive` check or `SIGCHLD` | 
|  | # handler. So we permit this as an allowed error as well. | 
|  | # After all, we are happy as long as it terminates. | 
|  | pass | 
|  | else: | 
|  | fail(f'loader process did not raise expected exception, but had {loader_p.exception}') | 
|  | elif exit_method == 'worker_error': | 
|  | if not isinstance(loader_p.exception, RuntimeError) or \ | 
|  | 'Worker error' not in str(loader_p.exception): | 
|  | fail(f'loader process did not raise expected exception, but had {loader_p.exception}') | 
|  | finally: | 
|  | loader_p.terminate() | 
|  |  | 
|  | def test_len(self): | 
|  | def check_len(dl, expected): | 
|  | self.assertEqual(len(dl), expected) | 
|  | n = 0 | 
|  | for _ in dl: | 
|  | n += 1 | 
|  | self.assertEqual(n, expected) | 
|  | check_len(self.dataset, 100) | 
|  | check_len(self._get_data_loader(self.dataset, batch_size=2), 50) | 
|  | check_len(self._get_data_loader(self.dataset, batch_size=3), 34) | 
|  |  | 
|  | def test_iterabledataset_len(self): | 
|  | class IterableDataset(torch.utils.data.IterableDataset): | 
|  | def __len__(self): | 
|  | return 10 | 
|  |  | 
|  | def __iter__(self): | 
|  | return iter(range(10)) | 
|  |  | 
|  | iterable_loader = DataLoader(IterableDataset(), batch_size=1) | 
|  | self.assertEqual(len(iterable_loader), 10) | 
|  | iterable_loader = DataLoader(IterableDataset(), batch_size=1, drop_last=True) | 
|  | self.assertEqual(len(iterable_loader), 10) | 
|  |  | 
|  | iterable_loader = DataLoader(IterableDataset(), batch_size=2) | 
|  | self.assertEqual(len(iterable_loader), 5) | 
|  | iterable_loader = DataLoader(IterableDataset(), batch_size=2, drop_last=True) | 
|  | self.assertEqual(len(iterable_loader), 5) | 
|  |  | 
|  | iterable_loader = DataLoader(IterableDataset(), batch_size=3) | 
|  | self.assertEqual(len(iterable_loader), 4) | 
|  | iterable_loader = DataLoader(IterableDataset(), batch_size=3, drop_last=True) | 
|  | self.assertEqual(len(iterable_loader), 3) | 
|  |  | 
|  | @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") | 
|  | def test_numpy_scalars(self): | 
|  | import numpy as np | 
|  |  | 
|  | class ScalarDataset(torch.utils.data.Dataset): | 
|  | def __init__(self, dtype): | 
|  | self.dtype = dtype | 
|  |  | 
|  | def __getitem__(self, i): | 
|  | return self.dtype() | 
|  |  | 
|  | def __len__(self): | 
|  | return 4 | 
|  |  | 
|  | dtypes = { | 
|  | np.float64: torch.DoubleTensor, | 
|  | np.float32: torch.FloatTensor, | 
|  | np.float16: torch.HalfTensor, | 
|  | np.int64: torch.LongTensor, | 
|  | np.int32: torch.IntTensor, | 
|  | np.int16: torch.ShortTensor, | 
|  | np.int8: torch.CharTensor, | 
|  | np.uint8: torch.ByteTensor, | 
|  | } | 
|  | for dt, tt in dtypes.items(): | 
|  | dset = ScalarDataset(dt) | 
|  | loader = self._get_data_loader(dset, batch_size=2) | 
|  | batch = next(iter(loader)) | 
|  | self.assertIsInstance(batch, tt) | 
|  |  | 
|  | def test_default_convert_mapping_keep_type(self): | 
|  | data = CustomDict({"a": 1, "b": 2}) | 
|  | converted = _utils.collate.default_convert(data) | 
|  |  | 
|  | self.assertEqual(converted, data) | 
|  |  | 
|  | def test_default_convert_sequence_keep_type(self): | 
|  | data = CustomList([1, 2, 3]) | 
|  | converted = _utils.collate.default_convert(data) | 
|  |  | 
|  | self.assertEqual(converted, data) | 
|  |  | 
|  | def test_default_convert_sequence_dont_keep_type(self): | 
|  | data = range(2) | 
|  | converted = _utils.collate.default_convert(data) | 
|  |  | 
|  | self.assertEqual(converted, [0, 1]) | 
|  |  | 
|  | def test_default_collate_dtype(self): | 
|  | arr = [1, 2, -1] | 
|  | collated = _utils.collate.default_collate(arr) | 
|  | self.assertEqual(collated, torch.tensor(arr)) | 
|  | self.assertEqual(collated.dtype, torch.int64) | 
|  |  | 
|  | arr = [1.1, 2.3, -0.9] | 
|  | collated = _utils.collate.default_collate(arr) | 
|  | self.assertEqual(collated, torch.tensor(arr, dtype=torch.float64)) | 
|  |  | 
|  | arr = [True, False] | 
|  | collated = _utils.collate.default_collate(arr) | 
|  | self.assertEqual(collated, torch.tensor(arr)) | 
|  | self.assertEqual(collated.dtype, torch.bool) | 
|  |  | 
|  | # Should be a no-op | 
|  | arr = ['a', 'b', 'c'] | 
|  | self.assertEqual(arr, _utils.collate.default_collate(arr)) | 
|  |  | 
|  | def test_default_collate_mapping_keep_type(self): | 
|  | batch = [CustomDict({"a": 1, "b": 2}), CustomDict({"a": 3, "b": 4})] | 
|  | collated = _utils.collate.default_collate(batch) | 
|  |  | 
|  | expected = CustomDict({"a": torch.tensor([1, 3]), "b": torch.tensor([2, 4])}) | 
|  | self.assertEqual(collated, expected) | 
|  |  | 
|  | def test_default_collate_sequence_keep_type(self): | 
|  | batch = [CustomList([1, 2, 3]), CustomList([4, 5, 6])] | 
|  | collated = _utils.collate.default_collate(batch) | 
|  |  | 
|  | expected = CustomList([ | 
|  | torch.tensor([1, 4]), | 
|  | torch.tensor([2, 5]), | 
|  | torch.tensor([3, 6]), | 
|  | ]) | 
|  | self.assertEqual(collated, expected) | 
|  |  | 
|  | def test_default_collate_sequence_dont_keep_type(self): | 
|  | batch = [range(2), range(2)] | 
|  | collated = _utils.collate.default_collate(batch) | 
|  |  | 
|  | self.assertEqual(collated, [torch.tensor([0, 0]), torch.tensor([1, 1])]) | 
|  |  | 
|  | @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") | 
|  | def test_default_collate_bad_numpy_types(self): | 
|  | import numpy as np | 
|  |  | 
|  | # Should be a no-op | 
|  | arr = np.array(['a', 'b', 'c']) | 
|  | self.assertEqual(arr, _utils.collate.default_collate(arr)) | 
|  |  | 
|  | arr = np.array([[['a', 'b', 'c']]]) | 
|  | self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr)) | 
|  |  | 
|  | arr = np.array([object(), object(), object()]) | 
|  | self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr)) | 
|  |  | 
|  | arr = np.array([[[object(), object(), object()]]]) | 
|  | self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr)) | 
|  |  | 
|  | @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") | 
|  | def test_default_collate_numpy_memmap(self): | 
|  | import numpy as np | 
|  |  | 
|  | with tempfile.TemporaryFile() as f: | 
|  | arr = np.array([[0, 1], [2, 3], [4, 5], [6, 7]]) | 
|  | arr_memmap = np.memmap(f, dtype=arr.dtype, mode='w+', shape=arr.shape) | 
|  | arr_memmap[:] = arr[:] | 
|  | arr_new = np.memmap(f, dtype=arr.dtype, mode='r', shape=arr.shape) | 
|  | tensor = _utils.collate.default_collate(list(arr_new)) | 
|  |  | 
|  | self.assertTrue((tensor == tensor.new_tensor([[0, 1], [2, 3], [4, 5], [6, 7]])).all().item()) | 
|  |  | 
|  | def test_default_collate_bad_sequence_type(self): | 
|  | batch = [['X'], ['X', 'X']] | 
|  | self.assertRaises(RuntimeError, lambda: _utils.collate.default_collate(batch)) | 
|  | self.assertRaises(RuntimeError, lambda: _utils.collate.default_collate(batch[::-1])) | 
|  |  | 
|  | @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") | 
|  | def test_default_collate_shared_tensor(self): | 
|  | import numpy as np | 
|  | t_in = torch.zeros(1) | 
|  | n_in = np.zeros(1) | 
|  |  | 
|  | self.assertEqual(t_in.is_shared(), False) | 
|  |  | 
|  | self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), False) | 
|  | self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), False) | 
|  |  | 
|  | # FIXME: fix the following hack that makes `default_collate` believe | 
|  | #        that it is in a worker process (since it tests | 
|  | #        `get_worker_info() != None`), even though it is not. | 
|  | old = _utils.worker._worker_info | 
|  | try: | 
|  | _utils.worker._worker_info = 'x' | 
|  | self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), True) | 
|  | self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), True) | 
|  | finally: | 
|  | _utils.worker._worker_info = old | 
|  |  | 
|  | def test_excessive_thread_creation_warning(self): | 
|  | with self.assertWarnsRegex( | 
|  | UserWarning, | 
|  | r"excessive worker creation might get DataLoader running slow or even freeze"): | 
|  | dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000) | 
|  |  | 
|  |  | 
|  | class IntegrationTestDataLoaderDataPipe(TestCase): | 
|  | r""" | 
|  | Verify the behavior of a certain ``DataPipes`` with ``DataLoader`` | 
|  | """ | 
|  |  | 
|  | def test_shuffler_iterdatapipe(self): | 
|  | r""" | 
|  | Verify ``IterDataPipe.shuffle`` is controlled by ``DataLoader`` | 
|  | to generate different seeds deterministically per epoch. | 
|  | """ | 
|  | exp = list(range(100)) | 
|  |  | 
|  | def _create_dp(buffer_size): | 
|  | input_ds = dp.iter.IterableWrapper(exp) | 
|  | return input_ds.shuffle(buffer_size=buffer_size).sharding_filter() | 
|  |  | 
|  | for bs in (5, 20, 33): | 
|  | # Test Deterministic | 
|  | for num_workers, pw in itertools.product((0, 1, 2), (True, False)): | 
|  | if num_workers == 0 and pw: | 
|  | continue | 
|  |  | 
|  | shuffle_dp = _create_dp(bs) | 
|  |  | 
|  | mp_ctx = "spawn" if num_workers > 0 else None | 
|  | dl = DataLoader( | 
|  | shuffle_dp, | 
|  | num_workers=num_workers, | 
|  | shuffle=True, | 
|  | multiprocessing_context=mp_ctx, | 
|  | persistent_workers=pw | 
|  | ) | 
|  |  | 
|  | # No seed | 
|  | dl_res_ns = list(dl) | 
|  | self.assertEqual(sorted(dl_res_ns), exp) | 
|  |  | 
|  | # Same seeds | 
|  | dl_res = [] | 
|  | for epoch in range(2): | 
|  | torch.manual_seed(123) | 
|  | dl_res.append(list(dl)) | 
|  | self.assertEqual(dl_res[0], dl_res[1]) | 
|  | self.assertEqual(sorted(dl_res[0]), exp) | 
|  |  | 
|  | # Different seeds | 
|  | torch.manual_seed(321) | 
|  | dl_res.append(list(dl)) | 
|  |  | 
|  | self.assertEqual(len(dl_res[0]), len(dl_res[2])) | 
|  | self.assertNotEqual(dl_res[0], dl_res[2]) | 
|  | self.assertEqual(sorted(dl_res[0]), sorted(dl_res[2])) | 
|  |  | 
|  | if dl._iterator is not None: | 
|  | dl._iterator._shutdown_workers() | 
|  | dl._iterator = None | 
|  | del dl | 
|  |  | 
|  |  | 
|  | class StringDataset(Dataset): | 
|  | def __init__(self): | 
|  | self.s = '12345' | 
|  |  | 
|  | def __len__(self): | 
|  | return len(self.s) | 
|  |  | 
|  | def __getitem__(self, ndx): | 
|  | return (self.s[ndx], ndx) | 
|  |  | 
|  |  | 
|  | @unittest.skipIf( | 
|  | TEST_WITH_TSAN, | 
|  | "Fails with TSAN with the following error: starting new threads after multi-threaded " | 
|  | "fork is not supported. Dying (set die_after_fork=0 to override)") | 
|  | class TestStringDataLoader(TestCase): | 
|  | def setUp(self): | 
|  | super().setUp() | 
|  | self.dataset = StringDataset() | 
|  |  | 
|  | @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") | 
|  | def test_shuffle_pin_memory(self): | 
|  | loader = DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True) | 
|  | for (s, n) in loader: | 
|  | self.assertIsInstance(s[0], str) | 
|  | self.assertTrue(n.is_pinned()) | 
|  |  | 
|  |  | 
|  | class DictDataset(Dataset): | 
|  | def __len__(self): | 
|  | return 4 | 
|  |  | 
|  | def __getitem__(self, ndx): | 
|  | return { | 
|  | 'a_tensor': torch.empty(4, 2).fill_(ndx), | 
|  | 'another_dict': { | 
|  | 'a_number': ndx, | 
|  | }, | 
|  | } | 
|  |  | 
|  |  | 
|  | @unittest.skipIf( | 
|  | TEST_WITH_TSAN, | 
|  | "Fails with TSAN with the following error: starting new threads after multi-threaded " | 
|  | "fork is not supported. Dying (set die_after_fork=0 to override)") | 
|  | class TestDictDataLoader(TestCase): | 
|  | def setUp(self): | 
|  | super().setUp() | 
|  | self.dataset = DictDataset() | 
|  |  | 
|  | def test_sequential_batch(self): | 
|  | for persistent_workers in (False, True): | 
|  | if persistent_workers: | 
|  | loader = DataLoader(self.dataset, batch_size=2, shuffle=False, | 
|  | persistent_workers=persistent_workers, num_workers=1) | 
|  | else: | 
|  | loader = DataLoader(self.dataset, batch_size=2, shuffle=False, | 
|  | persistent_workers=persistent_workers) | 
|  | batch_size = loader.batch_size | 
|  | for i, sample in enumerate(loader): | 
|  | idx = i * batch_size | 
|  | self.assertEqual(set(sample.keys()), {'a_tensor', 'another_dict'}) | 
|  | self.assertEqual(set(sample['another_dict'].keys()), {'a_number'}) | 
|  |  | 
|  | t = sample['a_tensor'] | 
|  | self.assertEqual(t.size(), torch.Size([batch_size, 4, 2])) | 
|  | self.assertTrue((t[0] == idx).all()) | 
|  | self.assertTrue((t[1] == idx + 1).all()) | 
|  |  | 
|  | n = sample['another_dict']['a_number'] | 
|  | self.assertEqual(n.size(), torch.Size([batch_size])) | 
|  | self.assertEqual(n[0], idx) | 
|  | self.assertEqual(n[1], idx + 1) | 
|  |  | 
|  | @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") | 
|  | def test_pin_memory(self): | 
|  | loader = DataLoader(self.dataset, batch_size=2, pin_memory=True) | 
|  | for sample in loader: | 
|  | self.assertTrue(sample['a_tensor'].is_pinned()) | 
|  | self.assertTrue(sample['another_dict']['a_number'].is_pinned()) | 
|  |  | 
|  | @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") | 
|  | def test_pin_memory_device(self): | 
|  | loader = DataLoader(self.dataset, batch_size=2, pin_memory=True, pin_memory_device='cuda') | 
|  | for sample in loader: | 
|  | self.assertTrue(sample['a_tensor'].is_pinned(device='cuda')) | 
|  | self.assertTrue(sample['another_dict']['a_number'].is_pinned(device='cuda')) | 
|  |  | 
|  | @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") | 
|  | def test_pin_memory_with_only_device(self): | 
|  | loader = DataLoader(self.dataset, batch_size=2, pin_memory_device='cuda') | 
|  | for sample in loader: | 
|  | self.assertFalse(sample['a_tensor'].is_pinned(device='cuda')) | 
|  | self.assertFalse(sample['another_dict']['a_number'].is_pinned(device='cuda')) | 
|  |  | 
|  | class DummyDataset(torch.utils.data.Dataset): | 
|  | def __init__(self): | 
|  | self.data = list(range(10)) | 
|  |  | 
|  | def __len__(self): | 
|  | return len(self.data) | 
|  |  | 
|  | def __getitem__(self, idx): | 
|  | if torch.is_tensor(idx): | 
|  | idx = idx.tolist() | 
|  | # The persistent workers always maintain the original | 
|  | # dataset through the dataloader lifetime | 
|  | # so the attributes will remain the same as the | 
|  | # first time the workers where spawned (dataloader iteration) | 
|  | assert self.start == 0 | 
|  | return self.data[idx] | 
|  |  | 
|  |  | 
|  | @unittest.skipIf( | 
|  | TEST_WITH_TSAN, | 
|  | "Fails with TSAN with the following error: starting new threads after multi-threaded " | 
|  | "fork is not supported. Dying (set die_after_fork=0 to override)") | 
|  | @unittest.skipIf( | 
|  | TEST_WITH_ASAN, "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223") | 
|  | class TestDataLoaderPersistentWorkers(TestDataLoader): | 
|  |  | 
|  | def setUp(self): | 
|  | super().setUp() | 
|  | self.persistent_workers = True | 
|  |  | 
|  | @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI") | 
|  | @unittest.skipIf(IS_WINDOWS, "No 'resource' module on Windows") | 
|  | def test_fd_limit_exceeded(self): | 
|  | # See NOTE [ DataLoader on Linux and open files limit ] | 
|  | import subprocess | 
|  | subprocess.check_output([sys.executable, '-c', """\ | 
|  | import torch | 
|  | import resource | 
|  | from torch.utils.data import DataLoader, IterableDataset | 
|  |  | 
|  | class RandomDataset(IterableDataset): | 
|  | def __init__(self, len, size): | 
|  | super(RandomDataset).__init__() | 
|  | self.len = len | 
|  | self.size = size | 
|  |  | 
|  | def __iter__(self): | 
|  | return self | 
|  |  | 
|  | def __next__(self): | 
|  | if self.len <= 0: | 
|  | raise StopIteration | 
|  | self.len -= 1 | 
|  | return torch.randn(self.size) | 
|  |  | 
|  | try: | 
|  | keep_fds_alive = [] | 
|  | resource.setrlimit(resource.RLIMIT_NOFILE, (100, 100)) | 
|  | for random_t in DataLoader(RandomDataset(200, (2,2)), multiprocessing_context="fork", | 
|  | num_workers=1, persistent_workers=True): | 
|  | random_t.max(dim=0) | 
|  | keep_fds_alive.append(random_t) | 
|  | except RuntimeError as e: | 
|  | assert "ulimit -n" in str(e) | 
|  | assert "set_sharing_strategy" in str(e) | 
|  | """]) | 
|  |  | 
|  | def test_dataset_not_reset(self): | 
|  | dataset = DummyDataset() | 
|  | pin_memory_configs = [False] | 
|  | if TEST_CUDA: | 
|  | pin_memory_configs.append(True) | 
|  | for pin_memory in pin_memory_configs: | 
|  | dataloader = self._get_data_loader(dataset, num_workers=2, pin_memory=pin_memory) | 
|  | dataset.start = 0 | 
|  | for i in range(10): | 
|  | for x in dataloader: | 
|  | pass | 
|  | # Changing the start value here doesn't have any effect in the dataset | 
|  | # cached by the workers. since they are not recreated between epochs | 
|  | # and can cache values safely | 
|  | dataset.start = i | 
|  |  | 
|  | @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI") | 
|  | @unittest.skipIf(IS_WINDOWS, "Needs fork") | 
|  | def test_early_exit(self): | 
|  | import subprocess | 
|  | proc = subprocess.check_output([sys.executable, '-c', """\ | 
|  | import torch | 
|  | from torch.utils.data import DataLoader, IterableDataset | 
|  |  | 
|  | class RandomDataset(IterableDataset): | 
|  | def __init__(self, len, size): | 
|  | super(RandomDataset).__init__() | 
|  | self.len = len | 
|  | self.size = size | 
|  |  | 
|  | def __iter__(self): | 
|  | return self | 
|  |  | 
|  | def __next__(self): | 
|  | if self.len <= 0: | 
|  | raise StopIteration | 
|  | self.len -= 1 | 
|  | return torch.randn(self.size) | 
|  |  | 
|  | if __name__ == '__main__': | 
|  | dl = DataLoader( | 
|  | RandomDataset(64, (28, 28)), | 
|  | batch_size=16, | 
|  | num_workers=2, | 
|  | pin_memory=True, | 
|  | persistent_workers=True, | 
|  | multiprocessing_context="fork", | 
|  | ) | 
|  |  | 
|  | for _ in dl: | 
|  | break | 
|  | """]) | 
|  |  | 
|  |  | 
|  | class NamedTupleDataset(Dataset): | 
|  | from collections import namedtuple | 
|  | Batch = namedtuple('Batch', ['data', 'label', 'random_tensor']) | 
|  | Data = namedtuple('Data', ['positive', 'negative']) | 
|  |  | 
|  | def __len__(self): | 
|  | return 4 | 
|  |  | 
|  | def __getitem__(self, ndx): | 
|  | return self.Batch(data=self.Data(positive=ndx, negative=-ndx), | 
|  | label=str(ndx), random_tensor=torch.randn(3)) | 
|  |  | 
|  |  | 
|  | @unittest.skipIf( | 
|  | TEST_WITH_TSAN, | 
|  | "Fails with TSAN with the following error: starting new threads after multi-threaded " | 
|  | "fork is not supported. Dying (set die_after_fork=0 to override)") | 
|  | class TestNamedTupleDataLoader(TestCase): | 
|  | def setUp(self): | 
|  | super().setUp() | 
|  | self.dataset = NamedTupleDataset() | 
|  |  | 
|  | def test_dataloader_with_namedtuple(self): | 
|  | # auto-collation | 
|  | loader = DataLoader(self.dataset, batch_size=2, pin_memory=TEST_CUDA) | 
|  | for batch in loader: | 
|  | self.assertIsInstance(batch, NamedTupleDataset.Batch) | 
|  | self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA) | 
|  | self.assertIsInstance(batch.data, NamedTupleDataset.Data) | 
|  | self.assertIsInstance(batch.data.positive, torch.Tensor) | 
|  | self.assertEqual(batch.data.positive.is_pinned(), TEST_CUDA) | 
|  | # no auto-collation | 
|  | loader = DataLoader(self.dataset, batch_size=None, pin_memory=TEST_CUDA) | 
|  | for batch in loader: | 
|  | self.assertIsInstance(batch, NamedTupleDataset.Batch) | 
|  | self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA) | 
|  | self.assertIsInstance(batch.data, NamedTupleDataset.Data) | 
|  | self.assertNotIsInstance(batch.data.positive, torch.Tensor) | 
|  |  | 
|  | class SimpleCustomBatch: | 
|  | def __init__(self, data): | 
|  | transposed_data = list(zip(*data)) | 
|  | self.inp = torch.stack(transposed_data[0], 0) | 
|  | self.tgt = torch.stack(transposed_data[1], 0) | 
|  |  | 
|  | def pin_memory(self): | 
|  | self.inp = self.inp.pin_memory() | 
|  | self.tgt = self.tgt.pin_memory() | 
|  | return self | 
|  |  | 
|  | def is_pinned(self): | 
|  | return self.inp.is_pinned() and self.tgt.is_pinned() | 
|  |  | 
|  | # Workaround for https://github.com/pytorch/pytorch/issues/50661 | 
|  | # Classes from  `__main__` can not be correctly unpickled from spawned module | 
|  | # See https://docs.python.org/3/library/multiprocessing.html#multiprocessing-programming | 
|  | self_module = __import__(os.path.splitext(os.path.basename(__file__))[0]) | 
|  |  | 
|  | def collate_wrapper(batch): | 
|  | return self_module.SimpleCustomBatch(batch) | 
|  |  | 
|  |  | 
|  | def collate_into_packed_sequence(batch): | 
|  | data = torch.stack([sample[0] for sample in batch], 1) | 
|  | t, b = data.size() | 
|  | lengths = torch.randint(1, t, size=(b,), dtype=torch.int64) | 
|  | return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, enforce_sorted=False) | 
|  |  | 
|  |  | 
|  | def collate_into_packed_sequence_batch_first(batch): | 
|  | data = torch.stack([sample[0] for sample in batch], 0) | 
|  | b, t = data.size() | 
|  | lengths = torch.randint(1, t, size=(b,), dtype=torch.int64) | 
|  | return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, batch_first=True, enforce_sorted=False) | 
|  |  | 
|  |  | 
|  | @unittest.skipIf( | 
|  | TEST_WITH_TSAN, | 
|  | "Fails with TSAN with the following error: starting new threads after multi-threaded " | 
|  | "fork is not supported. Dying (set die_after_fork=0 to override)") | 
|  | class TestCustomPinFn(TestCase): | 
|  | def setUp(self): | 
|  | super().setUp() | 
|  | inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) | 
|  | tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) | 
|  | self.dataset = TensorDataset(inps, tgts) | 
|  |  | 
|  | @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") | 
|  | def test_custom_batch_pin(self): | 
|  | test_cases = [ | 
|  | (collate_wrapper, self_module.SimpleCustomBatch), | 
|  | (collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence), | 
|  | (collate_into_packed_sequence_batch_first, torch.nn.utils.rnn.PackedSequence), | 
|  | ] | 
|  | for collate_fn, elem_cls in test_cases: | 
|  | loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_fn, | 
|  | pin_memory=True) | 
|  | for sample in loader: | 
|  | self.assertIsInstance(sample, elem_cls) | 
|  | self.assertTrue(sample.is_pinned()) | 
|  |  | 
|  | @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") | 
|  | def test_custom_batch_pin_worker(self): | 
|  | test_cases = [ | 
|  | (collate_wrapper, self_module.SimpleCustomBatch), | 
|  | (collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence), | 
|  | (collate_into_packed_sequence_batch_first, torch.nn.utils.rnn.PackedSequence), | 
|  | ] | 
|  | for collate_fn, elem_cls in test_cases: | 
|  | loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_fn, | 
|  | pin_memory=True, num_workers=1) | 
|  | for sample in loader: | 
|  | self.assertIsInstance(sample, elem_cls) | 
|  | self.assertTrue(sample.is_pinned()) | 
|  |  | 
|  |  | 
|  | class TestWorkerQueueDataset(Dataset): | 
|  | def __init__(self, data): | 
|  | self.data = data | 
|  | self.worker_id = None | 
|  |  | 
|  | def worker_init_fn(self, worker_id): | 
|  | self.worker_id = worker_id | 
|  |  | 
|  | def __getitem__(self, item): | 
|  | return self.worker_id, self.data[item] | 
|  |  | 
|  | def __len__(self): | 
|  | return len(self.data) | 
|  |  | 
|  |  | 
|  | @unittest.skipIf( | 
|  | TEST_WITH_TSAN, | 
|  | "Fails with TSAN with the following error: starting new threads after multi-threaded " | 
|  | "fork is not supported. Dying (set die_after_fork=0 to override)") | 
|  | @unittest.skipIf( | 
|  | TEST_WITH_ASAN, | 
|  | "Flaky with ASAN, see https://github.com/pytorch/pytorch/issues/65727") | 
|  | class TestIndividualWorkerQueue(TestCase): | 
|  | def setUp(self): | 
|  | super().setUp() | 
|  | self.dataset = TestWorkerQueueDataset(list(range(128))) | 
|  |  | 
|  | def _run_ind_worker_queue_test(self, batch_size, num_workers): | 
|  | loader = DataLoader( | 
|  | self.dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, | 
|  | timeout=5, worker_init_fn=self.dataset.worker_init_fn | 
|  | ) | 
|  | current_worker_idx = 0 | 
|  | for i, (worker_ids, sample) in enumerate(loader): | 
|  | self.assertEqual(worker_ids.tolist(), [current_worker_idx] * batch_size) | 
|  | self.assertEqual(sample.tolist(), list(range(i * batch_size, (i + 1) * batch_size))) | 
|  | current_worker_idx += 1 | 
|  | if current_worker_idx == num_workers: | 
|  | current_worker_idx = 0 | 
|  |  | 
|  | def test_ind_worker_queue(self): | 
|  | max_num_workers = None | 
|  | if hasattr(os, 'sched_getaffinity'): | 
|  | try: | 
|  | max_num_workers = len(os.sched_getaffinity(0)) | 
|  | except Exception: | 
|  | pass | 
|  | if max_num_workers is None: | 
|  | cpu_count = os.cpu_count() | 
|  | if cpu_count is not None: | 
|  | # Use half number of CPUs | 
|  | max_num_workers = cpu_count // 2 | 
|  |  | 
|  | if max_num_workers is None: | 
|  | max_num_workers = 1 | 
|  |  | 
|  | for batch_size in (8, 16, 32, 64): | 
|  | for num_workers in range(0, min(6, max_num_workers)): | 
|  | self._run_ind_worker_queue_test(batch_size=batch_size, num_workers=num_workers + 1) | 
|  |  | 
|  |  | 
|  | class SetAffinityDataset(IterableDataset): | 
|  |  | 
|  | def __iter__(self): | 
|  | torch.randperm(1) | 
|  | after = os.sched_getaffinity(0) | 
|  | return iter(after) | 
|  |  | 
|  | @unittest.skipIf( | 
|  | not hasattr(os, 'sched_setaffinity'), | 
|  | "os.sched_setaffinity is not available") | 
|  | class TestSetAffinity(TestCase): | 
|  | def test_set_affinity_in_worker_init(self): | 
|  | # Query the current affinity mask to avoid setting a disallowed one | 
|  | old_affinity = os.sched_getaffinity(0) | 
|  | if not old_affinity: | 
|  | self.skipTest("No affinity information") | 
|  | # Choose any | 
|  | expected_affinity = list(old_affinity)[-1] | 
|  |  | 
|  | def worker_set_affinity(_): | 
|  | os.sched_setaffinity(0, [expected_affinity]) | 
|  |  | 
|  |  | 
|  | dataset = SetAffinityDataset() | 
|  |  | 
|  | dataloader = torch.utils.data.DataLoader( | 
|  | dataset, num_workers=2, worker_init_fn=worker_set_affinity) | 
|  | for sample in dataloader: | 
|  | self.assertEqual(sample, [expected_affinity]) | 
|  |  | 
|  | class ConvDataset(Dataset): | 
|  | def __init__(self): | 
|  | self.x = torch.ones(1, 1, 24000) | 
|  | # Call convolution on parent process | 
|  | self[0] | 
|  |  | 
|  | def __len__(self): | 
|  | return 1 | 
|  |  | 
|  | def __getitem__(self, index): | 
|  | return torch.nn.functional.conv1d(self.x, torch.ones(1, 1, 2)) | 
|  |  | 
|  |  | 
|  | @unittest.skipIf(IS_WINDOWS, "Needs fork") | 
|  | @unittest.skipIf( | 
|  | TEST_WITH_ASAN, | 
|  | "This test hangs when running with ASAN, see https://github.com/pytorch/pytorch/issues/75492") | 
|  | class TestConvAfterFork(TestCase): | 
|  | # Tests crash reported in https://github.com/pytorch/pytorch/issues/53565 | 
|  | def test_conv_after_fork(self): | 
|  | loader = DataLoader(ConvDataset(), num_workers=1) | 
|  | for x in loader: | 
|  | self.assertEqual(x.shape, (1, 1, 1, 23999)) | 
|  |  | 
|  |  | 
|  | if __name__ == '__main__': | 
|  | run_tests() |