1 #include <torch/data/samplers/distributed.h> 2 #include <torch/serialize/archive.h> 3 #include <torch/types.h> 14 DistributedRandomSampler::DistributedRandomSampler(
18 bool allow_duplicates)
19 : DistributedSampler(size, num_replicas, rank, allow_duplicates),
29 if (sample_index_ == end_index_) {
33 size_t end = sample_index_ + batch_size;
34 if (end > end_index_) {
38 auto iter = all_indices_.begin();
39 std::vector<size_t> res(iter + sample_index_, iter + end);
45 size_ = new_size.value_or(size_);
48 std::mt19937 rand(epoch_);
49 std::shuffle(all_indices_.begin(), all_indices_.end(), rand);
50 sample_index_ = begin_index_;
53 void DistributedRandomSampler::populate_indices() {
54 size_t num_local_samples = local_sample_count();
56 num_replicas_ == 1 ? size_ : num_local_samples * num_replicas_;
57 all_indices_.resize(sample_count);
58 std::iota(std::begin(all_indices_), std::end(all_indices_), 0);
59 for (
size_t i = size_; i < sample_count; ++i) {
62 all_indices_[i] = i - size_;
64 begin_index_ = rank_ * num_local_samples;
65 end_index_ = begin_index_ + num_local_samples;
66 sample_index_ = begin_index_;
72 torch::tensor(static_cast<int64_t>(sample_index_)),
76 torch::tensor(static_cast<int64_t>(epoch_)),
81 auto tensor = torch::empty(1, torch::kInt64);
82 archive.
read(
"epoch_", tensor,
true);
83 epoch_ = tensor.item<int64_t>();
87 tensor = torch::empty(1, torch::kInt64);
88 archive.
read(
"sample_index_", tensor,
true);
89 sample_index_ = tensor.item<int64_t>();
92 size_t DistributedRandomSampler::index() const noexcept {
96 DistributedSequentialSampler::DistributedSequentialSampler(
100 bool allow_duplicates)
110 if (sample_index_ == end_index_) {
114 size_t end = sample_index_ + batch_size;
115 if (end > end_index_) {
119 std::vector<size_t> res(end - sample_index_);
120 std::iota(std::begin(res), std::end(res), sample_index_);
122 for (
size_t& index : res) {
123 index = index % size_;
131 size_t size = new_size.value_or(size_);
136 sample_index_ = begin_index_;
140 void DistributedSequentialSampler::populate_indices() {
141 begin_index_ = rank_ * local_sample_count();
142 end_index_ = begin_index_ + local_sample_count();
143 sample_index_ = begin_index_;
146 void DistributedSequentialSampler::save(
150 torch::tensor(static_cast<int64_t>(sample_index_)),
155 auto tensor = torch::empty(1, torch::kInt64);
156 archive.
read(
"sample_index_", tensor,
true);
157 sample_index_ = tensor.item<int64_t>();
160 size_t DistributedSequentialSampler::index() const noexcept {
161 return sample_index_;
optional< size_t > size() const override
Returns the size of the dataset.
A Sampler that selects a subset of indices to sample from and defines a sampling behavior.
void write(const std::string &key, const Tensor &tensor, bool is_buffer=false)
Writes a (key, tensor) pair to the OutputArchive, and marks it as being or not being a buffer (non-di...