Caffe2 - C++ API
A deep learning, cross platform ML framework
sequencers.h
1 #pragma once
2 
3 #include <torch/types.h>
4 
5 #include <algorithm>
6 #include <cstddef>
7 #include <vector>
8 
9 namespace torch {
10 namespace data {
11 namespace detail {
12 namespace sequencers {
13 namespace detail {
14 template<typename Result>
15 bool buffer_contains_result(const std::vector<optional<Result>>& buffer) {
16  return std::any_of(
17  buffer.begin(), buffer.end(), [](const optional<Result>& result) {
18  return result.has_value();
19  });
20 }
21 } // namespace detail
22 
28 template <typename Result>
29 struct Sequencer {
30  using ResultProducer = std::function<optional<Result>()>;
31  virtual ~Sequencer() = default;
32  virtual optional<Result> next(ResultProducer next_result) = 0;
33 };
34 
37 template <typename Result>
38 struct NoSequencer final : public Sequencer<Result> {
39  using typename Sequencer<Result>::ResultProducer;
40  optional<Result> next(ResultProducer next_result) override {
41  return next_result();
42  }
43 };
44 
62 template <typename Result>
63 struct OrderedSequencer : public Sequencer<Result> {
64  using typename Sequencer<Result>::ResultProducer;
65 
68  explicit OrderedSequencer(size_t max_jobs) : buffer_(max_jobs) {}
69 
71  optional<Result> next(ResultProducer next_result) override {
72  // If we already have the result for the next sqn, return it.
73  if (auto& maybe_result = buffer(next_sequence_number_)) {
74  auto result = std::move(*maybe_result);
75  buffer(next_sequence_number_++).reset();
76  return result;
77  }
78  // Otherwise wait for the next result.
79  while (true) {
80  auto result = next_result();
81  if (!result) {
82  AT_ASSERT(!detail::buffer_contains_result(buffer_));
83  break;
84  }
85  // If it was not nullopt and the sequence numbers match, return it
86  // directly and bump the sequence number.
87  if (result->sequence_number == next_sequence_number_) {
88  ++next_sequence_number_;
89  return result;
90  }
91  // Stash the result for later.
92  AT_ASSERT(!buffer(result->sequence_number).has_value());
93  buffer(result->sequence_number) = std::move(result);
94  }
95  // The result was an empty optional, so we are done with this epoch.
96  return nullopt;
97  }
98 
100  optional<Result>& buffer(size_t index) {
101  return buffer_.at(index % buffer_.size());
102  }
103 
105  size_t next_sequence_number_ = 0;
106 
108  std::vector<optional<Result>> buffer_;
109 };
110 } // namespace sequencers
111 } // namespace detail
112 } // namespace data
113 } // namespace torch
A Sequencer accepts a function that yields the next result of a DataLoader and then has the opportuni...
Definition: sequencers.h:29
A Sequencer that does not enforce any ordering.
Definition: sequencers.h:38
optional< Result > next(ResultProducer next_result) override
Buffers results until the next one in the expected order is received.
Definition: sequencers.h:71
OrderedSequencer(size_t max_jobs)
Constructs the OrderedSequencer with the maximum number of results it will ever hold at one point in ...
Definition: sequencers.h:68
A Sequencer that buffers results and returns them in order of their sequence number.
Definition: sequencers.h:63
optional< Result > & buffer(size_t index)
Accesses the buffer at the index modulo the buffer size.
Definition: sequencers.h:100
Definition: jit_type.h:17
std::vector< optional< Result > > buffer_
A fixed-size buffer (after construction).
Definition: sequencers.h:108