Caffe2 - C++ API
A deep learning, cross platform ML framework
random.h
1 #pragma once
2 
3 #include <torch/csrc/WindowsTorchApiMacro.h>
4 #include <torch/data/samplers/base.h>
5 #include <torch/types.h>
6 
7 #include <cstddef>
8 #include <vector>
9 
10 namespace torch {
11 namespace serialize {
12 class OutputArchive;
13 class InputArchive;
14 } // namespace serialize
15 } // namespace torch
16 
17 namespace torch {
18 namespace data {
19 namespace samplers {
20 
22 class RandomSampler : public Sampler<> {
23  public:
29  TORCH_API explicit RandomSampler(
30  int64_t size,
31  Dtype index_dtype = torch::kInt64);
32 
34  TORCH_API void reset(optional<size_t> new_size = nullopt) override;
35 
37  TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override;
38 
40  TORCH_API void save(serialize::OutputArchive& archive) const override;
41 
43  TORCH_API void load(serialize::InputArchive& archive) override;
44 
46  TORCH_API size_t index() const noexcept;
47 
48  private:
49  Tensor indices_;
50  int64_t index_ = 0;
51 };
52 } // namespace samplers
53 } // namespace data
54 } // namespace torch
A Sampler that returns random indices.
Definition: random.h:22
A Sampler is an object that yields an index with which to access a dataset.
Definition: base.h:23
Definition: jit_type.h:17
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32