1 r""""Contains definitions of the methods used by the _DataLoaderIter to put 2 fetched tensors into pinned memory. 4 These **needs** to be in global scope since Py2 doesn't support serializing 9 from torch._six import queue, container_abcs, string_classes
10 from .
import collate, MP_STATUS_CHECK_INTERVAL, ExceptionWrapper
13 def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
20 r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
24 if done_event.is_set():
30 assert done_event.is_set()
32 elif done_event.is_set():
35 elif isinstance(r[1], ExceptionWrapper):
40 batch = pin_memory_batch(batch)
42 out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
44 out_queue.put((idx, batch))
47 def pin_memory_batch(batch):
48 if isinstance(batch, torch.Tensor):
49 return batch.pin_memory()
50 elif isinstance(batch, string_classes):
52 elif isinstance(batch, container_abcs.Mapping):
53 return {k: pin_memory_batch(sample)
for k, sample
in batch.items()}
54 elif isinstance(batch, tuple)
and hasattr(batch,
'_fields'):
55 return type(batch)(*(pin_memory_batch(sample)
for sample
in batch))
56 elif isinstance(batch, container_abcs.Sequence):
57 return [pin_memory_batch(sample)
for sample
in batch]
58 elif hasattr(batch,
"pin_memory"):
59 return batch.pin_memory()