Caffe2 - Python API
A deep learning, cross platform ML framework
dataset.py
1 import bisect
2 import warnings
3 
4 from torch._utils import _accumulate
5 from torch import randperm
6 
7 
8 class Dataset(object):
9  """An abstract class representing a Dataset.
10 
11  All other datasets should subclass it. All subclasses should override
12  ``__len__``, that provides the size of the dataset, and ``__getitem__``,
13  supporting integer indexing in range from 0 to len(self) exclusive.
14  """
15 
16  def __getitem__(self, index):
17  raise NotImplementedError
18 
19  def __len__(self):
20  raise NotImplementedError
21 
22  def __add__(self, other):
23  return ConcatDataset([self, other])
24 
25 
27  """Dataset wrapping tensors.
28 
29  Each sample will be retrieved by indexing tensors along the first dimension.
30 
31  Arguments:
32  *tensors (Tensor): tensors that have the same size of the first dimension.
33  """
34 
35  def __init__(self, *tensors):
36  assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
37  self.tensors = tensors
38 
39  def __getitem__(self, index):
40  return tuple(tensor[index] for tensor in self.tensors)
41 
42  def __len__(self):
43  return self.tensors[0].size(0)
44 
45 
47  """
48  Dataset to concatenate multiple datasets.
49  Purpose: useful to assemble different existing datasets, possibly
50  large-scale datasets as the concatenation operation is done in an
51  on-the-fly manner.
52 
53  Arguments:
54  datasets (sequence): List of datasets to be concatenated
55  """
56 
57  @staticmethod
58  def cumsum(sequence):
59  r, s = [], 0
60  for e in sequence:
61  l = len(e)
62  r.append(l + s)
63  s += l
64  return r
65 
66  def __init__(self, datasets):
67  super(ConcatDataset, self).__init__()
68  assert len(datasets) > 0, 'datasets should not be an empty iterable'
69  self.datasets = list(datasets)
70  self.cumulative_sizes = self.cumsum(self.datasets)
71 
72  def __len__(self):
73  return self.cumulative_sizes[-1]
74 
75  def __getitem__(self, idx):
76  if idx < 0:
77  if -idx > len(self):
78  raise ValueError("absolute value of index should not exceed dataset length")
79  idx = len(self) + idx
80  dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
81  if dataset_idx == 0:
82  sample_idx = idx
83  else:
84  sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
85  return self.datasets[dataset_idx][sample_idx]
86 
87  @property
88  def cummulative_sizes(self):
89  warnings.warn("cummulative_sizes attribute is renamed to "
90  "cumulative_sizes", DeprecationWarning, stacklevel=2)
91  return self.cumulative_sizes
92 
93 
94 class Subset(Dataset):
95  """
96  Subset of a dataset at specified indices.
97 
98  Arguments:
99  dataset (Dataset): The whole Dataset
100  indices (sequence): Indices in the whole set selected for subset
101  """
102  def __init__(self, dataset, indices):
103  self.dataset = dataset
104  self.indices = indices
105 
106  def __getitem__(self, idx):
107  return self.dataset[self.indices[idx]]
108 
109  def __len__(self):
110  return len(self.indices)
111 
112 
113 def random_split(dataset, lengths):
114  """
115  Randomly split a dataset into non-overlapping new datasets of given lengths.
116 
117  Arguments:
118  dataset (Dataset): Dataset to be split
119  lengths (sequence): lengths of splits to be produced
120  """
121  if sum(lengths) != len(dataset):
122  raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
123 
124  indices = randperm(sum(lengths)).tolist()
125  return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]