1 r"""Definition of the DataLoader and it's iterator _DataLoaderIter classes. 3 To support these two classes, in `./_utils` we define many utility methods and 4 functions to be run in multiprocessing. E.g., the data loading worker loop is 5 in `./_utils/worker.py`. 10 from .
import SequentialSampler, RandomSampler, BatchSampler
21 default_collate = _utils.collate.default_collate
26 Data loader. Combines a dataset and a sampler, and provides 27 single- or multi-process iterators over the dataset. 30 dataset (Dataset): dataset from which to load the data. 31 batch_size (int, optional): how many samples per batch to load 33 shuffle (bool, optional): set to ``True`` to have the data reshuffled 34 at every epoch (default: ``False``). 35 sampler (Sampler, optional): defines the strategy to draw samples from 36 the dataset. If specified, ``shuffle`` must be False. 37 batch_sampler (Sampler, optional): like sampler, but returns a batch of 38 indices at a time. Mutually exclusive with :attr:`batch_size`, 39 :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`. 40 num_workers (int, optional): how many subprocesses to use for data 41 loading. 0 means that the data will be loaded in the main process. 43 collate_fn (callable, optional): merges a list of samples to form a mini-batch. 44 pin_memory (bool, optional): If ``True``, the data loader will copy tensors 45 into CUDA pinned memory before returning them. If your data elements 46 are a custom type, or your ``collate_fn`` returns a batch that is a custom type 47 see the example below. 48 drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, 49 if the dataset size is not divisible by the batch size. If ``False`` and 50 the size of dataset is not divisible by the batch size, then the last batch 51 will be smaller. (default: ``False``) 52 timeout (numeric, optional): if positive, the timeout value for collecting a batch 53 from workers. Should always be non-negative. (default: ``0``) 54 worker_init_fn (callable, optional): If not ``None``, this will be called on each 55 worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as 56 input, after seeding and before data loading. (default: ``None``) 58 .. note:: By default, each worker will have its PyTorch seed set to 59 ``base_seed + worker_id``, where ``base_seed`` is a long generated 60 by main process using its RNG. However, seeds for other libraies 61 may be duplicated upon initializing workers (w.g., NumPy), causing 62 each worker to return identical random numbers. (See 63 :ref:`dataloader-workers-random-seed` section in FAQ.) You may 64 use :func:`torch.initial_seed()` to access the PyTorch seed for 65 each worker in :attr:`worker_init_fn`, and use it to set other 66 seeds before data loading. 68 .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an 69 unpicklable object, e.g., a lambda function. 71 The default memory pinning logic only recognizes Tensors and maps and iterables 72 containg Tensors. By default, if the pinning logic sees a batch that is a custom type 73 (which will occur if you have a ``collate_fn`` that returns a custom batch type), 74 or if each element of your batch is a custom type, the pinning logic will not 75 recognize them, and it will return that batch (or those elements) 76 without pinning the memory. To enable memory pinning for custom batch or data types, 77 define a ``pin_memory`` method on your custom type(s). 81 class SimpleCustomBatch: 82 def __init__(self, data): 83 transposed_data = list(zip(*data)) 84 self.inp = torch.stack(transposed_data[0], 0) 85 self.tgt = torch.stack(transposed_data[1], 0) 88 self.inp = self.inp.pin_memory() 89 self.tgt = self.tgt.pin_memory() 92 def collate_wrapper(batch): 93 return SimpleCustomBatch(batch) 95 inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) 96 tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) 97 dataset = TensorDataset(inps, tgts) 99 loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper, 102 for batch_ndx, sample in enumerate(loader): 103 print(sample.inp.is_pinned()) 104 print(sample.tgt.is_pinned()) 108 __initialized =
False 110 def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
111 batch_sampler=
None, num_workers=0, collate_fn=default_collate,
112 pin_memory=
False, drop_last=
False, timeout=0,
113 worker_init_fn=
None):
124 raise ValueError(
'timeout option should be non-negative')
126 if batch_sampler
is not None:
127 if batch_size > 1
or shuffle
or sampler
is not None or drop_last:
128 raise ValueError(
'batch_sampler option is mutually exclusive ' 129 'with batch_size, shuffle, sampler, and ' 134 if sampler
is not None and shuffle:
135 raise ValueError(
'sampler option is mutually exclusive with ' 139 raise ValueError(
'num_workers option cannot be negative; ' 140 'use num_workers=0 to disable multiprocessing.')
142 if batch_sampler
is None:
145 sampler = RandomSampler(dataset)
147 sampler = SequentialSampler(dataset)
148 batch_sampler = BatchSampler(sampler, batch_size, drop_last)
154 def __setattr__(self, attr, val):
155 if self.
__initialized and attr
in (
'batch_size',
'sampler',
'drop_last'):
156 raise ValueError(
'{} attribute should not be set after {} is ' 157 'initialized'.format(attr, self.__class__.__name__))
159 super(DataLoader, self).__setattr__(attr, val)
169 r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" 396 def __init__(self, loader):
406 base_seed = torch.LongTensor(1).random_().item()
423 index_queue = multiprocessing.Queue()
424 index_queue.cancel_join_thread()
425 w = multiprocessing.Process(
426 target=_utils.worker._worker_loop,
427 args=(self.
dataset, index_queue,
439 self.index_queues.append(index_queue)
440 self.workers.append(w)
444 pin_memory_thread = threading.Thread(
445 target=_utils.pin_memory._pin_memory_loop,
448 pin_memory_thread.daemon =
True 449 pin_memory_thread.start()
456 _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid
for w
in self.
workers))
457 _utils.signal_handling._set_SIGCHLD_handler()
467 def _try_get_batch(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
480 data = self.data_queue.get(timeout=timeout)
482 except Exception
as e:
486 if not all(w.is_alive()
for w
in self.
workers):
487 pids_str =
', '.join(str(w.pid)
for w
in self.
workers if not w.is_alive())
488 raise RuntimeError(
'DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))
489 if isinstance(e, queue.Empty):
493 def _get_batch(self):
509 raise RuntimeError(
'DataLoader timed out after {} seconds'.format(self.
timeout))
511 while self.pin_memory_thread.is_alive():
517 raise RuntimeError(
'Pin memory thread exited unexpectedly')
531 batch = _utils.pin_memory.pin_memory_batch(batch)
536 batch = self.reorder_dict.pop(self.
rcvd_idx)
558 def _put_indices(self):
568 def _process_next_batch(self, batch):
572 raise batch.exc_type(batch.exc_msg)
575 def __getstate__(self):
581 raise NotImplementedError(
"_DataLoaderIter cannot be pickled")
583 def _shutdown_workers(self):
586 python_exit_status = _utils.python_exit_status
587 if python_exit_status
is True or python_exit_status
is None:
597 _utils.signal_handling._remove_worker_pids(id(self))
600 self.done_event.set()
605 if hasattr(self,
'pin_memory_thread'):
610 self.worker_result_queue.cancel_join_thread()
611 self.worker_result_queue.put(
None)
612 self.pin_memory_thread.join()
619 self.worker_result_queue.close()
def _shutdown_workers(self)
def _try_get_batch(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL)
def _process_next_batch(self, batch)