Caffe2 - C++ API
A deep learning, cross platform ML framework
distributed.cpp
1 #include <torch/data/samplers/distributed.h>
2 #include <torch/serialize/archive.h>
3 #include <torch/types.h>
4 
5 #include <algorithm>
6 #include <cstddef>
7 #include <random>
8 #include <vector>
9 
10 namespace torch {
11 namespace data {
12 namespace samplers {
13 
14 DistributedRandomSampler::DistributedRandomSampler(
15  size_t size,
16  size_t num_replicas,
17  size_t rank,
18  bool allow_duplicates)
19  : DistributedSampler(size, num_replicas, rank, allow_duplicates),
20  begin_index_(0),
21  end_index_(0),
22  sample_index_(0) {
23  // shuffle first time.
24  reset(size_);
25 }
26 
27 optional<std::vector<size_t>> DistributedRandomSampler::next(
28  size_t batch_size) {
29  if (sample_index_ == end_index_) {
30  return nullopt;
31  }
32 
33  size_t end = sample_index_ + batch_size;
34  if (end > end_index_) {
35  end = end_index_;
36  }
37 
38  auto iter = all_indices_.begin();
39  std::vector<size_t> res(iter + sample_index_, iter + end);
40  sample_index_ = end;
41  return res;
42 }
43 
44 void DistributedRandomSampler::reset(optional<size_t> new_size) {
45  size_ = new_size.value_or(size_);
46  populate_indices();
47 
48  std::mt19937 rand(epoch_);
49  std::shuffle(all_indices_.begin(), all_indices_.end(), rand);
50  sample_index_ = begin_index_;
51 }
52 
53 void DistributedRandomSampler::populate_indices() {
54  size_t num_local_samples = local_sample_count();
55  size_t sample_count =
56  num_replicas_ == 1 ? size_ : num_local_samples * num_replicas_;
57  all_indices_.resize(sample_count);
58  std::iota(std::begin(all_indices_), std::end(all_indices_), 0);
59  for (size_t i = size_; i < sample_count; ++i) {
60  // we may have added duplicate samples to make all
61  // replicas to have the same number of samples.
62  all_indices_[i] = i - size_;
63  }
64  begin_index_ = rank_ * num_local_samples;
65  end_index_ = begin_index_ + num_local_samples;
66  sample_index_ = begin_index_;
67 }
68 
69 void DistributedRandomSampler::save(serialize::OutputArchive& archive) const {
70  archive.write(
71  "sample_index_",
72  torch::tensor(static_cast<int64_t>(sample_index_)),
73  /*is_buffer=*/true);
74  archive.write(
75  "epoch_",
76  torch::tensor(static_cast<int64_t>(epoch_)),
77  /*is_buffer=*/true);
78 }
79 
80 void DistributedRandomSampler::load(serialize::InputArchive& archive) {
81  auto tensor = torch::empty(1, torch::kInt64);
82  archive.read("epoch_", tensor, /*is_buffer=*/true);
83  epoch_ = tensor.item<int64_t>();
84  // call reset() after loading epoch_ to populate indices.
85  reset(size_);
86 
87  tensor = torch::empty(1, torch::kInt64);
88  archive.read("sample_index_", tensor, /*is_buffer=*/true);
89  sample_index_ = tensor.item<int64_t>();
90 }
91 
92 size_t DistributedRandomSampler::index() const noexcept {
93  return sample_index_;
94 }
95 
96 DistributedSequentialSampler::DistributedSequentialSampler(
97  size_t size,
98  size_t num_replicas,
99  size_t rank,
100  bool allow_duplicates)
101  : DistributedSampler(size, num_replicas, rank, allow_duplicates),
102  begin_index_(0),
103  end_index_(0),
104  sample_index_(0) {
105  populate_indices();
106 }
107 
108 optional<std::vector<size_t>> DistributedSequentialSampler::next(
109  size_t batch_size) {
110  if (sample_index_ == end_index_) {
111  return nullopt;
112  }
113 
114  size_t end = sample_index_ + batch_size;
115  if (end > end_index_) {
116  end = end_index_;
117  }
118 
119  std::vector<size_t> res(end - sample_index_);
120  std::iota(std::begin(res), std::end(res), sample_index_);
121  if (end >= size_) {
122  for (size_t& index : res) {
123  index = index % size_;
124  }
125  }
126  sample_index_ = end;
127  return res;
128 }
129 
130 void DistributedSequentialSampler::reset(optional<size_t> new_size) {
131  size_t size = new_size.value_or(size_);
132  if (size != size_) {
133  size_ = size;
134  populate_indices();
135  } else {
136  sample_index_ = begin_index_;
137  }
138 }
139 
140 void DistributedSequentialSampler::populate_indices() {
141  begin_index_ = rank_ * local_sample_count();
142  end_index_ = begin_index_ + local_sample_count();
143  sample_index_ = begin_index_;
144 }
145 
146 void DistributedSequentialSampler::save(
147  serialize::OutputArchive& archive) const {
148  archive.write(
149  "sample_index_",
150  torch::tensor(static_cast<int64_t>(sample_index_)),
151  /*is_buffer=*/true);
152 }
153 
154 void DistributedSequentialSampler::load(serialize::InputArchive& archive) {
155  auto tensor = torch::empty(1, torch::kInt64);
156  archive.read("sample_index_", tensor, /*is_buffer=*/true);
157  sample_index_ = tensor.item<int64_t>();
158 }
159 
160 size_t DistributedSequentialSampler::index() const noexcept {
161  return sample_index_;
162 }
163 
164 } // namespace samplers
165 } // namespace data
166 } // namespace torch
optional< size_t > size() const override
Returns the size of the dataset.
Definition: mnist.cpp:107
void read(const std::string &key, Tensor &tensor, bool is_buffer=false)
Reads a tensor associated with a given key.
A Sampler that selects a subset of indices to sample from and defines a sampling behavior.
Definition: distributed.h:26
Definition: jit_type.h:17
void write(const std::string &key, const Tensor &tensor, bool is_buffer=false)
Writes a (key, tensor) pair to the OutputArchive, and marks it as being or not being a buffer (non-di...
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32