Caffe2 - C++ API
A deep learning, cross platform ML framework
base.h
1 #pragma once
2 
3 #include <torch/csrc/WindowsTorchApiMacro.h>
4 #include <torch/types.h>
5 
6 #include <cstddef>
7 #include <vector>
8 #include <mutex>
9 
10 namespace torch {
11 namespace serialize {
12 class OutputArchive;
13 class InputArchive;
14 } // namespace serialize
15 } // namespace torch
16 
17 namespace torch {
18 namespace data {
19 namespace samplers {
22 template <typename BatchRequest = std::vector<size_t>>
23 class Sampler {
24  public:
25  using BatchRequestType = BatchRequest;
26 
27  virtual ~Sampler() = default;
28 
32  TORCH_API virtual void reset(optional<size_t> new_size) = 0;
33 
36  TORCH_API virtual optional<BatchRequest> next(size_t batch_size) = 0;
37 
39  TORCH_API virtual void save(serialize::OutputArchive& archive) const = 0;
40 
42  TORCH_API virtual void load(serialize::InputArchive& archive) = 0;
43 };
44 
45 } // namespace samplers
46 } // namespace data
47 } // namespace torch
A Sampler is an object that yields an index with which to access a dataset.
Definition: base.h:23
Definition: jit_type.h:17
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32