Caffe2 - Python API
A deep learning, cross platform ML framework
collate.py
1 r""""Contains definitions of the methods used by the _DataLoaderIter workers to
2 collate samples fetched from dataset into Tensor(s).
3 
4 These **needs** to be in global scope since Py2 doesn't support serializing
5 static methods.
6 """
7 
8 import torch
9 import re
10 from torch._six import container_abcs, string_classes, int_classes
11 
12 _use_shared_memory = False
13 r"""Whether to use shared memory in default_collate"""
14 
15 np_str_obj_array_pattern = re.compile(r'[SaUO]')
16 
17 error_msg_fmt = "batch must contain tensors, numbers, dicts or lists; found {}"
18 
19 numpy_type_map = {
20  'float64': torch.DoubleTensor,
21  'float32': torch.FloatTensor,
22  'float16': torch.HalfTensor,
23  'int64': torch.LongTensor,
24  'int32': torch.IntTensor,
25  'int16': torch.ShortTensor,
26  'int8': torch.CharTensor,
27  'uint8': torch.ByteTensor,
28 }
29 
30 
31 def default_collate(batch):
32  r"""Puts each data field into a tensor with outer dimension batch size"""
33 
34  elem_type = type(batch[0])
35  if isinstance(batch[0], torch.Tensor):
36  out = None
37  if _use_shared_memory:
38  # If we're in a background process, concatenate directly into a
39  # shared memory tensor to avoid an extra copy
40  numel = sum([x.numel() for x in batch])
41  storage = batch[0].storage()._new_shared(numel)
42  out = batch[0].new(storage)
43  return torch.stack(batch, 0, out=out)
44  elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
45  and elem_type.__name__ != 'string_':
46  elem = batch[0]
47  if elem_type.__name__ == 'ndarray':
48  # array of string classes and object
49  if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
50  raise TypeError(error_msg_fmt.format(elem.dtype))
51 
52  return default_collate([torch.from_numpy(b) for b in batch])
53  if elem.shape == (): # scalars
54  py_type = float if elem.dtype.name.startswith('float') else int
55  return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
56  elif isinstance(batch[0], float):
57  return torch.tensor(batch, dtype=torch.float64)
58  elif isinstance(batch[0], int_classes):
59  return torch.tensor(batch)
60  elif isinstance(batch[0], string_classes):
61  return batch
62  elif isinstance(batch[0], container_abcs.Mapping):
63  return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
64  elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'): # namedtuple
65  return type(batch[0])(*(default_collate(samples) for samples in zip(*batch)))
66  elif isinstance(batch[0], container_abcs.Sequence):
67  transposed = zip(*batch)
68  return [default_collate(samples) for samples in transposed]
69 
70  raise TypeError((error_msg_fmt.format(type(batch[0]))))