Caffe2 - Python API
A deep learning, cross platform ML framework
pin_memory.py
1 r""""Contains definitions of the methods used by the _DataLoaderIter to put
2 fetched tensors into pinned memory.
3 
4 These **needs** to be in global scope since Py2 doesn't support serializing
5 static methods.
6 """
7 
8 import torch
9 from torch._six import queue, container_abcs, string_classes
10 from . import collate, MP_STATUS_CHECK_INTERVAL, ExceptionWrapper
11 
12 
13 def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
14  torch.cuda.set_device(device_id)
15 
16  # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
17  # logic of this function.
18  while True:
19  try:
20  r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
21  except queue.Empty:
22  continue
23  except Exception:
24  if done_event.is_set():
25  # Weird things can happen when shutting down, e.g., fd being
26  # closed when tensors are shared via fds.
27  break
28  raise
29  if r is None:
30  assert done_event.is_set()
31  return
32  elif done_event.is_set():
33  # Haven't seen the final signal yet. Keep getting until None.
34  continue
35  elif isinstance(r[1], ExceptionWrapper):
36  out_queue.put(r)
37  else:
38  idx, batch = r
39  try:
40  batch = pin_memory_batch(batch)
41  except Exception:
42  out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
43  else:
44  out_queue.put((idx, batch))
45 
46 
47 def pin_memory_batch(batch):
48  if isinstance(batch, torch.Tensor):
49  return batch.pin_memory()
50  elif isinstance(batch, string_classes):
51  return batch
52  elif isinstance(batch, container_abcs.Mapping):
53  return {k: pin_memory_batch(sample) for k, sample in batch.items()}
54  elif isinstance(batch, tuple) and hasattr(batch, '_fields'): # namedtuple
55  return type(batch)(*(pin_memory_batch(sample) for sample in batch))
56  elif isinstance(batch, container_abcs.Sequence):
57  return [pin_memory_batch(sample) for sample in batch]
58  elif hasattr(batch, "pin_memory"):
59  return batch.pin_memory()
60  else:
61  return batch
def set_device(device)
Definition: __init__.py:253