2 from torch._six import int_classes
as _int_classes
6 r"""Base class for all Samplers. 8 Every Sampler subclass has to provide an __iter__ method, providing a way 9 to iterate over indices of dataset elements, and a __len__ method that 10 returns the length of the returned iterators. 13 def __init__(self, data_source):
17 raise NotImplementedError
20 raise NotImplementedError
24 r"""Samples elements sequentially, always in the same order. 27 data_source (Dataset): dataset to sample from 30 def __init__(self, data_source):
41 r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. 42 If with replacement, then user can specify ``num_samples`` to draw. 45 data_source (Dataset): dataset to sample from 46 num_samples (int): number of samples to draw, default=len(dataset) 47 replacement (bool): samples are drawn with replacement if ``True``, default=False 50 def __init__(self, data_source, replacement=False, num_samples=None):
55 if self.
_num_samples is not None and replacement
is False:
56 raise ValueError(
"With replacement=False, num_samples should not be specified, " 57 "since a random permute will be performed.")
60 raise ValueError(
"num_samples should be a positive integer " 61 "value, but got num_samples={}".format(self.
num_samples))
63 raise ValueError(
"replacement should be a boolean value, but got " 67 def num_samples(self):
76 return iter(torch.randint(high=n, size=(self.
num_samples,), dtype=torch.int64).tolist())
77 return iter(torch.randperm(n).tolist())
84 r"""Samples elements randomly from a given list of indices, without replacement. 87 indices (sequence): a sequence of indices 90 def __init__(self, indices):
94 return (self.
indices[i]
for i
in torch.randperm(len(self.
indices)))
101 r"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights). 104 weights (sequence) : a sequence of weights, not necessary summing up to one 105 num_samples (int): number of samples to draw 106 replacement (bool): if ``True``, samples are drawn with replacement. 107 If not, they are drawn without replacement, which means that when a 108 sample index is drawn for a row, it cannot be drawn again for that row. 111 >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) 113 >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) 117 def __init__(self, weights, num_samples, replacement=True):
118 if not isinstance(num_samples, _int_classes)
or isinstance(num_samples, bool)
or \
120 raise ValueError(
"num_samples should be a positive integer " 121 "value, but got num_samples={}".format(num_samples))
122 if not isinstance(replacement, bool):
123 raise ValueError(
"replacement should be a boolean value, but got " 124 "replacement={}".format(replacement))
125 self.
weights = torch.as_tensor(weights, dtype=torch.double)
137 r"""Wraps another sampler to yield a mini-batch of indices. 140 sampler (Sampler): Base sampler. 141 batch_size (int): Size of mini-batch. 142 drop_last (bool): If ``True``, the sampler will drop the last batch if 143 its size would be less than ``batch_size`` 146 >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) 147 [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 148 >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) 149 [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 152 def __init__(self, sampler, batch_size, drop_last):
153 if not isinstance(sampler, Sampler):
154 raise ValueError(
"sampler should be an instance of " 155 "torch.utils.data.Sampler, but got sampler={}" 157 if not isinstance(batch_size, _int_classes)
or isinstance(batch_size, bool)
or \
159 raise ValueError(
"batch_size should be a positive integer value, " 160 "but got batch_size={}".format(batch_size))
161 if not isinstance(drop_last, bool):
162 raise ValueError(
"drop_last should be a boolean value, but got " 163 "drop_last={}".format(drop_last))
175 if len(batch) > 0
and not self.
drop_last: