Caffe2 - Python API
A deep learning, cross platform ML framework
sampler.py
1 import torch
2 from torch._six import int_classes as _int_classes
3 
4 
5 class Sampler(object):
6  r"""Base class for all Samplers.
7 
8  Every Sampler subclass has to provide an __iter__ method, providing a way
9  to iterate over indices of dataset elements, and a __len__ method that
10  returns the length of the returned iterators.
11  """
12 
13  def __init__(self, data_source):
14  pass
15 
16  def __iter__(self):
17  raise NotImplementedError
18 
19  def __len__(self):
20  raise NotImplementedError
21 
22 
24  r"""Samples elements sequentially, always in the same order.
25 
26  Arguments:
27  data_source (Dataset): dataset to sample from
28  """
29 
30  def __init__(self, data_source):
31  self.data_source = data_source
32 
33  def __iter__(self):
34  return iter(range(len(self.data_source)))
35 
36  def __len__(self):
37  return len(self.data_source)
38 
39 
41  r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
42  If with replacement, then user can specify ``num_samples`` to draw.
43 
44  Arguments:
45  data_source (Dataset): dataset to sample from
46  num_samples (int): number of samples to draw, default=len(dataset)
47  replacement (bool): samples are drawn with replacement if ``True``, default=False
48  """
49 
50  def __init__(self, data_source, replacement=False, num_samples=None):
51  self.data_source = data_source
52  self.replacement = replacement
53  self._num_samples = num_samples
54 
55  if self._num_samples is not None and replacement is False:
56  raise ValueError("With replacement=False, num_samples should not be specified, "
57  "since a random permute will be performed.")
58 
59  if not isinstance(self.num_samples, int) or self.num_samples <= 0:
60  raise ValueError("num_samples should be a positive integer "
61  "value, but got num_samples={}".format(self.num_samples))
62  if not isinstance(self.replacement, bool):
63  raise ValueError("replacement should be a boolean value, but got "
64  "replacement={}".format(self.replacement))
65 
66  @property
67  def num_samples(self):
68  # dataset size might change at runtime
69  if self._num_samples is None:
70  return len(self.data_source)
71  return self._num_samples
72 
73  def __iter__(self):
74  n = len(self.data_source)
75  if self.replacement:
76  return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
77  return iter(torch.randperm(n).tolist())
78 
79  def __len__(self):
80  return self.num_samples
81 
82 
84  r"""Samples elements randomly from a given list of indices, without replacement.
85 
86  Arguments:
87  indices (sequence): a sequence of indices
88  """
89 
90  def __init__(self, indices):
91  self.indices = indices
92 
93  def __iter__(self):
94  return (self.indices[i] for i in torch.randperm(len(self.indices)))
95 
96  def __len__(self):
97  return len(self.indices)
98 
99 
101  r"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
102 
103  Args:
104  weights (sequence) : a sequence of weights, not necessary summing up to one
105  num_samples (int): number of samples to draw
106  replacement (bool): if ``True``, samples are drawn with replacement.
107  If not, they are drawn without replacement, which means that when a
108  sample index is drawn for a row, it cannot be drawn again for that row.
109 
110  Example:
111  >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
112  [0, 0, 0, 1, 0]
113  >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
114  [0, 1, 4, 3, 2]
115  """
116 
117  def __init__(self, weights, num_samples, replacement=True):
118  if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
119  num_samples <= 0:
120  raise ValueError("num_samples should be a positive integer "
121  "value, but got num_samples={}".format(num_samples))
122  if not isinstance(replacement, bool):
123  raise ValueError("replacement should be a boolean value, but got "
124  "replacement={}".format(replacement))
125  self.weights = torch.as_tensor(weights, dtype=torch.double)
126  self.num_samples = num_samples
127  self.replacement = replacement
128 
129  def __iter__(self):
130  return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
131 
132  def __len__(self):
133  return self.num_samples
134 
135 
137  r"""Wraps another sampler to yield a mini-batch of indices.
138 
139  Args:
140  sampler (Sampler): Base sampler.
141  batch_size (int): Size of mini-batch.
142  drop_last (bool): If ``True``, the sampler will drop the last batch if
143  its size would be less than ``batch_size``
144 
145  Example:
146  >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
147  [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
148  >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
149  [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
150  """
151 
152  def __init__(self, sampler, batch_size, drop_last):
153  if not isinstance(sampler, Sampler):
154  raise ValueError("sampler should be an instance of "
155  "torch.utils.data.Sampler, but got sampler={}"
156  .format(sampler))
157  if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
158  batch_size <= 0:
159  raise ValueError("batch_size should be a positive integer value, "
160  "but got batch_size={}".format(batch_size))
161  if not isinstance(drop_last, bool):
162  raise ValueError("drop_last should be a boolean value, but got "
163  "drop_last={}".format(drop_last))
164  self.sampler = sampler
165  self.batch_size = batch_size
166  self.drop_last = drop_last
167 
168  def __iter__(self):
169  batch = []
170  for idx in self.sampler:
171  batch.append(idx)
172  if len(batch) == self.batch_size:
173  yield batch
174  batch = []
175  if len(batch) > 0 and not self.drop_last:
176  yield batch
177 
178  def __len__(self):
179  if self.drop_last:
180  return len(self.sampler) // self.batch_size
181  else:
182  return (len(self.sampler) + self.batch_size - 1) // self.batch_size