Caffe2 - Python API
A deep learning, cross platform ML framework
dataloader.py
1 r"""Definition of the DataLoader and it's iterator _DataLoaderIter classes.
2 
3 To support these two classes, in `./_utils` we define many utility methods and
4 functions to be run in multiprocessing. E.g., the data loading worker loop is
5 in `./_utils/worker.py`.
6 """
7 
8 import torch
9 import torch.multiprocessing as multiprocessing
10 from . import SequentialSampler, RandomSampler, BatchSampler
11 from . import _utils
12 import threading
13 from torch._six import queue
14 
15 
16 # This function used to be defined in this file. However, it was moved to
17 # _utils/collate.py. Although it is rather hard to access this from user land
18 # (one has to explicitly directly `import torch.utils.data.dataloader`), there
19 # probably is user code out there using it. This aliasing maintains BC in this
20 # aspect.
21 default_collate = _utils.collate.default_collate
22 
23 
24 class DataLoader(object):
25  r"""
26  Data loader. Combines a dataset and a sampler, and provides
27  single- or multi-process iterators over the dataset.
28 
29  Arguments:
30  dataset (Dataset): dataset from which to load the data.
31  batch_size (int, optional): how many samples per batch to load
32  (default: ``1``).
33  shuffle (bool, optional): set to ``True`` to have the data reshuffled
34  at every epoch (default: ``False``).
35  sampler (Sampler, optional): defines the strategy to draw samples from
36  the dataset. If specified, ``shuffle`` must be False.
37  batch_sampler (Sampler, optional): like sampler, but returns a batch of
38  indices at a time. Mutually exclusive with :attr:`batch_size`,
39  :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
40  num_workers (int, optional): how many subprocesses to use for data
41  loading. 0 means that the data will be loaded in the main process.
42  (default: ``0``)
43  collate_fn (callable, optional): merges a list of samples to form a mini-batch.
44  pin_memory (bool, optional): If ``True``, the data loader will copy tensors
45  into CUDA pinned memory before returning them. If your data elements
46  are a custom type, or your ``collate_fn`` returns a batch that is a custom type
47  see the example below.
48  drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
49  if the dataset size is not divisible by the batch size. If ``False`` and
50  the size of dataset is not divisible by the batch size, then the last batch
51  will be smaller. (default: ``False``)
52  timeout (numeric, optional): if positive, the timeout value for collecting a batch
53  from workers. Should always be non-negative. (default: ``0``)
54  worker_init_fn (callable, optional): If not ``None``, this will be called on each
55  worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
56  input, after seeding and before data loading. (default: ``None``)
57 
58  .. note:: By default, each worker will have its PyTorch seed set to
59  ``base_seed + worker_id``, where ``base_seed`` is a long generated
60  by main process using its RNG. However, seeds for other libraies
61  may be duplicated upon initializing workers (w.g., NumPy), causing
62  each worker to return identical random numbers. (See
63  :ref:`dataloader-workers-random-seed` section in FAQ.) You may
64  use :func:`torch.initial_seed()` to access the PyTorch seed for
65  each worker in :attr:`worker_init_fn`, and use it to set other
66  seeds before data loading.
67 
68  .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
69  unpicklable object, e.g., a lambda function.
70 
71  The default memory pinning logic only recognizes Tensors and maps and iterables
72  containg Tensors. By default, if the pinning logic sees a batch that is a custom type
73  (which will occur if you have a ``collate_fn`` that returns a custom batch type),
74  or if each element of your batch is a custom type, the pinning logic will not
75  recognize them, and it will return that batch (or those elements)
76  without pinning the memory. To enable memory pinning for custom batch or data types,
77  define a ``pin_memory`` method on your custom type(s).
78 
79  Example::
80 
81  class SimpleCustomBatch:
82  def __init__(self, data):
83  transposed_data = list(zip(*data))
84  self.inp = torch.stack(transposed_data[0], 0)
85  self.tgt = torch.stack(transposed_data[1], 0)
86 
87  def pin_memory(self):
88  self.inp = self.inp.pin_memory()
89  self.tgt = self.tgt.pin_memory()
90  return self
91 
92  def collate_wrapper(batch):
93  return SimpleCustomBatch(batch)
94 
95  inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
96  tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
97  dataset = TensorDataset(inps, tgts)
98 
99  loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
100  pin_memory=True)
101 
102  for batch_ndx, sample in enumerate(loader):
103  print(sample.inp.is_pinned())
104  print(sample.tgt.is_pinned())
105 
106  """
107 
108  __initialized = False
109 
110  def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
111  batch_sampler=None, num_workers=0, collate_fn=default_collate,
112  pin_memory=False, drop_last=False, timeout=0,
113  worker_init_fn=None):
114  self.dataset = dataset
115  self.batch_size = batch_size
116  self.num_workers = num_workers
117  self.collate_fn = collate_fn
118  self.pin_memory = pin_memory
119  self.drop_last = drop_last
120  self.timeout = timeout
121  self.worker_init_fn = worker_init_fn
122 
123  if timeout < 0:
124  raise ValueError('timeout option should be non-negative')
125 
126  if batch_sampler is not None:
127  if batch_size > 1 or shuffle or sampler is not None or drop_last:
128  raise ValueError('batch_sampler option is mutually exclusive '
129  'with batch_size, shuffle, sampler, and '
130  'drop_last')
131  self.batch_size = None
132  self.drop_last = None
133 
134  if sampler is not None and shuffle:
135  raise ValueError('sampler option is mutually exclusive with '
136  'shuffle')
137 
138  if self.num_workers < 0:
139  raise ValueError('num_workers option cannot be negative; '
140  'use num_workers=0 to disable multiprocessing.')
141 
142  if batch_sampler is None:
143  if sampler is None:
144  if shuffle:
145  sampler = RandomSampler(dataset)
146  else:
147  sampler = SequentialSampler(dataset)
148  batch_sampler = BatchSampler(sampler, batch_size, drop_last)
149 
150  self.sampler = sampler
151  self.batch_sampler = batch_sampler
152  self.__initialized = True
153 
154  def __setattr__(self, attr, val):
155  if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
156  raise ValueError('{} attribute should not be set after {} is '
157  'initialized'.format(attr, self.__class__.__name__))
158 
159  super(DataLoader, self).__setattr__(attr, val)
160 
161  def __iter__(self):
162  return _DataLoaderIter(self)
163 
164  def __len__(self):
165  return len(self.batch_sampler)
166 
167 
168 class _DataLoaderIter(object):
169  r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
170 
171  # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
172  #
173  # Preliminary:
174  #
175  # Our data model looks like this (queues are indicated with curly brackets):
176  #
177  # main process ||
178  # | ||
179  # {index_queue} ||
180  # | ||
181  # worker processes || DATA
182  # | ||
183  # {worker_result_queue} || FLOW
184  # | ||
185  # pin_memory_thread of main process || DIRECTION
186  # | ||
187  # {data_queue} ||
188  # | ||
189  # data output \/
190  #
191  # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
192  # `pin_memory=False`.
193  #
194  #
195  # Terminating multiprocessing logic requires very careful design. In
196  # particular, we need to make sure that
197  #
198  # 1. The iterator gracefully exits the workers when its last reference is
199  # gone or it is depleted.
200  #
201  # In this case, the workers should be gracefully exited because the
202  # main process may still need to continue to run, and we want cleaning
203  # up code in the workers to be executed (e.g., releasing GPU memory).
204  # Naturally, we implement the shutdown logic in `__del__` of
205  # DataLoaderIterator.
206  #
207  # We delay the discussion on the logic in this case until later.
208  #
209  # 2. The iterator exits the workers when the loader process and/or worker
210  # processes exits normally or with error.
211  #
212  # We set all workers and `pin_memory_thread` to have `daemon=True`.
213  #
214  # You may ask, why can't we make the workers non-daemonic, and
215  # gracefully exit using the same logic as we have in `__del__` when the
216  # iterator gets deleted (see 1 above)?
217  #
218  # First of all, `__del__` is **not** guaranteed to be called when
219  # interpreter exits. Even if it is called, by the time it executes,
220  # many Python core library resources may alreay be freed, and even
221  # simple things like acquiring an internal lock of a queue may hang.
222  # Therefore, in this case, we actually need to prevent `__del__` from
223  # being executed, and rely on the automatic termination of daemonic
224  # children. Thus, we register an `atexit` hook that sets a global flag
225  # `_utils.python_exit_status`. Since `atexit` hooks are executed in the
226  # reverse order of registration, we are guaranteed that this flag is
227  # set before library resources we use are freed. (Hooks freeing those
228  # resources are registered at importing the Python core libraries at
229  # the top of this file.) So in `__del__`, we check if
230  # `_utils.python_exit_status` is set or `None` (freed), and perform
231  # no-op if so.
232  #
233  # Another problem with `__del__` is also related to the library cleanup
234  # calls. When a process ends, it shuts the all its daemonic children
235  # down with a SIGTERM (instead of joining them without a timeout).
236  # Simiarly for threads, but by a different mechanism. This fact,
237  # together with a few implementation details of multiprocessing, forces
238  # us to make workers daemonic. All of our problems arise when a
239  # DataLoader is used in a subprocess, and are caused by multiprocessing
240  # code which looks more or less like this:
241  #
242  # try:
243  # your_function_using_a_dataloader()
244  # finally:
245  # multiprocessing.util._exit_function()
246  #
247  # The joining/termination mentioned above happens inside
248  # `_exit_function()`. Now, if `your_function_using_a_dataloader()`
249  # throws, the stack trace stored in the exception will prevent the
250  # frame which uses `DataLoaderIter` to be freed. If the frame has any
251  # reference to the `DataLoaderIter` (e.g., in a method of the iter),
252  # its `__del__`, which starts the shutdown procedure, will not be
253  # called. That, in turn, means that workers aren't notified. Attempting
254  # to join in `_exit_function` will then result in a hang.
255  #
256  # For context, `_exit_function` is also registered as an `atexit` call.
257  # So it is unclear to me (@ssnl) why this is needed in a finally block.
258  # The code dates back to 2008 and there is no comment on the original
259  # PEP 371 or patch https://bugs.python.org/issue3050 (containing both
260  # the finally block and the `atexit` registration) that explains this.
261  #
262  # Another choice is to just shutdown workers with logic in 1 above
263  # whenever we see an error in `next`. This isn't ideal because
264  # a. It prevents users from using try-catch to resume data loading.
265  # b. It doesn't prevent hanging if users have references to the
266  # iterator.
267  #
268  # 3. All processes exit if any of them die unexpectedly by fatal signals.
269  #
270  # As shown above, the workers are set as daemonic children of the main
271  # process. However, automatic cleaning-up of such child processes only
272  # happens if the parent process exits gracefully (e.g., not via fatal
273  # signals like SIGKILL). So we must ensure that each process will exit
274  # even the process that should send/receive data to/from it were
275  # killed, i.e.,
276  #
277  # a. A process won't hang when getting from a queue.
278  #
279  # Even with carefully designed data dependencies (i.e., a `put()`
280  # always corresponding to a `get()`), hanging on `get()` can still
281  # happen when data in queue is corrupted (e.g., due to
282  # `cancel_join_thread` or unexpected exit).
283  #
284  # For child exit, we set a timeout whenever we try to get data
285  # from `data_queue`, and check the workers' status on each timeout
286  # and error.
287  # See `_DataLoaderiter._get_batch()` and
288  # `_DataLoaderiter._try_get_batch()` for details.
289  #
290  # Additionally, for child exit on non-Windows platforms, we also
291  # register a SIGCHLD handler (which is supported on Windows) on
292  # the main process, which checks if any of the workers fail in the
293  # (Python) handler. This is more efficient and faster in detecting
294  # worker failures, compared to only using the above mechanism.
295  # See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
296  #
297  # For `.get()` calls where the sender(s) is not the workers, we
298  # guard them with timeouts, and check the status of the sender
299  # when timeout happens:
300  # + in the workers, the `_utils.worker.ManagerWatchdog` class
301  # checks the status of the main process.
302  # + if `pin_memory=True`, when getting from `pin_memory_thread`,
303  # check `pin_memory_thread` status periodically until `.get()`
304  # returns or see that `pin_memory_thread` died.
305  #
306  # b. A process won't hang when putting into a queue;
307  #
308  # We use `mp.Queue` which has a separate background thread to put
309  # objects from an unbounded buffer array. The background thread is
310  # daemonic and usually automatically joined when the process
311  # exits.
312  #
313  # However, in case that the receiver has ended abruptly while
314  # reading from the pipe, the join will hang forever. Therefore,
315  # for both `worker_result_queue` (worker -> main process/pin_memory_thread)
316  # and each `index_queue` (main process -> worker), we use
317  # `q.cancel_join_thread()` in sender process before any `q.put` to
318  # prevent this automatic join.
319  #
320  # Moreover, having all queues called `cancel_join_thread` makes
321  # implementing graceful shutdown logic in `__del__` much easier.
322  # It won't need to get from any queue, which would also need to be
323  # guarded by periodic status checks.
324  #
325  # Note that this may leave corrupted data in the queue, but we
326  # don't care about the data anyways once we are shutting down.
327  #
328  #
329  # Now let's get back to 1:
330  # how we gracefully exit the workers when the last reference to the
331  # iterator is gone.
332  #
333  # To achieve this, we implement the following logic along with the design
334  # choices mentioned above:
335  #
336  # [worker processes]
337  # While loader process is alive:
338  # Get from index_queue.
339  # If got a `None`, exit.
340  # If get anything else,
341  # Check `done_event`.
342  # If set, continue to next iteration
343  # i.e., keep getting until see the `None`, then exit.
344  # Otherwise, process data.
345  # If timed out,
346  # No matter `done_event` is set (still need to see `None`) or not,
347  # must continue to next iteration .
348  #
349  # [pin_memory_thread]
350  # # No need to check main thread. If this thread is alive, the main loader
351  # # thread must be alive, because this thread is set as daemonic.
352  # While True:
353  # Get from index_queue.
354  # If got a `None`, exit.
355  # If get anything else,
356  # Check `done_event`.
357  # If set, continue to next iteration
358  # i.e., keep getting until see the `None`, then exit.
359  # Otherwise, process data.
360  #
361  # NOTE: we don't check the status of the main thread because
362  # 1. if the process is killed by fatal signal, `pin_memory_thread`
363  # ends.
364  # 2. in other cases, either the cleaning-up in __del__ or the
365  # automatic exit of daemonic thread will take care of it.
366  # This won't busy-wait either because `.get(timeout)` does not
367  # busy-wait.
368  #
369  # [main process]
370  # In the DataLoader Iter's `__del__`
371  # a. Set `done_event` (shared with `pin_memory_thread` and workers).
372  #
373  # Note: from here on, the workers & `pin_memory_thread` may exit at
374  # any time after they receive `None`.
375  #
376  # b. Exit `pin_memory_thread`
377  # i. Put `None` in `worker_result_queue`.
378  # ii. Join the `pin_memory_thread`.
379  #
380  # c. Exit the workers.
381  # i. Put `None` in each worker's `index_queue`.
382  # ii. Join the workers.
383  #
384  # NOTE: This has to be after (b) because it may leave corrupted data
385  # in `worker_result_queue`, which `pin_memory_thread` reads
386  # from.
387  #
388  # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
389  # can be omitted
390  #
391  # NB: `done_event`s isn't strictly needed. E.g., we can just check for
392  # `None` from `index_queue`, but it allows us to skip wasting resources
393  # processing indices already in `index_queue` if we are already shutting
394  # down.
395 
396  def __init__(self, loader):
397  self.dataset = loader.dataset
398  self.collate_fn = loader.collate_fn
399  self.batch_sampler = loader.batch_sampler
400  self.num_workers = loader.num_workers
401  self.pin_memory = loader.pin_memory and torch.cuda.is_available()
402  self.timeout = loader.timeout
403 
404  self.sample_iter = iter(self.batch_sampler)
405 
406  base_seed = torch.LongTensor(1).random_().item()
407 
408  if self.num_workers > 0:
409  self.worker_init_fn = loader.worker_init_fn
410  self.worker_queue_idx = 0
411  self.worker_result_queue = multiprocessing.Queue()
412  self.batches_outstanding = 0
413  self.worker_pids_set = False
414  self.shutdown = False
415  self.send_idx = 0
416  self.rcvd_idx = 0
417  self.reorder_dict = {}
418  self.done_event = multiprocessing.Event()
419 
420  self.index_queues = []
421  self.workers = []
422  for i in range(self.num_workers):
423  index_queue = multiprocessing.Queue()
424  index_queue.cancel_join_thread()
425  w = multiprocessing.Process(
426  target=_utils.worker._worker_loop,
427  args=(self.dataset, index_queue,
428  self.worker_result_queue, self.done_event,
429  self.collate_fn, base_seed + i,
430  self.worker_init_fn, i))
431  w.daemon = True
432  # NB: Process.start() actually take some time as it needs to
433  # start a process and pass the arguments over via a pipe.
434  # Therefore, we only add a worker to self.workers list after
435  # it started, so that we do not call .join() if program dies
436  # before it starts, and __del__ tries to join but will get:
437  # AssertionError: can only join a started process.
438  w.start()
439  self.index_queues.append(index_queue)
440  self.workers.append(w)
441 
442  if self.pin_memory:
443  self.data_queue = queue.Queue()
444  pin_memory_thread = threading.Thread(
445  target=_utils.pin_memory._pin_memory_loop,
446  args=(self.worker_result_queue, self.data_queue,
448  pin_memory_thread.daemon = True
449  pin_memory_thread.start()
450  # Similar to workers (see comment above), we only register
451  # pin_memory_thread once it is started.
452  self.pin_memory_thread = pin_memory_thread
453  else:
454  self.data_queue = self.worker_result_queue
455 
456  _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers))
457  _utils.signal_handling._set_SIGCHLD_handler()
458  self.worker_pids_set = True
459 
460  # prime the prefetch loop
461  for _ in range(2 * self.num_workers):
462  self._put_indices()
463 
464  def __len__(self):
465  return len(self.batch_sampler)
466 
467  def _try_get_batch(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
468  # Tries to fetch data from `data_queue` for a given timeout. This can
469  # also be used as inner loop of fetching without timeout, with the
470  # sender status as the loop condition.
471  #
472  # This raises a `RuntimeError` if any worker died expectedly. This error
473  # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
474  # (only for non-Windows platforms), or the manual check below on errors
475  # and timeouts.
476  #
477  # Returns a 2-tuple:
478  # (bool: whether successfully get data, any: data if successful else None)
479  try:
480  data = self.data_queue.get(timeout=timeout)
481  return (True, data)
482  except Exception as e:
483  # At timeout and error, we manually check whether any worker has
484  # failed. Note that this is the only mechanism for Windows to detect
485  # worker failures.
486  if not all(w.is_alive() for w in self.workers):
487  pids_str = ', '.join(str(w.pid) for w in self.workers if not w.is_alive())
488  raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))
489  if isinstance(e, queue.Empty):
490  return (False, None)
491  raise
492 
493  def _get_batch(self):
494  # Fetches data from `self.data_queue`.
495  #
496  # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
497  # which we achieve by running `self._try_get_batch(timeout=MP_STATUS_CHECK_INTERVAL)`
498  # in a loop. This is the only mechanism to detect worker failures for
499  # Windows. For other platforms, a SIGCHLD handler is also used for
500  # worker failure detection.
501  #
502  # If `pin_memory=True`, we also need check if `pin_memory_thread` had
503  # died at timeouts.
504  if self.timeout > 0:
505  success, data = self._try_get_batch(self.timeout)
506  if success:
507  return data
508  else:
509  raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
510  elif self.pin_memory:
511  while self.pin_memory_thread.is_alive():
512  success, data = self._try_get_batch()
513  if success:
514  return data
515  else:
516  # while condition is false, i.e., pin_memory_thread died.
517  raise RuntimeError('Pin memory thread exited unexpectedly')
518  # In this case, `self.data_queue` is a `queue.Queue`,. But we don't
519  # need to call `.task_done()` because we don't use `.join()`.
520  else:
521  while True:
522  success, data = self._try_get_batch()
523  if success:
524  return data
525 
526  def __next__(self):
527  if self.num_workers == 0: # same-process loading
528  indices = next(self.sample_iter) # may raise StopIteration
529  batch = self.collate_fn([self.dataset[i] for i in indices])
530  if self.pin_memory:
531  batch = _utils.pin_memory.pin_memory_batch(batch)
532  return batch
533 
534  # check if the next sample has already been generated
535  if self.rcvd_idx in self.reorder_dict:
536  batch = self.reorder_dict.pop(self.rcvd_idx)
537  return self._process_next_batch(batch)
538 
539  if self.batches_outstanding == 0:
540  self._shutdown_workers()
541  raise StopIteration
542 
543  while True:
544  assert (not self.shutdown and self.batches_outstanding > 0)
545  idx, batch = self._get_batch()
546  self.batches_outstanding -= 1
547  if idx != self.rcvd_idx:
548  # store out-of-order samples
549  self.reorder_dict[idx] = batch
550  continue
551  return self._process_next_batch(batch)
552 
553  next = __next__ # Python 2 compatibility
554 
555  def __iter__(self):
556  return self
557 
558  def _put_indices(self):
559  assert self.batches_outstanding < 2 * self.num_workers
560  indices = next(self.sample_iter, None)
561  if indices is None:
562  return
563  self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
564  self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
565  self.batches_outstanding += 1
566  self.send_idx += 1
567 
568  def _process_next_batch(self, batch):
569  self.rcvd_idx += 1
570  self._put_indices()
571  if isinstance(batch, _utils.ExceptionWrapper):
572  raise batch.exc_type(batch.exc_msg)
573  return batch
574 
575  def __getstate__(self):
576  # TODO: add limited pickling support for sharing an iterator
577  # across multiple threads for HOGWILD.
578  # Probably the best way to do this is by moving the sample pushing
579  # to a separate thread and then just sharing the data queue
580  # but signalling the end is tricky without a non-blocking API
581  raise NotImplementedError("_DataLoaderIter cannot be pickled")
582 
583  def _shutdown_workers(self):
584  # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
585  # the logic of this function.
586  python_exit_status = _utils.python_exit_status
587  if python_exit_status is True or python_exit_status is None:
588  # See (2) of the note. If Python is shutting down, do no-op.
589  return
590  # Normal exit when last reference is gone / iterator is depleted.
591  # See (1) and the second half of the note.
592  if not self.shutdown:
593  self.shutdown = True
594  # Removes pids from the C side data structure first so worker
595  # termination afterwards won't trigger false positive error report.
596  if self.worker_pids_set:
597  _utils.signal_handling._remove_worker_pids(id(self))
598  self.worker_pids_set = False
599 
600  self.done_event.set()
601 
602  # Exit `pin_memory_thread` first because exiting workers may leave
603  # corrupted data in `worker_result_queue` which `pin_memory_thread`
604  # reads from.
605  if hasattr(self, 'pin_memory_thread'):
606  # Use hasattr in case error happens before we set the attribute.
607  # First time do `worker_result_queue.put` in this process.
608 
609  # `cancel_join_thread` in case that `pin_memory_thread` exited.
610  self.worker_result_queue.cancel_join_thread()
611  self.worker_result_queue.put(None)
612  self.pin_memory_thread.join()
613  # Indicate that no more data will be put on this queue by the
614  # current process. This **must** be called after
615  # `pin_memory_thread` is joined because that thread shares the
616  # same pipe handles with this loader thread. If the handle is
617  # closed, Py3 will error in this case, but Py2 will just time
618  # out even if there is data in the queue.
619  self.worker_result_queue.close()
620 
621  # Exit workers now.
622  for q in self.index_queues:
623  q.put(None)
624  # Indicate that no more data will be put on this queue by the
625  # current process.
626  q.close()
627  for w in self.workers:
628  w.join()
629 
630  def __del__(self):
631  if self.num_workers > 0:
632  self._shutdown_workers()
def is_available()
Definition: __init__.py:45
def _try_get_batch(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL)
Definition: dataloader.py:467
def current_device()
Definition: __init__.py:349