Caffe2 - Python API
A deep learning, cross platform ML framework
test_dataloader.py
1 import math
2 import sys
3 import errno
4 import os
5 import ctypes
6 import signal
7 import torch
8 import gc
9 import time
10 import traceback
11 import unittest
12 import subprocess
13 import itertools
14 import warnings
15 from torch import multiprocessing as mp
16 from torch.utils.data import _utils, Dataset, TensorDataset, DataLoader, ConcatDataset
17 from torch.utils.data._utils import ExceptionWrapper, MP_STATUS_CHECK_INTERVAL
18 from torch.utils.data.dataset import random_split
19 from common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, IS_PPC,
20  IS_PYTORCH_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm,
21  load_tests)
22 
23 try:
24  import psutil
25  HAS_PSUTIL = True
26 except ImportError:
27  HAS_PSUTIL = False
28  err_msg = ("psutil not found. Some critical data loader tests relying on it "
29  "(e.g., TestDataLoader.test_proper_exit) will not run.")
30  if IS_PYTORCH_CI:
31  raise ImportError(err_msg)
32  else:
33  warnings.warn(err_msg)
34 
35 
36 # load_tests from common_utils is used to automatically filter tests for
37 # sharding on sandcastle. This line silences flake warnings
38 load_tests = load_tests
39 
40 # We cannot import TEST_CUDA from common_cuda here, because if we do that,
41 # the TEST_CUDNN line from common_cuda will be executed multiple times
42 # as well during the execution of this test suite, and it will cause
43 # CUDA OOM error on Windows.
44 TEST_CUDA = torch.cuda.is_available()
45 
46 if not NO_MULTIPROCESSING_SPAWN:
47  # Get a multiprocessing context because some test / third party library will
48  # set start_method when imported, and setting again triggers RuntimeError.
49  mp = mp.get_context(method='spawn')
50 
51 
52 JOIN_TIMEOUT = 17.0 if (IS_WINDOWS or IS_PPC) else 13.0
53 
54 
55 class TestDatasetRandomSplit(TestCase):
56  def test_lengths_must_equal_dataset_size(self):
57  with self.assertRaises(ValueError):
58  random_split([1, 2, 3, 4], [1, 2])
59 
60  def test_splits_have_correct_size(self):
61  splits = random_split([1, 2, 3, 4, 5, 6], [2, 4])
62  self.assertEqual(len(splits), 2)
63  self.assertEqual(len(splits[0]), 2)
64  self.assertEqual(len(splits[1]), 4)
65 
66  def test_splits_are_mutually_exclusive(self):
67  data = [5, 2, 3, 4, 1, 6]
68  splits = random_split(data, [2, 4])
69  all_values = []
70  all_values.extend(list(splits[0]))
71  all_values.extend(list(splits[1]))
72  data.sort()
73  all_values.sort()
74  self.assertListEqual(data, all_values)
75 
76  def test_splits_indexing_type(self):
77  r"""Indices generated by random_split
78  should be of integer type
79  """
80  class CustomDataset():
81  def __init__(self, test_object, custom_list):
82  self.data = custom_list
83  self.test_object = test_object
84 
85  def __getitem__(self, key):
86  self.test_object.assertEqual(type(key), type(0))
87  return self.data[key]
88 
89  def __len__(self):
90  return len(self.data)
91 
92  x = [1, 2, 3, 4, 5]
93  dataset = CustomDataset(self, x)
94  dataset = random_split(dataset, [5])[0]
95  data_loader = DataLoader(dataset)
96  for batch in data_loader:
97  pass
98 
99 
100 class TestTensorDataset(TestCase):
101 
102  def test_len(self):
103  source = TensorDataset(torch.randn(15, 10, 2, 3, 4, 5), torch.randperm(15))
104  self.assertEqual(len(source), 15)
105 
106  def test_getitem(self):
107  t = torch.randn(15, 10, 2, 3, 4, 5)
108  l = torch.randn(15, 10)
109  source = TensorDataset(t, l)
110  for i in range(15):
111  self.assertEqual(t[i], source[i][0])
112  self.assertEqual(l[i], source[i][1])
113 
114  def test_getitem_1d(self):
115  t = torch.randn(15)
116  l = torch.randn(15)
117  source = TensorDataset(t, l)
118  for i in range(15):
119  self.assertEqual(t[i], source[i][0])
120  self.assertEqual(l[i], source[i][1])
121 
122  def test_single_tensor(self):
123  t = torch.randn(5, 10)
124  source = TensorDataset(t)
125  self.assertEqual(len(source), 5)
126  for i in range(5):
127  self.assertEqual(t[i], source[i][0])
128 
129  def test_many_tensors(self):
130  t0 = torch.randn(5, 10, 2, 3, 4, 5)
131  t1 = torch.randn(5, 10)
132  t2 = torch.randn(5, 10, 2, 5)
133  t3 = torch.randn(5, 10, 3, 7)
134  source = TensorDataset(t0, t1, t2, t3)
135  self.assertEqual(len(source), 5)
136  for i in range(5):
137  self.assertEqual(t0[i], source[i][0])
138  self.assertEqual(t1[i], source[i][1])
139  self.assertEqual(t2[i], source[i][2])
140  self.assertEqual(t3[i], source[i][3])
141 
142 
143 class TestConcatDataset(TestCase):
144 
145  def test_concat_two_singletons(self):
146  result = ConcatDataset([[0], [1]])
147  self.assertEqual(2, len(result))
148  self.assertEqual(0, result[0])
149  self.assertEqual(1, result[1])
150 
151  def test_concat_two_non_singletons(self):
152  result = ConcatDataset([[0, 1, 2, 3, 4],
153  [5, 6, 7, 8, 9]])
154  self.assertEqual(10, len(result))
155  self.assertEqual(0, result[0])
156  self.assertEqual(5, result[5])
157 
158  def test_concat_two_non_singletons_with_empty(self):
159  # Adding an empty dataset somewhere is correctly handled
160  result = ConcatDataset([[0, 1, 2, 3, 4],
161  [],
162  [5, 6, 7, 8, 9]])
163  self.assertEqual(10, len(result))
164  self.assertEqual(0, result[0])
165  self.assertEqual(5, result[5])
166 
167  def test_concat_raises_index_error(self):
168  result = ConcatDataset([[0, 1, 2, 3, 4],
169  [5, 6, 7, 8, 9]])
170  with self.assertRaises(IndexError):
171  # this one goes to 11
172  result[11]
173 
174  def test_add_dataset(self):
175  d1 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
176  d2 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
177  d3 = TensorDataset(torch.rand(7, 3, 28, 28), torch.rand(7))
178  result = d1 + d2 + d3
179  self.assertEqual(21, len(result))
180  self.assertEqual(0, (d1[0][0] - result[0][0]).abs().sum())
181  self.assertEqual(0, (d2[0][0] - result[7][0]).abs().sum())
182  self.assertEqual(0, (d3[0][0] - result[14][0]).abs().sum())
183 
184 
185 # Stores the first encountered exception in .exception.
186 # Inspired by https://stackoverflow.com/a/33599967
187 class ErrorTrackingProcess(mp.Process):
188 
189  def __init__(self, *args, **kwargs):
190  super(ErrorTrackingProcess, self).__init__(*args, **kwargs)
191  self._pconn, self._cconn = mp.Pipe()
192  self._exception = None
193 
194  def run(self):
195  # Disable polluting stderr with errors that are supposed to happen.
196  sys.stderr = open(os.devnull, "w")
197  try:
198  super(ErrorTrackingProcess, self).run()
199  self._cconn.send(None)
200  except Exception:
201  self._cconn.send(ExceptionWrapper(sys.exc_info()))
202  raise
203 
204  @property
205  def exception(self):
206  if self._pconn.poll():
207  self._exception = self._pconn.recv()
208  if self._exception is None:
209  return None
210  else:
211  return self._exception.exc_type(self._exception.exc_msg)
212 
213  # ESRCH means that os.kill can't finds alive proc
214  def send_signal(self, signum, ignore_ESRCH=False):
215  try:
216  os.kill(self.pid, signum)
217  except OSError as e:
218  if not ignore_ESRCH or e.errno != errno.ESRCH:
219  raise
220 
221 
222 class ErrorDataset(Dataset):
223 
224  def __init__(self, size):
225  self.size = size
226 
227  def __len__(self):
228  return self.size
229 
230 
231 class SegfaultDataset(Dataset):
232 
233  def __init__(self, size):
234  self.size = size
235 
236  def __getitem__(self, idx):
237  return ctypes.string_at(0)
238 
239  def __len__(self):
240  return self.size
241 
242 
243 class SleepDataset(Dataset):
244 
245  def __init__(self, size, sleep_sec):
246  self.size = size
247  self.sleep_sec = sleep_sec
248  self.sleeped = False
249 
250  def __getitem__(self, idx):
251  if not self.sleeped:
252  time.sleep(self.sleep_sec)
253  self.sleeped = True
254  return idx
255 
256  def __len__(self):
257  return self.size
258 
259 
260 class SeedDataset(Dataset):
261 
262  def __init__(self, size):
263  self.size = size
264 
265  def __getitem__(self, idx):
266  return torch.initial_seed()
267 
268  def __len__(self):
269  return self.size
270 
271 
272 # Inspired by https://stackoverflow.com/a/26703365
273 # This will ensure that each worker at least processes one data
274 class SynchronizedSeedDataset(Dataset):
275 
276  def __init__(self, size, num_workers):
277  assert size >= num_workers
278  self.count = mp.Value('i', 0, lock=True)
279  self.barrier = mp.Semaphore(0)
280  self.num_workers = num_workers
281  self.size = size
282 
283  def __getitem__(self, idx):
284  with self.count.get_lock():
285  self.count.value += 1
286  if self.count.value == self.num_workers:
287  self.barrier.release()
288  self.barrier.acquire()
289  self.barrier.release()
290  return torch.initial_seed()
291 
292  def __len__(self):
293  return self.size
294 
295 
296 def _test_timeout():
297  dataset = SleepDataset(10, 3)
298  dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1)
299  _ = next(iter(dataloader))
300 
301 
302 def _test_timeout_pin_memory():
303  dataset = SleepDataset(10, 3)
304  dataloader = DataLoader(dataset, batch_size=2, num_workers=2, timeout=1, pin_memory=True)
305  _ = next(iter(dataloader))
306 
307 
308 def disable_stderr(worker_id):
309  r"""
310  Avoids printing "ERROR: Unexpected segmentation fault encountered in worker."
311  from workers. Since worker signal handler prints with low-level write(),
312  this has to be done on OS level via dup.
313 
314  This is used as worker_init_fn for test_segfault.
315  """
316  sys.stderr.flush() # flush library buffers that dup2 knows nothing about
317  # Can't use a with-block because otherwise the fd will be closed when this
318  # function ends.
319  devnull = open(os.devnull, 'w')
320  os.dup2(devnull.fileno(), sys.stderr.fileno())
321 
322 
323 def _test_segfault():
324  dataset = SegfaultDataset(10)
325  dataloader = DataLoader(dataset, batch_size=2, num_workers=2, worker_init_fn=disable_stderr)
326  _ = next(iter(dataloader))
327 
328 
329 class TestProperExitDataset(object):
330  def __init__(self, size, error_event):
331  self.size = size
332  self.error_event = error_event
333 
334  def __len__(self):
335  return self.size
336 
337  def __getitem__(self, idx):
338  if self.error_event is not None and self.error_event.is_set():
339  raise RuntimeError('Worker error')
340  return torch.tensor([idx])
341 
342 
343 # See TestDataLoader.test_proper_exit for usage
344 def _test_proper_exit(use_workers, pin_memory, exit_method, hold_iter_reference,
345  loader_setup_event, tester_setup_event):
346  num_workers = 2 if use_workers else 0
347 
348  if exit_method == 'worker_error' or exit_method == 'worker_kill':
349  assert use_workers is True
350 
351  if exit_method == 'worker_error':
352  worker_error_event = mp.Event()
353  else:
354  worker_error_event = None
355 
356  ds = TestProperExitDataset(12, worker_error_event)
357 
358  loader = DataLoader(ds, batch_size=1, shuffle=False,
359  num_workers=num_workers, pin_memory=pin_memory)
360  error_it = 2
361 
362  if use_workers:
363  # 2 is the magical per-worker prefetch number...
364  # FIXME: change this after the number becomes configurable.
365  assert len(loader) > (error_it + 2 + 1) * num_workers
366 
367  it = iter(loader)
368  if use_workers:
369  workers = it.workers
370 
371  def kill_pid(pid):
372  psutil_p = psutil.Process(pid)
373  psutil_p.kill()
374  psutil_p.wait(JOIN_TIMEOUT)
375  assert not psutil_p.is_running()
376 
377  for i, _ in enumerate(it):
378  if i == 0:
379  if not hold_iter_reference:
380  del it
381  loader_setup_event.set()
382  tester_setup_event.wait()
383  # ensure that the workers are still alive
384  if use_workers:
385  for w in workers:
386  assert w.is_alive()
387  if worker_error_event is not None:
388  worker_error_event.set()
389 
390  if i == error_it:
391  if exit_method == 'loader_error':
392  raise RuntimeError('Loader error')
393  elif exit_method == 'loader_kill':
394  kill_pid(os.getpid())
395  elif exit_method == 'worker_kill':
396  kill_pid(workers[0].pid)
397 
398  if not hold_iter_reference:
399  # Tries to trigger the __del__ clean-up rather than the automatic
400  # exiting of daemonic children. Technically it should be automatically
401  # triggered, but I don't want to rely on the implementation detail of
402  # Python gc.
403  gc.collect()
404 
405 
406 # test custom init function
407 def init_fn(worker_id):
408  torch.manual_seed(12345)
409 
410 
411 class TestDataLoader(TestCase):
412 
413  def setUp(self):
414  self.data = torch.randn(100, 2, 3, 5)
415  self.labels = torch.randperm(50).repeat(2)
416  self.dataset = TensorDataset(self.data, self.labels)
417 
418  def _test_sequential(self, loader):
419  batch_size = loader.batch_size
420  for i, (sample, target) in enumerate(loader):
421  idx = i * batch_size
422  self.assertEqual(sample, self.data[idx:idx + batch_size])
423  self.assertEqual(target, self.labels[idx:idx + batch_size])
424  self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
425 
426  def _test_shuffle(self, loader):
427  found_data = {i: 0 for i in range(self.data.size(0))}
428  found_labels = {i: 0 for i in range(self.labels.size(0))}
429  batch_size = loader.batch_size
430  for i, (batch_samples, batch_targets) in enumerate(loader):
431  for sample, target in zip(batch_samples, batch_targets):
432  for data_point_idx, data_point in enumerate(self.data):
433  if data_point.eq(sample).all():
434  self.assertFalse(found_data[data_point_idx])
435  found_data[data_point_idx] += 1
436  break
437  self.assertEqual(target, self.labels[data_point_idx])
438  found_labels[data_point_idx] += 1
439  self.assertEqual(sum(found_data.values()), (i + 1) * batch_size)
440  self.assertEqual(sum(found_labels.values()), (i + 1) * batch_size)
441  self.assertEqual(i, math.floor((len(self.dataset) - 1) / batch_size))
442 
443  def _test_error(self, loader):
444  it = iter(loader)
445  errors = 0
446  while True:
447  try:
448  next(it)
449  except NotImplementedError:
450  errors += 1
451  except StopIteration:
452  self.assertEqual(errors,
453  math.ceil(float(len(loader.dataset)) / loader.batch_size))
454  return
455 
456  def test_invalid_assign_after_init(self):
457  dl = DataLoader(self.dataset)
458  for attr in ('batch_size', 'sampler', 'drop_last'):
459  def fn():
460  setattr(dl, attr, {})
461 
462  self.assertRaises(ValueError, fn)
463 
464  def test_sequential(self):
465  self._test_sequential(DataLoader(self.dataset))
466 
467  def test_sequential_batch(self):
468  self._test_sequential(DataLoader(self.dataset, batch_size=2))
469 
470  def test_growing_dataset(self):
471  dataset = [torch.ones(4) for _ in range(4)]
472  dataloader_seq = DataLoader(dataset, shuffle=False)
473  dataloader_shuffle = DataLoader(dataset, shuffle=True)
474  dataset.append(torch.ones(4))
475  self.assertEqual(len(dataloader_seq), 5)
476  self.assertEqual(len(dataloader_shuffle), 5)
477 
478  @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
479  def test_sequential_pin_memory(self):
480  loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
481  for input, target in loader:
482  self.assertTrue(input.is_pinned())
483  self.assertTrue(target.is_pinned())
484 
485  def test_multiple_dataloaders(self):
486  loader1_it = iter(DataLoader(self.dataset, num_workers=1))
487  loader2_it = iter(DataLoader(self.dataset, num_workers=2))
488  next(loader1_it)
489  next(loader1_it)
490  next(loader2_it)
491  next(loader2_it)
492  next(loader1_it)
493  next(loader2_it)
494 
495  @unittest.skip("temporarily disable until flaky failures are fixed")
496  def test_segfault(self):
497  p = ErrorTrackingProcess(target=_test_segfault)
498  p.start()
499  p.join(JOIN_TIMEOUT)
500  try:
501  self.assertFalse(p.is_alive())
502  self.assertNotEqual(p.exitcode, 0)
503  if IS_WINDOWS:
504  self.assertIsInstance(p.exception, OSError)
505  self.assertRegex(str(p.exception), r'access violation reading ')
506  else:
507  self.assertIsInstance(p.exception, RuntimeError)
508  self.assertRegex(str(p.exception), r'DataLoader worker \(pid \d+\) is killed by signal: ')
509  finally:
510  p.terminate()
511 
512  def test_timeout(self):
513  if TEST_CUDA and not NO_MULTIPROCESSING_SPAWN:
514  targets = (_test_timeout, _test_timeout_pin_memory)
515  else:
516  targets = (_test_timeout,)
517  for target in targets:
518  p = ErrorTrackingProcess(target=target)
519  p.start()
520  p.join(JOIN_TIMEOUT)
521  try:
522  self.assertFalse(p.is_alive())
523  self.assertNotEqual(p.exitcode, 0)
524  self.assertIsInstance(p.exception, RuntimeError)
525  self.assertRegex(str(p.exception), r'DataLoader timed out after \d+ seconds')
526  finally:
527  p.terminate()
528 
529  def test_worker_seed(self):
530  num_workers = 6
531  dataset = SynchronizedSeedDataset(num_workers, num_workers)
532  dataloader = DataLoader(dataset, batch_size=1, num_workers=num_workers)
533  seeds = set()
534  for batch in dataloader:
535  seeds.add(batch[0])
536  self.assertEqual(len(seeds), num_workers)
537 
538  def test_worker_init_fn(self):
539  dataset = SeedDataset(4)
540  dataloader = DataLoader(dataset, batch_size=2, num_workers=2,
541  worker_init_fn=init_fn)
542  for batch in dataloader:
543  self.assertEqual(12345, batch[0])
544  self.assertEqual(12345, batch[1])
545 
546  def test_shuffle(self):
547  self._test_shuffle(DataLoader(self.dataset, shuffle=True))
548 
549  def test_shuffle_batch(self):
550  self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True))
551 
552  def test_sequential_workers(self):
553  self._test_sequential(DataLoader(self.dataset, num_workers=4))
554 
555  def test_seqential_batch_workers(self):
556  self._test_sequential(DataLoader(self.dataset, batch_size=2, num_workers=4))
557 
558  def test_shuffle_workers(self):
559  self._test_shuffle(DataLoader(self.dataset, shuffle=True, num_workers=4))
560 
561  def test_shuffle_batch_workers(self):
562  self._test_shuffle(DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4))
563 
564  def _test_batch_sampler(self, **kwargs):
565  # [(0, 1), (2, 3, 4), (5, 6), (7, 8, 9), ...]
566  batches = []
567  for i in range(0, 100, 5):
568  batches.append(tuple(range(i, i + 2)))
569  batches.append(tuple(range(i + 2, i + 5)))
570 
571  dl = DataLoader(self.dataset, batch_sampler=batches, **kwargs)
572  self.assertEqual(len(dl), 40)
573  for i, (input, _target) in enumerate(dl):
574  if i % 2 == 0:
575  offset = i * 5 // 2
576  self.assertEqual(len(input), 2)
577  self.assertEqual(input, self.data[offset:offset + 2])
578  else:
579  offset = i * 5 // 2
580  self.assertEqual(len(input), 3)
581  self.assertEqual(input, self.data[offset:offset + 3])
582 
583  def test_RandomSampler(self):
584 
585  from collections import Counter
586  from torch.utils.data import RandomSampler
587 
588  def sample_stat(sampler, num_samples):
589  counts = Counter(sampler)
590  count_repeated = sum(val > 1 for val in counts.values())
591  return (count_repeated, min(counts.keys()), max(counts.keys()))
592 
593  # test sample with replacement
594  n = len(self.dataset) + 1 # ensure at least one sample is drawn more than once
595  sampler_with_replacement = RandomSampler(self.dataset, replacement=True, num_samples=n)
596  count_repeated, minval, maxval = sample_stat(sampler_with_replacement, n)
597  self.assertTrue(count_repeated > 0)
598  self.assertTrue(minval >= 0)
599  self.assertTrue(maxval < len(self.dataset))
600 
601  # test sample without replacement
602  sampler_without_replacement = RandomSampler(self.dataset)
603  count_repeated, minval, maxval = sample_stat(sampler_without_replacement, len(self.dataset))
604  self.assertTrue(count_repeated == 0)
605  self.assertTrue(minval == 0)
606  self.assertTrue(maxval == len(self.dataset) - 1)
607 
608  # raise error when replacement=False and num_samples is not None
609  self.assertRaises(ValueError, lambda: RandomSampler(self.dataset, num_samples=len(self.dataset)))
610 
611  self.assertRaises(ValueError, lambda: RandomSampler(self.dataset, num_samples=0))
612 
613  def test_random_sampler_len_with_replacement(self):
614  from torch.utils.data import RandomSampler
615  # add 5 extra samples
616  num_samples = len(self.dataset) + 5
617  sampler = RandomSampler(self.dataset,
618  replacement=True,
619  num_samples=num_samples)
620  # test len method
621  self.assertEqual(num_samples, len(sampler))
622 
623  # test with iteration
624  count_num_samples = sum(1 for _ in sampler)
625  self.assertEqual(num_samples, count_num_samples)
626 
627  # test with dataloader, batch_size = 1
628  batch_size = 1
629  count_num_samples_in_data_loader = len(DataLoader(
630  self.dataset, batch_size=batch_size, sampler=sampler))
631  self.assertEqual(num_samples, count_num_samples_in_data_loader)
632 
633  # test with dataloader, batch_size = 6
634  batch_size = 6
635  count_num_samples_in_data_loader = len(DataLoader(
636  self.dataset, batch_size=batch_size, sampler=sampler))
637  self.assertEqual(int(math.ceil(float(num_samples) / batch_size)),
638  count_num_samples_in_data_loader)
639 
640  def test_duplicating_data_with_drop_last(self):
641 
642  from torch.utils.data.distributed import DistributedSampler
643 
644  num_processes = 4
645  num_batches = 9
646  data_set = torch.IntTensor(range(num_batches))
647  scanned_data = torch.IntTensor([])
648  for i in range(num_processes):
649  s = DistributedSampler(data_set, num_processes, i)
650  d_loader = DataLoader(data_set, batch_size=int(num_batches / num_processes), drop_last=True, sampler=s)
651  for data in d_loader:
652  scanned_data = torch.cat((scanned_data, data), 0)
653 
654  self.assertEqual(scanned_data.size(), scanned_data.unique().size())
655 
656  @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
657  don't support multiprocessing with spawn start method")
658  def test_batch_sampler(self):
659  self._test_batch_sampler()
660  self._test_batch_sampler(num_workers=4)
661 
662  @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
663  def test_shuffle_pin_memory(self):
664  loader = DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
665  for input, target in loader:
666  self.assertTrue(input.is_pinned())
667  self.assertTrue(target.is_pinned())
668 
669  @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
670  def test_numpy(self):
671  import numpy as np
672 
673  class TestDataset(torch.utils.data.Dataset):
674  def __getitem__(self, i):
675  return np.ones((2, 3, 4)) * i
676 
677  def __len__(self):
678  return 1000
679 
680  loader = DataLoader(TestDataset(), batch_size=12)
681  batch = next(iter(loader))
682  self.assertIsInstance(batch, torch.DoubleTensor)
683  self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))
684 
685  def test_error(self):
686  self._test_error(DataLoader(ErrorDataset(100), batch_size=2, shuffle=True))
687 
688  @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
689  don't support multiprocessing with spawn start method")
690  def test_error_workers(self):
691  self._test_error(DataLoader(ErrorDataset(41), batch_size=2, shuffle=True, num_workers=4))
692 
693  @unittest.skipIf(IS_WINDOWS, "FIXME: stuck test")
694  def test_partial_workers(self):
695  r"""Check that workers exit even if the iterator is not exhausted."""
696  if TEST_CUDA:
697  pin_memory_configs = (True, False)
698  else:
699  pin_memory_configs = (False,)
700 
701  for pin_memory in pin_memory_configs:
702  loader = iter(DataLoader(self.dataset, batch_size=2, num_workers=4, pin_memory=pin_memory))
703  workers = loader.workers
704  if pin_memory:
705  pin_memory_thread = loader.pin_memory_thread
706  for i, _ in enumerate(loader):
707  if i == 10:
708  break
709  assert i == 10
710  del loader
711  for w in workers:
712  w.join(JOIN_TIMEOUT)
713  self.assertFalse(w.is_alive(), 'subprocess not terminated')
714  if pin_memory:
715  pin_memory_thread.join(JOIN_TIMEOUT)
716  self.assertFalse(pin_memory_thread.is_alive())
717 
718  @skipIfRocm
719  @unittest.skipIf(not HAS_PSUTIL, "psutil not found")
720  def test_proper_exit(self):
721  (r'''There might be ConnectionResetError or leaked semaphore warning '''
722  r'''(due to dirty process exit), but they are all safe to ignore''')
723 
724  # TODO: test the case where the pin_memory_thread triggers an
725  # error/fatal signal. I haven't found out how to properly do that.
726 
727  for use_workers, pin_memory, hold_iter_reference in itertools.product([True, False], repeat=3):
728  # `hold_iter_reference` specifies whether we hold a reference to the
729  # iterator. This is interesting because Python3 error traces holds a
730  # reference to the frames, which hold references to all the local
731  # variables including the iterator, and then the iterator dtor may
732  # not be called before process end. It is important to see that the
733  # processes still exit in both cases.
734 
735  if pin_memory and (not TEST_CUDA or NO_MULTIPROCESSING_SPAWN):
736  # Can't use CUDA without spawn
737  continue
738 
739  # `exit_method` controls the way the loader process ends.
740  # - `*_kill` means that `*` is killed by OS.
741  # - `*_error` means that `*` raises an error.
742  # - `None` means that no error happens.
743  # In all cases, all processes should end properly.
744  if use_workers:
745  exit_methods = [None, 'loader_error', 'loader_kill', 'worker_kill', 'worker_error']
746  else:
747  exit_methods = [None, 'loader_error', 'loader_kill']
748 
749  for exit_method in exit_methods:
750 
751  desc = []
752  desc.append('use_workers={}'.format(use_workers))
753  desc.append('pin_memory={}'.format(pin_memory))
754  desc.append('hold_iter_reference={}'.format(hold_iter_reference))
755  desc.append('exit_method={}'.format(exit_method))
756  desc = 'test_proper_exit with ' + ', '.join(desc)
757 
758  # Event that the loader process uses to signal testing process
759  # that various things are setup, including that the worker pids
760  # are specified in `worker_pids` array.
761  loader_setup_event = mp.Event()
762 
763  # Event that this process has finished setting up, and the
764  # loader process can now proceed to trigger error events or
765  # finish normally.
766  tester_setup_event = mp.Event()
767 
768  loader_p = ErrorTrackingProcess(target=_test_proper_exit,
769  args=(use_workers, pin_memory, exit_method,
770  hold_iter_reference, loader_setup_event,
771  tester_setup_event))
772  loader_p.start()
773 
774  # Wait for loader process to set everything up, e.g., starting
775  # workers.
776  loader_setup_event.wait(timeout=JOIN_TIMEOUT)
777  if not loader_setup_event.is_set():
778  fail_msg = desc + ': loader process failed to setup within given time'
779  if loader_p.exception is not None:
780  self.fail(fail_msg + ', and had exception {}'.format(loader_p.exception))
781  elif not loader_p.is_alive():
782  self.fail(fail_msg + ', and exited with code {} but had no exception'.format(loader_p.exitcode))
783  else:
784  self.fail(fail_msg + ', and is still alive.')
785 
786  worker_psutil_p = psutil.Process(loader_p.pid).children()
787 
788  tester_setup_event.set()
789 
790  try:
791  loader_p.join(JOIN_TIMEOUT + MP_STATUS_CHECK_INTERVAL)
792  if loader_p.is_alive():
793  fail_msg = desc + ': loader process did not terminate'
794  if loader_p.exception is not None:
795  self.fail(fail_msg + ', and had exception {}'.format(loader_p.exception))
796  else:
797  self.fail(fail_msg + ', and had no exception')
798  _, alive = psutil.wait_procs(worker_psutil_p, timeout=(MP_STATUS_CHECK_INTERVAL + JOIN_TIMEOUT))
799  if len(alive) > 0:
800  self.fail(desc + ': worker process (pid(s) {}) did not terminate'.format(
801  ', '.join(str(p.pid) for p in alive)))
802  if exit_method is None:
803  self.assertEqual(loader_p.exitcode, 0)
804  else:
805  self.assertNotEqual(loader_p.exitcode, 0)
806  if exit_method == 'loader_error':
807  self.assertIsInstance(loader_p.exception, RuntimeError, desc)
808  self.assertIn('Loader error', str(loader_p.exception), desc)
809  elif exit_method == 'worker_kill':
810  self.assertIsInstance(loader_p.exception, RuntimeError, desc)
811  self.assertIn('DataLoader worker (pid', str(loader_p.exception), desc)
812  elif exit_method == 'worker_error':
813  self.assertIsInstance(loader_p.exception, RuntimeError, desc)
814  self.assertIn('Worker error', str(loader_p.exception), desc)
815  finally:
816  loader_p.terminate()
817 
818  def test_len(self):
819  def check_len(dl, expected):
820  self.assertEqual(len(dl), expected)
821  n = 0
822  for _ in dl:
823  n += 1
824  self.assertEqual(n, expected)
825  check_len(self.dataset, 100)
826  check_len(DataLoader(self.dataset, batch_size=2), 50)
827  check_len(DataLoader(self.dataset, batch_size=3), 34)
828 
829  @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
830  def test_numpy_scalars(self):
831  import numpy as np
832 
833  class ScalarDataset(torch.utils.data.Dataset):
834  def __init__(self, dtype):
835  self.dtype = dtype
836 
837  def __getitem__(self, i):
838  return self.dtype()
839 
840  def __len__(self):
841  return 4
842 
843  dtypes = {
844  np.float64: torch.DoubleTensor,
845  np.float32: torch.FloatTensor,
846  np.float16: torch.HalfTensor,
847  np.int64: torch.LongTensor,
848  np.int32: torch.IntTensor,
849  np.int16: torch.ShortTensor,
850  np.int8: torch.CharTensor,
851  np.uint8: torch.ByteTensor,
852  }
853  for dt, tt in dtypes.items():
854  dset = ScalarDataset(dt)
855  loader = DataLoader(dset, batch_size=2)
856  batch = next(iter(loader))
857  self.assertIsInstance(batch, tt)
858 
859  def test_default_collate_dtype(self):
860  arr = [1, 2, -1]
861  collated = _utils.collate.default_collate(arr)
862  self.assertEqual(collated, torch.tensor(arr))
863  self.assertEqual(collated.dtype, torch.int64)
864 
865  arr = [1.1, 2.3, -0.9]
866  collated = _utils.collate.default_collate(arr)
867  self.assertEqual(collated, torch.tensor(arr))
868  self.assertEqual(collated.dtype, torch.float64)
869 
870  arr = [True, False]
871  collated = _utils.collate.default_collate(arr)
872  self.assertEqual(collated, torch.tensor(arr))
873  self.assertEqual(collated.dtype, torch.uint8)
874 
875  # Should be a no-op
876  arr = ['a', 'b', 'c']
877  self.assertEqual(arr, _utils.collate.default_collate(arr))
878 
879  @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
880  def test_default_collate_bad_numpy_types(self):
881  import numpy as np
882 
883  # Should be a no-op
884  arr = np.array(['a', 'b', 'c'])
885  self.assertEqual(arr, _utils.collate.default_collate(arr))
886 
887  arr = np.array([[['a', 'b', 'c']]])
888  self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
889 
890  arr = np.array([object(), object(), object()])
891  self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
892 
893  arr = np.array([[[object(), object(), object()]]])
894  self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
895 
896  @unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
897  def test_default_collate_shared_tensor(self):
898  import numpy as np
899  t_in = torch.zeros(1)
900  n_in = np.zeros(1)
901 
902  self.assertEqual(t_in.is_shared(), False)
903 
904  self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), False)
905  self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), False)
906 
907  old = _utils.collate._use_shared_memory
908  try:
909  _utils.collate._use_shared_memory = True
910  self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), True)
911  self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), True)
912  finally:
913  _utils.collate._use_shared_memory = old
914 
915 
916 class StringDataset(Dataset):
917  def __init__(self):
918  self.s = '12345'
919 
920  def __len__(self):
921  return len(self.s)
922 
923  def __getitem__(self, ndx):
924  return (self.s[ndx], ndx)
925 
926 
927 class TestStringDataLoader(TestCase):
928  def setUp(self):
929  self.dataset = StringDataset()
930 
931  @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
932  def test_shuffle_pin_memory(self):
933  loader = DataLoader(self.dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
934  for (s, n) in loader:
935  self.assertIsInstance(s[0], str)
936  self.assertTrue(n.is_pinned())
937 
938 
939 class DictDataset(Dataset):
940  def __len__(self):
941  return 4
942 
943  def __getitem__(self, ndx):
944  return {
945  'a_tensor': torch.Tensor(4, 2).fill_(ndx),
946  'another_dict': {
947  'a_number': ndx,
948  },
949  }
950 
951 
952 class TestDictDataLoader(TestCase):
953  def setUp(self):
954  self.dataset = DictDataset()
955 
956  def test_sequential_batch(self):
957  loader = DataLoader(self.dataset, batch_size=2, shuffle=False)
958  batch_size = loader.batch_size
959  for i, sample in enumerate(loader):
960  idx = i * batch_size
961  self.assertEqual(set(sample.keys()), {'a_tensor', 'another_dict'})
962  self.assertEqual(set(sample['another_dict'].keys()), {'a_number'})
963 
964  t = sample['a_tensor']
965  self.assertEqual(t.size(), torch.Size([batch_size, 4, 2]))
966  self.assertTrue((t[0] == idx).all())
967  self.assertTrue((t[1] == idx + 1).all())
968 
969  n = sample['another_dict']['a_number']
970  self.assertEqual(n.size(), torch.Size([batch_size]))
971  self.assertEqual(n[0], idx)
972  self.assertEqual(n[1], idx + 1)
973 
974  @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
975  def test_pin_memory(self):
976  loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
977  for sample in loader:
978  self.assertTrue(sample['a_tensor'].is_pinned())
979  self.assertTrue(sample['another_dict']['a_number'].is_pinned())
980 
981 
982 class NamedTupleDataset(Dataset):
983  from collections import namedtuple
984  Batch = namedtuple('Batch', ['data', 'label'])
985  Data = namedtuple('Data', ['positive', 'negative'])
986 
987  def __len__(self):
988  return 4
989 
990  def __getitem__(self, ndx):
991  return self.Batch(data=self.Data(positive=ndx, negative=-ndx),
992  label=str(ndx))
993 
994 
995 class TestNamedTupleDataLoader(TestCase):
996  def setUp(self):
997  self.dataset = NamedTupleDataset()
998 
999  @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1000  def test_collate_and_pin_memory_with_namedtuple(self):
1001  loader = DataLoader(self.dataset, batch_size=2, pin_memory=True)
1002  for batch in loader:
1003  self.assertIsInstance(batch, NamedTupleDataset.Batch)
1004  self.assertIsInstance(batch.data, NamedTupleDataset.Data)
1005 
1006 
1007 class SimpleCustomBatch(object):
1008  def __init__(self, data):
1009  transposed_data = list(zip(*data))
1010  self.inp = torch.stack(transposed_data[0], 0)
1011  self.tgt = torch.stack(transposed_data[1], 0)
1012 
1013  def pin_memory(self):
1014  self.inp = self.inp.pin_memory()
1015  self.tgt = self.tgt.pin_memory()
1016  return self
1017 
1018  def is_pinned(self):
1019  return self.inp.is_pinned() and self.tgt.is_pinned()
1020 
1021 
1022 def collate_wrapper(batch):
1023  return SimpleCustomBatch(batch)
1024 
1025 
1026 def collate_into_packed_sequence(batch):
1027  data = torch.stack([sample[0] for sample in batch], 1)
1028  t, b = data.size()
1029  lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
1030  return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, enforce_sorted=False)
1031 
1032 
1033 def collate_into_packed_sequence_batch_first(batch):
1034  data = torch.stack([sample[0] for sample in batch], 0)
1035  b, t = data.size()
1036  lengths = torch.randint(1, t, size=(b,), dtype=torch.int64)
1037  return torch.nn.utils.rnn.pack_padded_sequence(data, lengths, batch_first=True, enforce_sorted=False)
1038 
1039 
1040 class TestCustomPinFn(TestCase):
1041  def setUp(self):
1042  inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
1043  tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
1044  self.dataset = TensorDataset(inps, tgts)
1045 
1046  @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1047  @skipIfRocm
1048  def test_custom_batch_pin(self):
1049  test_cases = [
1050  (collate_wrapper, SimpleCustomBatch),
1051  (collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
1052  (collate_into_packed_sequence_batch_first, torch.nn.utils.rnn.PackedSequence),
1053  ]
1054  for collate_fn, elem_cls in test_cases:
1055  loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_fn,
1056  pin_memory=True)
1057  for sample in loader:
1058  self.assertIsInstance(sample, elem_cls)
1059  self.assertTrue(sample.is_pinned())
1060 
1061  @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1062  @skipIfRocm
1063  def test_custom_batch_pin_worker(self):
1064  test_cases = [
1065  (collate_wrapper, SimpleCustomBatch),
1066  (collate_into_packed_sequence, torch.nn.utils.rnn.PackedSequence),
1067  (collate_into_packed_sequence_batch_first, torch.nn.utils.rnn.PackedSequence),
1068  ]
1069  for collate_fn, elem_cls in test_cases:
1070  loader = DataLoader(self.dataset, batch_size=2, collate_fn=collate_fn,
1071  pin_memory=True, num_workers=1)
1072  for sample in loader:
1073  self.assertIsInstance(sample, elem_cls)
1074  self.assertTrue(sample.is_pinned())
1075 
1076 
1077 class TestWorkerQueueDataset(Dataset):
1078  def __init__(self, data):
1079  self.data = data
1080  self.worker_id = None
1081 
1082  def worker_init_fn(self, worker_id):
1083  self.worker_id = worker_id
1084 
1085  def __getitem__(self, item):
1086  return self.worker_id, self.data[item]
1087 
1088  def __len__(self):
1089  return len(self.data)
1090 
1091 
1093  def setUp(self):
1094  self.dataset = TestWorkerQueueDataset([i for i in range(128)])
1095 
1096  def _run_ind_worker_queue_test(self, batch_size, num_workers):
1097  loader = DataLoader(
1098  self.dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers,
1099  worker_init_fn=self.dataset.worker_init_fn
1100  )
1101  current_worker_idx = 0
1102  for i, (worker_ids, sample) in enumerate(loader):
1103  self.assertEqual(worker_ids.tolist(), [current_worker_idx] * batch_size)
1104  self.assertEqual(sample.tolist(), [j for j in range(i * batch_size, (i + 1) * batch_size)])
1105  current_worker_idx += 1
1106  if current_worker_idx == num_workers:
1107  current_worker_idx = 0
1108 
1109  def test_ind_worker_queue(self):
1110  for batch_size in (8, 16, 32, 64):
1111  for num_workers in range(1, 6):
1112  self._run_ind_worker_queue_test(batch_size=batch_size, num_workers=num_workers)
1113 
1114 
1115 if __name__ == '__main__':
1116  run_tests()
def is_available()
Definition: __init__.py:45
def _test_sequential(self, loader)
def _test_batch_sampler(self, kwargs)
def pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)
Definition: rnn.py:221
def _run_ind_worker_queue_test(self, batch_size, num_workers)