Caffe2 - Python API
A deep learning, cross platform ML framework
Packages
Classes
Files
C++ API
Python API
GitHub
File List
torch
utils
data
distributed.py
1
import
math
2
import
torch
3
from
.
import
Sampler
4
import
torch.distributed
as
dist
5
6
7
class
DistributedSampler
(Sampler):
8
"""Sampler that restricts data loading to a subset of the dataset.
9
10
It is especially useful in conjunction with
11
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
12
process can pass a DistributedSampler instance as a DataLoader sampler,
13
and load a subset of the original dataset that is exclusive to it.
14
15
.. note::
16
Dataset is assumed to be of constant size.
17
18
Arguments:
19
dataset: Dataset used for sampling.
20
num_replicas (optional): Number of processes participating in
21
distributed training.
22
rank (optional): Rank of the current process within num_replicas.
23
"""
24
25
def
__init__(self, dataset, num_replicas=None, rank=None):
26
if
num_replicas
is
None
:
27
if
not
dist.is_available():
28
raise
RuntimeError(
"Requires distributed package to be available"
)
29
num_replicas = dist.get_world_size()
30
if
rank
is
None
:
31
if
not
dist.is_available():
32
raise
RuntimeError(
"Requires distributed package to be available"
)
33
rank = dist.get_rank()
34
self.
dataset
= dataset
35
self.
num_replicas
= num_replicas
36
self.
rank
= rank
37
self.
epoch
= 0
38
self.
num_samples
= int(math.ceil(len(self.
dataset
) * 1.0 / self.
num_replicas
))
39
self.
total_size
= self.
num_samples
* self.
num_replicas
40
41
def
__iter__(self):
42
# deterministically shuffle based on epoch
43
g = torch.Generator()
44
g.manual_seed(self.
epoch
)
45
indices = torch.randperm(len(self.
dataset
), generator=g).tolist()
46
47
# add extra samples to make it evenly divisible
48
indices += indices[:(self.
total_size
- len(indices))]
49
assert
len(indices) == self.
total_size
50
51
# subsample
52
indices = indices[self.
rank
:self.
total_size
:self.
num_replicas
]
53
assert
len(indices) == self.
num_samples
54
55
return
iter(indices)
56
57
def
__len__(self):
58
return
self.
num_samples
59
60
def
set_epoch(self, epoch):
61
self.
epoch
= epoch
torch.distributed
Definition:
__init__.py:1
torch.utils.data.distributed.DistributedSampler.num_samples
num_samples
Definition:
distributed.py:38
torch.utils.data.distributed.DistributedSampler.total_size
total_size
Definition:
distributed.py:39
torch.utils.data.distributed.DistributedSampler.epoch
epoch
Definition:
distributed.py:37
torch.utils.data.distributed.DistributedSampler.dataset
dataset
Definition:
distributed.py:34
torch.utils.data.distributed.DistributedSampler.num_replicas
num_replicas
Definition:
distributed.py:35
torch.utils.data.distributed.DistributedSampler.rank
rank
Definition:
distributed.py:36
torch.utils.data.distributed.DistributedSampler
Definition:
distributed.py:7
Generated on Thu Mar 21 2019 13:06:37 for Caffe2 - Python API by
1.8.11