5 from torch
import randperm
9 """An abstract class representing a Dataset. 11 All other datasets should subclass it. All subclasses should override 12 ``__len__``, that provides the size of the dataset, and ``__getitem__``, 13 supporting integer indexing in range from 0 to len(self) exclusive. 16 def __getitem__(self, index):
17 raise NotImplementedError
20 raise NotImplementedError
22 def __add__(self, other):
27 """Dataset wrapping tensors. 29 Each sample will be retrieved by indexing tensors along the first dimension. 32 *tensors (Tensor): tensors that have the same size of the first dimension. 35 def __init__(self, *tensors):
36 assert all(tensors[0].size(0) == tensor.size(0)
for tensor
in tensors)
39 def __getitem__(self, index):
40 return tuple(tensor[index]
for tensor
in self.
tensors)
48 Dataset to concatenate multiple datasets. 49 Purpose: useful to assemble different existing datasets, possibly 50 large-scale datasets as the concatenation operation is done in an 54 datasets (sequence): List of datasets to be concatenated 66 def __init__(self, datasets):
67 super(ConcatDataset, self).__init__()
68 assert len(datasets) > 0,
'datasets should not be an empty iterable' 75 def __getitem__(self, idx):
78 raise ValueError(
"absolute value of index should not exceed dataset length")
85 return self.
datasets[dataset_idx][sample_idx]
88 def cummulative_sizes(self):
89 warnings.warn(
"cummulative_sizes attribute is renamed to " 90 "cumulative_sizes", DeprecationWarning, stacklevel=2)
96 Subset of a dataset at specified indices. 99 dataset (Dataset): The whole Dataset 100 indices (sequence): Indices in the whole set selected for subset 102 def __init__(self, dataset, indices):
106 def __getitem__(self, idx):
113 def random_split(dataset, lengths):
115 Randomly split a dataset into non-overlapping new datasets of given lengths. 118 dataset (Dataset): Dataset to be split 119 lengths (sequence): lengths of splits to be produced 121 if sum(lengths) != len(dataset):
122 raise ValueError(
"Sum of input lengths does not equal the length of the input dataset!")
124 indices = randperm(sum(lengths)).tolist()
125 return [
Subset(dataset, indices[offset - length:offset])
for offset, length
in zip(_accumulate(lengths), lengths)]