Caffe2 - C++ API
A deep learning, cross platform ML framework
stream.h
1 #pragma once
2 
3 #include <torch/csrc/WindowsTorchApiMacro.h>
4 #include <torch/data/samplers/base.h>
5 #include <torch/data/samplers/custom_batch_request.h>
6 #include <torch/types.h>
7 
8 #include <cstddef>
9 
10 namespace torch {
11 namespace serialize {
12 class InputArchive;
13 class OutputArchive;
14 } // namespace serialize
15 } // namespace torch
16 
17 namespace torch {
18 namespace data {
19 namespace samplers {
20 
23 struct TORCH_API BatchSize : public CustomBatchRequest {
24  explicit BatchSize(size_t size);
25  size_t size() const noexcept override;
26  operator size_t() const noexcept;
27  size_t size_;
28 };
29 
35 class StreamSampler : public Sampler<BatchSize> {
36  public:
39  TORCH_API explicit StreamSampler(size_t epoch_size);
40 
42  TORCH_API void reset(optional<size_t> new_size = nullopt) override;
43 
48  TORCH_API optional<BatchSize> next(size_t batch_size) override;
49 
51  TORCH_API void save(serialize::OutputArchive& archive) const override;
52 
54  TORCH_API void load(serialize::InputArchive& archive) override;
55 
56  private:
57  size_t examples_retrieved_so_far_ = 0;
58  size_t epoch_size_;
59 };
60 
61 } // namespace samplers
62 } // namespace data
63 } // namespace torch
A Sampler is an object that yields an index with which to access a dataset.
Definition: base.h:23
Definition: jit_type.h:17
A wrapper around a batch size value, which implements the CustomBatchRequest interface.
Definition: stream.h:23
A sampler for (potentially infinite) streams of data.
Definition: stream.h:35
A base class for custom index types.
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32