15 from torch
import multiprocessing
as mp
16 from torch.utils.data import _utils, Dataset, TensorDataset, DataLoader, ConcatDataset
19 from common_utils
import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, IS_PPC,
20 IS_PYTORCH_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm,
28 err_msg = (
"psutil not found. Some critical data loader tests relying on it " 29 "(e.g., TestDataLoader.test_proper_exit) will not run.")
31 raise ImportError(err_msg)
33 warnings.warn(err_msg)
38 load_tests = load_tests
46 if not NO_MULTIPROCESSING_SPAWN:
49 mp = mp.get_context(method=
'spawn')
52 JOIN_TIMEOUT = 17.0
if (IS_WINDOWS
or IS_PPC)
else 13.0
56 def test_lengths_must_equal_dataset_size(self):
57 with self.assertRaises(ValueError):
58 random_split([1, 2, 3, 4], [1, 2])
60 def test_splits_have_correct_size(self):
61 splits = random_split([1, 2, 3, 4, 5, 6], [2, 4])
62 self.assertEqual(len(splits), 2)
63 self.assertEqual(len(splits[0]), 2)
64 self.assertEqual(len(splits[1]), 4)
66 def test_splits_are_mutually_exclusive(self):
67 data = [5, 2, 3, 4, 1, 6]
68 splits = random_split(data, [2, 4])
70 all_values.extend(list(splits[0]))
71 all_values.extend(list(splits[1]))
74 self.assertListEqual(data, all_values)
76 def test_splits_indexing_type(self):
77 r"""Indices generated by random_split 78 should be of integer type 80 class CustomDataset():
81 def __init__(self, test_object, custom_list):
82 self.
data = custom_list
85 def __getitem__(self, key):
86 self.test_object.assertEqual(type(key), type(0))
93 dataset = CustomDataset(self, x)
94 dataset = random_split(dataset, [5])[0]
95 data_loader = DataLoader(dataset)
96 for batch
in data_loader:
103 source = TensorDataset(torch.randn(15, 10, 2, 3, 4, 5), torch.randperm(15))
104 self.assertEqual(len(source), 15)
106 def test_getitem(self):
107 t = torch.randn(15, 10, 2, 3, 4, 5)
108 l = torch.randn(15, 10)
109 source = TensorDataset(t, l)
111 self.assertEqual(t[i], source[i][0])
112 self.assertEqual(l[i], source[i][1])
114 def test_getitem_1d(self):
117 source = TensorDataset(t, l)
119 self.assertEqual(t[i], source[i][0])
120 self.assertEqual(l[i], source[i][1])
122 def test_single_tensor(self):
123 t = torch.randn(5, 10)
124 source = TensorDataset(t)
125 self.assertEqual(len(source), 5)
127 self.assertEqual(t[i], source[i][0])
129 def test_many_tensors(self):
130 t0 = torch.randn(5, 10, 2, 3, 4, 5)
131 t1 = torch.randn(5, 10)
132 t2 = torch.randn(5, 10, 2, 5)
133 t3 = torch.randn(5, 10, 3, 7)
134 source = TensorDataset(t0, t1, t2, t3)
135 self.assertEqual(len(source), 5)
137 self.assertEqual(t0[i], source[i][0])
138 self.assertEqual(t1[i], source[i][1])
139 self.assertEqual(t2[i], source[i][2])
140 self.assertEqual(t3[i], source[i][3])
145 def test_concat_two_singletons(self):
146 result = ConcatDataset([[0], [1]])
147 self.assertEqual(2, len(result))
148 self.assertEqual(0, result[0])
149 self.assertEqual(1, result[1])
151 def test_concat_two_non_singletons(self):
152 result = ConcatDataset([[0, 1, 2, 3, 4],
154 self.assertEqual(10, len(result))
155 self.assertEqual(0, result[0])
156 self.assertEqual(5, result[5])
158 def test_concat_two_non_singletons_with_empty(self):
160 result = ConcatDataset([[0, 1, 2, 3, 4],
163 self.assertEqual(10, len(result))
164 self.assertEqual(0, result[0])
165 self.assertEqual(5, result[5])
167 def test_concat_raises_index_error(self):
168 result = ConcatDataset([[0, 1, 2, 3, 4],
170 with self.assertRaises(IndexError):
174 def test_add_dataset(self):
175 d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
176 d2 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
177 d3 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
178 result = d1 + d2 + d3
179 self.assertEqual(21, len(result))
180 self.assertEqual(0, (d1[0][0] - result[0][0]).abs().sum())
181 self.assertEqual(0, (d2[0][0] - result[7][0]).abs().sum())
182 self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum())
189 def __init__(self, *args, **kwargs):
190 super(ErrorTrackingProcess, self).__init__(*args, **kwargs)
191 self._pconn, self.
_cconn = mp.Pipe()
196 sys.stderr = open(os.devnull,
"w")
198 super(ErrorTrackingProcess, self).run()
199 self._cconn.send(
None)
206 if self._pconn.poll():
211 return self._exception.exc_type(self._exception.exc_msg)
214 def send_signal(self, signum, ignore_ESRCH=False):
216 os.kill(self.pid, signum)
218 if not ignore_ESRCH
or e.errno != errno.ESRCH:
224 def __init__(self, size):
233 def __init__(self, size):
236 def __getitem__(self, idx):
237 return ctypes.string_at(0)
245 def __init__(self, size, sleep_sec):
250 def __getitem__(self, idx):
262 def __init__(self, size):
265 def __getitem__(self, idx):
266 return torch.initial_seed()
276 def __init__(self, size, num_workers):
277 assert size >= num_workers
278 self.
count = mp.Value(
'i', 0, lock=
True)
283 def __getitem__(self, idx):
284 with self.count.get_lock():
285 self.count.value += 1
287 self.barrier.release()
288 self.barrier.acquire()
289 self.barrier.release()
290 return torch.initial_seed()
298 dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1)
299 _ = next(iter(dataloader))
302 def _test_timeout_pin_memory():
304 dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1, pin_memory=
True)
305 _ = next(iter(dataloader))
308 def disable_stderr(worker_id):
310 Avoids printing "ERROR: Unexpected segmentation fault encountered in worker." 311 from workers. Since worker signal handler prints with low-level write(), 312 this has to be done on OS level via dup. 314 This is used as worker_init_fn for test_segfault. 319 devnull = open(os.devnull,
'w')
320 os.dup2(devnull.fileno(), sys.stderr.fileno())
323 def _test_segfault():
325 dataloader = DataLoader(dataset, batch_size=2, num_workers=2, worker_init_fn=disable_stderr)
326 _ = next(iter(dataloader))
330 def __init__(self, size, error_event):
337 def __getitem__(self, idx):
338 if self.
error_event is not None and self.error_event.is_set():
339 raise RuntimeError(
'Worker error')
344 def _test_proper_exit(use_workers, pin_memory, exit_method, hold_iter_reference,
345 loader_setup_event, tester_setup_event):
346 num_workers = 2
if use_workers
else 0
348 if exit_method ==
'worker_error' or exit_method ==
'worker_kill':
349 assert use_workers
is True 351 if exit_method ==
'worker_error':
352 worker_error_event = mp.Event()
354 worker_error_event =
None 358 loader = DataLoader(ds, batch_size=1, shuffle=
False,
359 num_workers=num_workers, pin_memory=pin_memory)
365 assert len(loader) > (error_it + 2 + 1) * num_workers
372 psutil_p = psutil.Process(pid)
374 psutil_p.wait(JOIN_TIMEOUT)
375 assert not psutil_p.is_running()
377 for i, _
in enumerate(it):
379 if not hold_iter_reference:
381 loader_setup_event.set()
382 tester_setup_event.wait()
387 if worker_error_event
is not None:
388 worker_error_event.set()
391 if exit_method ==
'loader_error':
392 raise RuntimeError(
'Loader error')
393 elif exit_method ==
'loader_kill':
394 kill_pid(os.getpid())
395 elif exit_method ==
'worker_kill':
396 kill_pid(workers[0].pid)
398 if not hold_iter_reference:
407 def init_fn(worker_id):
408 torch.manual_seed(12345)
414 self.
data = torch.randn(100, 2, 3, 5)
415 self.
labels = torch.randperm(50).repeat(2)
418 def _test_sequential(self, loader):
419 batch_size = loader.batch_size
420 for i, (sample, target)
in enumerate(loader):
422 self.assertEqual(sample, self.
data[idx:idx + batch_size])
423 self.assertEqual(target, self.
labels[idx:idx + batch_size])
424 self.assertEqual(i, math.floor((len(self.
dataset) - 1) / batch_size))
426 def _test_shuffle(self, loader):
427 found_data = {i: 0
for i
in range(self.data.size(0))}
428 found_labels = {i: 0
for i
in range(self.labels.size(0))}
429 batch_size = loader.batch_size
430 for i, (batch_samples, batch_targets)
in enumerate(loader):
431 for sample, target
in zip(batch_samples, batch_targets):
432 for data_point_idx, data_point
in enumerate(self.
data):
433 if data_point.eq(sample).all():
434 self.assertFalse(found_data[data_point_idx])
435 found_data[data_point_idx] += 1
437 self.assertEqual(target, self.
labels[data_point_idx])
438 found_labels[data_point_idx] += 1
439 self.assertEqual(sum(found_data.values()), (i + 1) * batch_size)
440 self.assertEqual(sum(found_labels.values()), (i + 1) * batch_size)
441 self.assertEqual(i, math.floor((len(self.
dataset) - 1) / batch_size))
443 def _test_error(self, loader):
449 except NotImplementedError:
451 except StopIteration:
452 self.assertEqual(errors,
453 math.ceil(float(len(loader.dataset)) / loader.batch_size))
456 def test_invalid_assign_after_init(self):
458 for attr
in (
'batch_size',
'sampler',
'drop_last'):
460 setattr(dl, attr, {})
462 self.assertRaises(ValueError, fn)
464 def test_sequential(self):
467 def test_sequential_batch(self):
470 def test_growing_dataset(self):
471 dataset = [torch.ones(4)
for _
in range(4)]
472 dataloader_seq = DataLoader(dataset, shuffle=
False)
473 dataloader_shuffle = DataLoader(dataset, shuffle=
True)
474 dataset.append(torch.ones(4))
475 self.assertEqual(len(dataloader_seq), 5)
476 self.assertEqual(len(dataloader_shuffle), 5)
478 @unittest.skipIf(
not TEST_CUDA,
"CUDA unavailable")
479 def test_sequential_pin_memory(self):
480 loader = DataLoader(self.
dataset, batch_size=2, pin_memory=
True)
481 for input, target
in loader:
482 self.assertTrue(input.is_pinned())
483 self.assertTrue(target.is_pinned())
485 def test_multiple_dataloaders(self):
486 loader1_it = iter(DataLoader(self.
dataset, num_workers=1))
487 loader2_it = iter(DataLoader(self.
dataset, num_workers=2))
495 @unittest.skip(
"temporarily disable until flaky failures are fixed")
496 def test_segfault(self):
501 self.assertFalse(p.is_alive())
502 self.assertNotEqual(p.exitcode, 0)
504 self.assertIsInstance(p.exception, OSError)
505 self.assertRegex(str(p.exception),
r'access violation reading ')
507 self.assertIsInstance(p.exception, RuntimeError)
508 self.assertRegex(str(p.exception),
r'DataLoader worker \(pid \d+\) is killed by signal: ')
512 def test_timeout(self):
513 if TEST_CUDA
and not NO_MULTIPROCESSING_SPAWN:
514 targets = (_test_timeout, _test_timeout_pin_memory)
516 targets = (_test_timeout,)
517 for target
in targets:
522 self.assertFalse(p.is_alive())
523 self.assertNotEqual(p.exitcode, 0)
524 self.assertIsInstance(p.exception, RuntimeError)
525 self.assertRegex(str(p.exception),
r'DataLoader timed out after \d+ seconds')
529 def test_worker_seed(self):
532 dataloader = DataLoader(dataset, batch_size=1, num_workers=num_workers)
534 for batch
in dataloader:
536 self.assertEqual(len(seeds), num_workers)
538 def test_worker_init_fn(self):
540 dataloader = DataLoader(dataset, batch_size=2, num_workers=2,
541 worker_init_fn=init_fn)
542 for batch
in dataloader:
543 self.assertEqual(12345, batch[0])
544 self.assertEqual(12345, batch[1])
546 def test_shuffle(self):
549 def test_shuffle_batch(self):
552 def test_sequential_workers(self):
555 def test_seqential_batch_workers(self):
558 def test_shuffle_workers(self):
561 def test_shuffle_batch_workers(self):
564 def _test_batch_sampler(self, **kwargs):
567 for i
in range(0, 100, 5):
568 batches.append(tuple(range(i, i + 2)))
569 batches.append(tuple(range(i + 2, i + 5)))
571 dl = DataLoader(self.
dataset, batch_sampler=batches, **kwargs)
572 self.assertEqual(len(dl), 40)
573 for i, (input, _target)
in enumerate(dl):
576 self.assertEqual(len(input), 2)
577 self.assertEqual(input, self.
data[offset:offset + 2])
580 self.assertEqual(len(input), 3)
581 self.assertEqual(input, self.
data[offset:offset + 3])
583 def test_RandomSampler(self):
585 from collections
import Counter
588 def sample_stat(sampler, num_samples):
589 counts = Counter(sampler)
590 count_repeated = sum(val > 1
for val
in counts.values())
591 return (count_repeated, min(counts.keys()), max(counts.keys()))
595 sampler_with_replacement = RandomSampler(self.
dataset, replacement=
True, num_samples=n)
596 count_repeated, minval, maxval = sample_stat(sampler_with_replacement, n)
597 self.assertTrue(count_repeated > 0)
598 self.assertTrue(minval >= 0)
599 self.assertTrue(maxval < len(self.
dataset))
602 sampler_without_replacement = RandomSampler(self.
dataset)
603 count_repeated, minval, maxval = sample_stat(sampler_without_replacement, len(self.
dataset))
604 self.assertTrue(count_repeated == 0)
605 self.assertTrue(minval == 0)
606 self.assertTrue(maxval == len(self.
dataset) - 1)
609 self.assertRaises(ValueError,
lambda: RandomSampler(self.
dataset, num_samples=len(self.
dataset)))
611 self.assertRaises(ValueError,
lambda: RandomSampler(self.
dataset, num_samples=0))
613 def test_random_sampler_len_with_replacement(self):
616 num_samples = len(self.
dataset) + 5
617 sampler = RandomSampler(self.
dataset,
619 num_samples=num_samples)
621 self.assertEqual(num_samples, len(sampler))
624 count_num_samples = sum(1
for _
in sampler)
625 self.assertEqual(num_samples, count_num_samples)
629 count_num_samples_in_data_loader = len(DataLoader(
630 self.
dataset, batch_size=batch_size, sampler=sampler))
631 self.assertEqual(num_samples, count_num_samples_in_data_loader)
635 count_num_samples_in_data_loader = len(DataLoader(
636 self.
dataset, batch_size=batch_size, sampler=sampler))
637 self.assertEqual(int(math.ceil(float(num_samples) / batch_size)),
638 count_num_samples_in_data_loader)
640 def test_duplicating_data_with_drop_last(self):
646 data_set = torch.IntTensor(range(num_batches))
647 scanned_data = torch.IntTensor([])
648 for i
in range(num_processes):
649 s = DistributedSampler(data_set, num_processes, i)
650 d_loader = DataLoader(data_set, batch_size=int(num_batches / num_processes), drop_last=
True, sampler=s)
651 for data
in d_loader:
652 scanned_data = torch.cat((scanned_data, data), 0)
654 self.assertEqual(scanned_data.size(), scanned_data.unique().size())
656 @unittest.skipIf(NO_MULTIPROCESSING_SPAWN,
"Disabled for environments that \ 657 don't support multiprocessing with spawn start method")
658 def test_batch_sampler(self):
662 @unittest.skipIf(
not TEST_CUDA,
"CUDA unavailable")
663 def test_shuffle_pin_memory(self):
664 loader = DataLoader(self.
dataset, batch_size=2, shuffle=
True, num_workers=4, pin_memory=
True)
665 for input, target
in loader:
666 self.assertTrue(input.is_pinned())
667 self.assertTrue(target.is_pinned())
669 @unittest.skipIf(
not TEST_NUMPY,
"numpy unavailable")
670 def test_numpy(self):
673 class TestDataset(torch.utils.data.Dataset):
674 def __getitem__(self, i):
675 return np.ones((2, 3, 4)) * i
680 loader = DataLoader(TestDataset(), batch_size=12)
681 batch = next(iter(loader))
682 self.assertIsInstance(batch, torch.DoubleTensor)
683 self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))
685 def test_error(self):
688 @unittest.skipIf(NO_MULTIPROCESSING_SPAWN,
"Disabled for environments that \ 689 don't support multiprocessing with spawn start method")
690 def test_error_workers(self):
693 @unittest.skipIf(IS_WINDOWS,
"FIXME: stuck test")
694 def test_partial_workers(self):
695 r"""Check that workers exit even if the iterator is not exhausted.""" 697 pin_memory_configs = (
True,
False)
699 pin_memory_configs = (
False,)
701 for pin_memory
in pin_memory_configs:
702 loader = iter(DataLoader(self.
dataset, batch_size=2, num_workers=4, pin_memory=pin_memory))
703 workers = loader.workers
705 pin_memory_thread = loader.pin_memory_thread
706 for i, _
in enumerate(loader):
713 self.assertFalse(w.is_alive(),
'subprocess not terminated')
715 pin_memory_thread.join(JOIN_TIMEOUT)
716 self.assertFalse(pin_memory_thread.is_alive())
719 @unittest.skipIf(
not HAS_PSUTIL,
"psutil not found")
720 def test_proper_exit(self):
721 (
r'''There might be ConnectionResetError or leaked semaphore warning ''' 722 r'''(due to dirty process exit), but they are all safe to ignore''')
727 for use_workers, pin_memory, hold_iter_reference
in itertools.product([
True,
False], repeat=3):
735 if pin_memory
and (
not TEST_CUDA
or NO_MULTIPROCESSING_SPAWN):
745 exit_methods = [
None,
'loader_error',
'loader_kill',
'worker_kill',
'worker_error']
747 exit_methods = [
None,
'loader_error',
'loader_kill']
749 for exit_method
in exit_methods:
752 desc.append(
'use_workers={}'.format(use_workers))
753 desc.append(
'pin_memory={}'.format(pin_memory))
754 desc.append(
'hold_iter_reference={}'.format(hold_iter_reference))
755 desc.append(
'exit_method={}'.format(exit_method))
756 desc =
'test_proper_exit with ' +
', '.join(desc)
761 loader_setup_event = mp.Event()
766 tester_setup_event = mp.Event()
769 args=(use_workers, pin_memory, exit_method,
770 hold_iter_reference, loader_setup_event,
776 loader_setup_event.wait(timeout=JOIN_TIMEOUT)
777 if not loader_setup_event.is_set():
778 fail_msg = desc +
': loader process failed to setup within given time' 779 if loader_p.exception
is not None:
780 self.fail(fail_msg +
', and had exception {}'.format(loader_p.exception))
781 elif not loader_p.is_alive():
782 self.fail(fail_msg +
', and exited with code {} but had no exception'.format(loader_p.exitcode))
784 self.fail(fail_msg +
', and is still alive.')
786 worker_psutil_p = psutil.Process(loader_p.pid).children()
788 tester_setup_event.set()
791 loader_p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL)
792 if loader_p.is_alive():
793 fail_msg = desc +
': loader process did not terminate' 794 if loader_p.exception
is not None:
795 self.fail(fail_msg +
', and had exception {}'.format(loader_p.exception))
797 self.fail(fail_msg +
', and had no exception')
798 _, alive = psutil.wait_procs(worker_psutil_p, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT))
800 self.fail(desc +
': worker process (pid(s) {}) did not terminate'.format(
801 ', '.join(str(p.pid)
for p
in alive)))
802 if exit_method
is None:
803 self.assertEqual(loader_p.exitcode, 0)
805 self.assertNotEqual(loader_p.exitcode, 0)
806 if exit_method ==
'loader_error':
807 self.assertIsInstance(loader_p.exception, RuntimeError, desc)
808 self.assertIn(
'Loader error', str(loader_p.exception), desc)
809 elif exit_method ==
'worker_kill':
810 self.assertIsInstance(loader_p.exception, RuntimeError, desc)
811 self.assertIn(
'DataLoader worker (pid', str(loader_p.exception), desc)
812 elif exit_method ==
'worker_error':
813 self.assertIsInstance(loader_p.exception, RuntimeError, desc)
814 self.assertIn(
'Worker error', str(loader_p.exception), desc)
819 def check_len(dl, expected):
820 self.assertEqual(len(dl), expected)
824 self.assertEqual(n, expected)
826 check_len(DataLoader(self.
dataset, batch_size=2), 50)
827 check_len(DataLoader(self.
dataset, batch_size=3), 34)
829 @unittest.skipIf(
not TEST_NUMPY,
"numpy unavailable")
830 def test_numpy_scalars(self):
833 class ScalarDataset(torch.utils.data.Dataset):
834 def __init__(self, dtype):
837 def __getitem__(self, i):
844 np.float64: torch.DoubleTensor,
845 np.float32: torch.FloatTensor,
846 np.float16: torch.HalfTensor,
847 np.int64: torch.LongTensor,
848 np.int32: torch.IntTensor,
849 np.int16: torch.ShortTensor,
850 np.int8: torch.CharTensor,
851 np.uint8: torch.ByteTensor,
853 for dt, tt
in dtypes.items():
854 dset = ScalarDataset(dt)
855 loader = DataLoader(dset, batch_size=2)
856 batch = next(iter(loader))
857 self.assertIsInstance(batch, tt)
859 def test_default_collate_dtype(self):
861 collated = _utils.collate.default_collate(arr)
863 self.assertEqual(collated.dtype, torch.int64)
865 arr = [1.1, 2.3, -0.9]
866 collated = _utils.collate.default_collate(arr)
868 self.assertEqual(collated.dtype, torch.float64)
871 collated = _utils.collate.default_collate(arr)
873 self.assertEqual(collated.dtype, torch.uint8)
876 arr = [
'a',
'b',
'c']
877 self.assertEqual(arr, _utils.collate.default_collate(arr))
879 @unittest.skipIf(
not TEST_NUMPY,
"numpy unavailable")
880 def test_default_collate_bad_numpy_types(self):
884 arr = np.array([
'a',
'b',
'c'])
885 self.assertEqual(arr, _utils.collate.default_collate(arr))
887 arr = np.array([[[
'a',
'b',
'c']]])
888 self.assertRaises(TypeError,
lambda: _utils.collate.default_collate(arr))
890 arr = np.array([object(), object(), object()])
891 self.assertRaises(TypeError,
lambda: _utils.collate.default_collate(arr))
893 arr = np.array([[[object(), object(), object()]]])
894 self.assertRaises(TypeError,
lambda: _utils.collate.default_collate(arr))
896 @unittest.skipIf(
not TEST_NUMPY,
"numpy unavailable")
897 def test_default_collate_shared_tensor(self):
899 t_in = torch.zeros(1)
902 self.assertEqual(t_in.is_shared(),
False)
904 self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(),
False)
905 self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(),
False)
907 old = _utils.collate._use_shared_memory
909 _utils.collate._use_shared_memory =
True 910 self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(),
True)
911 self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(),
True)
913 _utils.collate._use_shared_memory = old
923 def __getitem__(self, ndx):
924 return (self.
s[ndx], ndx)
931 @unittest.skipIf(
not TEST_CUDA,
"CUDA unavailable")
932 def test_shuffle_pin_memory(self):
933 loader = DataLoader(self.
dataset, batch_size=2, shuffle=
True, num_workers=4, pin_memory=
True)
934 for (s, n)
in loader:
935 self.assertIsInstance(s[0], str)
936 self.assertTrue(n.is_pinned())
943 def __getitem__(self, ndx):
945 'a_tensor': torch.Tensor(4, 2).fill_(ndx),
956 def test_sequential_batch(self):
957 loader = DataLoader(self.
dataset, batch_size=2, shuffle=
False)
958 batch_size = loader.batch_size
959 for i, sample
in enumerate(loader):
961 self.assertEqual(set(sample.keys()), {
'a_tensor',
'another_dict'})
962 self.assertEqual(set(sample[
'another_dict'].keys()), {
'a_number'})
964 t = sample[
'a_tensor']
965 self.assertEqual(t.size(), torch.Size([batch_size, 4, 2]))
966 self.assertTrue((t[0] == idx).all())
967 self.assertTrue((t[1] == idx + 1).all())
969 n = sample[
'another_dict'][
'a_number']
970 self.assertEqual(n.size(), torch.Size([batch_size]))
971 self.assertEqual(n[0], idx)
972 self.assertEqual(n[1], idx + 1)
974 @unittest.skipIf(
not TEST_CUDA,
"CUDA unavailable")
975 def test_pin_memory(self):
976 loader = DataLoader(self.
dataset, batch_size=2, pin_memory=
True)
977 for sample
in loader:
978 self.assertTrue(sample[
'a_tensor'].is_pinned())
979 self.assertTrue(sample[
'another_dict'][
'a_number'].is_pinned())
983 from collections
import namedtuple
984 Batch = namedtuple(
'Batch', [
'data',
'label'])
985 Data = namedtuple(
'Data', [
'positive',
'negative'])
990 def __getitem__(self, ndx):
991 return self.
Batch(data=self.
Data(positive=ndx, negative=-ndx),
999 @unittest.skipIf(
not TEST_CUDA,
"CUDA unavailable")
1000 def test_collate_and_pin_memory_with_namedtuple(self):
1001 loader = DataLoader(self.
dataset, batch_size=2, pin_memory=
True)
1002 for batch
in loader:
1003 self.assertIsInstance(batch, NamedTupleDataset.Batch)
1004 self.assertIsInstance(batch.data, NamedTupleDataset.Data)
1008 def __init__(self, data):
1009 transposed_data = list(zip(*data))
1010 self.
inp = torch.stack(transposed_data[0], 0)
1011 self.
tgt = torch.stack(transposed_data[1], 0)
1013 def pin_memory(self):
1014 self.
inp = self.inp.pin_memory()
1015 self.
tgt = self.tgt.pin_memory()
1018 def is_pinned(self):
1019 return self.inp.is_pinned()
and self.tgt.is_pinned()
1022 def collate_wrapper(batch):
1026 def collate_into_packed_sequence(batch):
1027 data = torch.stack([sample[0]
for sample
in batch], 1)
1029 lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
1033 def collate_into_packed_sequence_batch_first(batch):
1034 data = torch.stack([sample[0]
for sample
in batch], 0)
1036 lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
1042 inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
1043 tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
1044 self.
dataset = TensorDataset(inps, tgts)
1046 @unittest.skipIf(
not TEST_CUDA,
"CUDA unavailable")
1048 def test_custom_batch_pin(self):
1050 (collate_wrapper, SimpleCustomBatch),
1054 for collate_fn, elem_cls
in test_cases:
1055 loader = DataLoader(self.
dataset, batch_size=2, collate_fn=collate_fn,
1057 for sample
in loader:
1058 self.assertIsInstance(sample, elem_cls)
1059 self.assertTrue(sample.is_pinned())
1061 @unittest.skipIf(
not TEST_CUDA,
"CUDA unavailable")
1063 def test_custom_batch_pin_worker(self):
1065 (collate_wrapper, SimpleCustomBatch),
1069 for collate_fn, elem_cls
in test_cases:
1070 loader = DataLoader(self.
dataset, batch_size=2, collate_fn=collate_fn,
1071 pin_memory=
True, num_workers=1)
1072 for sample
in loader:
1073 self.assertIsInstance(sample, elem_cls)
1074 self.assertTrue(sample.is_pinned())
1078 def __init__(self, data):
1082 def worker_init_fn(self, worker_id):
1085 def __getitem__(self, item):
1089 return len(self.
data)
1096 def _run_ind_worker_queue_test(self, batch_size, num_workers):
1097 loader = DataLoader(
1098 self.
dataset, batch_size=batch_size, shuffle=
False, num_workers=num_workers,
1099 worker_init_fn=self.dataset.worker_init_fn
1101 current_worker_idx = 0
1102 for i, (worker_ids, sample)
in enumerate(loader):
1103 self.assertEqual(worker_ids.tolist(), [current_worker_idx] * batch_size)
1104 self.assertEqual(sample.tolist(), [j
for j
in range(i * batch_size, (i + 1) * batch_size)])
1105 current_worker_idx += 1
1106 if current_worker_idx == num_workers:
1107 current_worker_idx = 0
1109 def test_ind_worker_queue(self):
1110 for batch_size
in (8, 16, 32, 64):
1111 for num_workers
in range(1, 6):
1115 if __name__ ==
'__main__':
def _test_shuffle(self, loader)
def _test_sequential(self, loader)
def _test_batch_sampler(self, kwargs)
def pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)
def _run_ind_worker_queue_test(self, batch_size, num_workers)
def _test_error(self, loader)