Caffe2 - C++ API
A deep learning, cross platform ML framework
distributed.h
1 #pragma once
2 
3 #include <torch/csrc/WindowsTorchApiMacro.h>
4 #include <torch/data/samplers/base.h>
5 
6 #include <cstddef>
7 #include <vector>
8 
9 namespace torch {
10 namespace serialize {
11 class OutputArchive;
12 class InputArchive;
13 } // namespace serialize
14 } // namespace torch
15 
16 namespace torch {
17 namespace data {
18 namespace samplers {
19 
25 template <typename BatchRequest = std::vector<size_t>>
26 class DistributedSampler : public Sampler<BatchRequest> {
27  public:
28  TORCH_API DistributedSampler(
29  size_t size,
30  size_t num_replicas = 1,
31  size_t rank = 0,
32  bool allow_duplicates = true)
33  : size_(size),
34  num_replicas_(num_replicas),
35  rank_(rank),
36  epoch_(0),
37  allow_duplicates_(allow_duplicates) {}
38 
41  void set_epoch(size_t epoch) {
42  epoch_ = epoch;
43  }
44 
45  size_t epoch() const {
46  return epoch_;
47  }
48 
49  protected:
50  size_t local_sample_count() {
51  if (allow_duplicates_) {
52  return (size_ + num_replicas_ - 1) / num_replicas_;
53  } else {
54  return size_ / num_replicas_;
55  }
56  }
57 
58  size_t size_;
59  size_t num_replicas_;
60  size_t rank_;
61  size_t epoch_;
62  bool allow_duplicates_;
63 };
64 
68  public:
69  TORCH_API DistributedRandomSampler(
70  size_t size,
71  size_t num_replicas = 1,
72  size_t rank = 0,
73  bool allow_duplicates = true);
74 
76  TORCH_API void reset(optional<size_t> new_size = nullopt) override;
77 
79  TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override;
80 
82  TORCH_API void save(serialize::OutputArchive& archive) const override;
83 
85  TORCH_API void load(serialize::InputArchive& archive) override;
86 
88  TORCH_API size_t index() const noexcept;
89 
90  private:
91  void populate_indices();
92 
93  size_t begin_index_;
94  size_t end_index_;
95  size_t sample_index_;
96  std::vector<size_t> all_indices_;
97 };
98 
101  public:
103  size_t size,
104  size_t num_replicas = 1,
105  size_t rank = 0,
106  bool allow_duplicates = true);
107 
109  TORCH_API void reset(optional<size_t> new_size = nullopt) override;
110 
112  TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override;
113 
115  TORCH_API void save(serialize::OutputArchive& archive) const override;
116 
118  TORCH_API void load(serialize::InputArchive& archive) override;
119 
121  TORCH_API size_t index() const noexcept;
122 
123  private:
124  void populate_indices();
125 
126  size_t begin_index_;
127  size_t end_index_;
128  size_t sample_index_;
129  std::vector<size_t> all_indices_;
130 };
131 
132 } // namespace samplers
133 } // namespace data
134 } // namespace torch
void set_epoch(size_t epoch)
Set the epoch for the current enumeration.
Definition: distributed.h:41
A Sampler is an object that yields an index with which to access a dataset.
Definition: base.h:23
A Sampler that selects a subset of indices to sample from and defines a sampling behavior.
Definition: distributed.h:26
Definition: jit_type.h:17
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32