1 r""""Contains definitions of the methods used by the _DataLoaderIter workers to 2 collate samples fetched from dataset into Tensor(s). 4 These **needs** to be in global scope since Py2 doesn't support serializing 10 from torch._six import container_abcs, string_classes, int_classes
12 _use_shared_memory =
False 13 r"""Whether to use shared memory in default_collate""" 15 np_str_obj_array_pattern = re.compile(
r'[SaUO]')
17 error_msg_fmt =
"batch must contain tensors, numbers, dicts or lists; found {}" 20 'float64': torch.DoubleTensor,
21 'float32': torch.FloatTensor,
22 'float16': torch.HalfTensor,
23 'int64': torch.LongTensor,
24 'int32': torch.IntTensor,
25 'int16': torch.ShortTensor,
26 'int8': torch.CharTensor,
27 'uint8': torch.ByteTensor,
31 def default_collate(batch):
32 r"""Puts each data field into a tensor with outer dimension batch size""" 34 elem_type = type(batch[0])
35 if isinstance(batch[0], torch.Tensor):
37 if _use_shared_memory:
40 numel = sum([x.numel()
for x
in batch])
41 storage = batch[0].
storage()._new_shared(numel)
42 out = batch[0].new(storage)
43 return torch.stack(batch, 0, out=out)
44 elif elem_type.__module__ ==
'numpy' and elem_type.__name__ !=
'str_' \
45 and elem_type.__name__ !=
'string_':
47 if elem_type.__name__ ==
'ndarray':
49 if np_str_obj_array_pattern.search(elem.dtype.str)
is not None:
50 raise TypeError(error_msg_fmt.format(elem.dtype))
52 return default_collate([torch.from_numpy(b)
for b
in batch])
54 py_type = float
if elem.dtype.name.startswith(
'float')
else int
55 return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
56 elif isinstance(batch[0], float):
58 elif isinstance(batch[0], int_classes):
60 elif isinstance(batch[0], string_classes):
62 elif isinstance(batch[0], container_abcs.Mapping):
63 return {key: default_collate([d[key]
for d
in batch])
for key
in batch[0]}
64 elif isinstance(batch[0], tuple)
and hasattr(batch[0],
'_fields'):
65 return type(batch[0])(*(default_collate(samples)
for samples
in zip(*batch)))
66 elif isinstance(batch[0], container_abcs.Sequence):
67 transposed = zip(*batch)
68 return [default_collate(samples)
for samples
in transposed]
70 raise TypeError((error_msg_fmt.format(type(batch[0]))))