3 #include <torch/csrc/WindowsTorchApiMacro.h> 4 #include <torch/data/samplers/base.h> 25 template <
typename BatchRequest = std::vector<
size_t>>
30 size_t num_replicas = 1,
32 bool allow_duplicates =
true)
34 num_replicas_(num_replicas),
37 allow_duplicates_(allow_duplicates) {}
45 size_t epoch()
const {
50 size_t local_sample_count() {
51 if (allow_duplicates_) {
52 return (size_ + num_replicas_ - 1) / num_replicas_;
54 return size_ / num_replicas_;
62 bool allow_duplicates_;
71 size_t num_replicas = 1,
73 bool allow_duplicates =
true);
88 TORCH_API
size_t index()
const noexcept;
91 void populate_indices();
96 std::vector<size_t> all_indices_;
104 size_t num_replicas = 1,
106 bool allow_duplicates =
true);
121 TORCH_API
size_t index()
const noexcept;
124 void populate_indices();
128 size_t sample_index_;
129 std::vector<size_t> all_indices_;
void set_epoch(size_t epoch)
Set the epoch for the current enumeration.
A Sampler is an object that yields an index with which to access a dataset.
A Sampler that selects a subset of indices to sample from and defines a sampling behavior.
Select samples sequentially.