2 from torch._six import int_classes 
as _int_classes
     6     r"""Base class for all Samplers.     8     Every Sampler subclass has to provide an __iter__ method, providing a way     9     to iterate over indices of dataset elements, and a __len__ method that    10     returns the length of the returned iterators.    13     def __init__(self, data_source):
    17         raise NotImplementedError
    20         raise NotImplementedError
    24     r"""Samples elements sequentially, always in the same order.    27         data_source (Dataset): dataset to sample from    30     def __init__(self, data_source):
    41     r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.    42     If with replacement, then user can specify ``num_samples`` to draw.    45         data_source (Dataset): dataset to sample from    46         num_samples (int): number of samples to draw, default=len(dataset)    47         replacement (bool): samples are drawn with replacement if ``True``, default=False    50     def __init__(self, data_source, replacement=False, num_samples=None):
    55         if self.
_num_samples is not None and replacement 
is False:
    56             raise ValueError(
"With replacement=False, num_samples should not be specified, "    57                              "since a random permute will be performed.")
    60             raise ValueError(
"num_samples should be a positive integer "    61                              "value, but got num_samples={}".format(self.
num_samples))
    63             raise ValueError(
"replacement should be a boolean value, but got "    67     def num_samples(self):
    76             return iter(torch.randint(high=n, size=(self.
num_samples,), dtype=torch.int64).tolist())
    77         return iter(torch.randperm(n).tolist())
    84     r"""Samples elements randomly from a given list of indices, without replacement.    87         indices (sequence): a sequence of indices    90     def __init__(self, indices):
    94         return (self.
indices[i] 
for i 
in torch.randperm(len(self.
indices)))
   101     r"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights).   104         weights (sequence)   : a sequence of weights, not necessary summing up to one   105         num_samples (int): number of samples to draw   106         replacement (bool): if ``True``, samples are drawn with replacement.   107             If not, they are drawn without replacement, which means that when a   108             sample index is drawn for a row, it cannot be drawn again for that row.   111         >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))   113         >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))   117     def __init__(self, weights, num_samples, replacement=True):
   118         if not isinstance(num_samples, _int_classes) 
or isinstance(num_samples, bool) 
or \
   120             raise ValueError(
"num_samples should be a positive integer "   121                              "value, but got num_samples={}".format(num_samples))
   122         if not isinstance(replacement, bool):
   123             raise ValueError(
"replacement should be a boolean value, but got "   124                              "replacement={}".format(replacement))
   125         self.
weights = torch.as_tensor(weights, dtype=torch.double)
   137     r"""Wraps another sampler to yield a mini-batch of indices.   140         sampler (Sampler): Base sampler.   141         batch_size (int): Size of mini-batch.   142         drop_last (bool): If ``True``, the sampler will drop the last batch if   143             its size would be less than ``batch_size``   146         >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))   147         [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]   148         >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))   149         [[0, 1, 2], [3, 4, 5], [6, 7, 8]]   152     def __init__(self, sampler, batch_size, drop_last):
   153         if not isinstance(sampler, Sampler):
   154             raise ValueError(
"sampler should be an instance of "   155                              "torch.utils.data.Sampler, but got sampler={}"   157         if not isinstance(batch_size, _int_classes) 
or isinstance(batch_size, bool) 
or \
   159             raise ValueError(
"batch_size should be a positive integer value, "   160                              "but got batch_size={}".format(batch_size))
   161         if not isinstance(drop_last, bool):
   162             raise ValueError(
"drop_last should be a boolean value, but got "   163                              "drop_last={}".format(drop_last))
   175         if len(batch) > 0 
and not self.
drop_last: