Caffe2 - C++ API
A deep learning, cross platform ML framework
iterator.h
1 #pragma once
2 
3 #include <torch/csrc/utils/variadic.h>
4 #include <torch/types.h>
5 
6 #include <c10/util/Exception.h>
7 
8 #include <functional>
9 #include <iterator>
10 #include <memory>
11 #include <type_traits>
12 #include <utility>
13 
14 namespace torch {
15 namespace data {
16 namespace detail {
17 // For increased safety and more separated logic, this implementation of
18 // `Iterator` consists of a `ValidIterator` and a `SentinelIterator`. A
19 // `ValidIterator` yields new batches until the `DataLoader` is exhausted. While
20 // the `DataLoader` is not exhausted, `ValidIterator`s compare equal if they are
21 // the same object. When the `ValidIterator` becomes exhauted, it compares equal
22 // to the `SentinelIterator`, but not before. Half the code here is to implement
23 // double dispatch for the comparison. Got damnit, C++.
24 
25 template <typename Batch>
27 
28 template <typename Batch>
30 
32 template <typename Batch>
33 struct IteratorImpl {
34  virtual ~IteratorImpl() = default;
35  virtual void next() = 0;
36  virtual Batch& get() = 0;
37  virtual bool operator==(const IteratorImpl& other) const = 0;
38  virtual bool operator==(const ValidIterator<Batch>& other) const = 0;
39  virtual bool operator==(const SentinelIterator<Batch>& other) const = 0;
40 };
41 
42 template <typename Batch>
43 struct ValidIterator : public IteratorImpl<Batch> {
44  using BatchProducer = std::function<optional<Batch>()>;
45 
46  explicit ValidIterator(BatchProducer next_batch)
47  : next_batch_(std::move(next_batch)) {}
48 
50  void next() override {
51  // If we didn't get the very first batch yet, get it now.
52  lazy_initialize();
53  AT_CHECK(
54  batch_.has_value(), "Attempted to increment iterator past the end");
55  // Increment to the next batch.
56  batch_ = next_batch_();
57  }
58 
62  Batch& get() override {
63  // If we didn't get the very first batch yet, get it now.
64  lazy_initialize();
65  AT_CHECK(
66  batch_.has_value(),
67  "Attempted to dereference iterator that was past the end");
68  return batch_.value();
69  }
70 
72  bool operator==(const IteratorImpl<Batch>& other) const override {
73  return other == *this;
74  }
75 
78  bool operator==(const SentinelIterator<Batch>& /* unused */) const override {
79  lazy_initialize();
80  return !batch_;
81  }
82 
84  bool operator==(const ValidIterator<Batch>& other) const override {
85  return &other == this;
86  }
87 
89  void lazy_initialize() const {
90  if (!initialized_) {
91  batch_ = next_batch_();
92  initialized_ = true;
93  }
94  }
95 
96  BatchProducer next_batch_;
97  mutable optional<Batch> batch_;
98  mutable bool initialized_ = false;
99 };
100 
101 template <typename Batch>
102 struct SentinelIterator : public IteratorImpl<Batch> {
103  void next() override {
104  AT_ERROR(
105  "Incrementing the DataLoader's past-the-end iterator is not allowed");
106  }
107 
108  Batch& get() override {
109  AT_ERROR(
110  "Dereferencing the DataLoader's past-the-end iterator is not allowed");
111  }
112 
114  bool operator==(const IteratorImpl<Batch>& other) const override {
115  return other == *this;
116  }
117 
120  bool operator==(const ValidIterator<Batch>& other) const override {
121  return other == *this;
122  }
123 
125  bool operator==(const SentinelIterator<Batch>& other) const override {
126  return true;
127  }
128 };
129 } // namespace detail
130 
131 template <typename Batch>
132 class Iterator {
133  public:
134  // Type aliases to make the class recognized as a proper iterator.
135  using difference_type = std::ptrdiff_t;
136  using value_type = Batch;
137  using pointer = Batch*;
138  using reference = Batch&;
139  using iterator_category = std::input_iterator_tag;
140 
141  explicit Iterator(std::unique_ptr<detail::IteratorImpl<Batch>> impl)
142  : impl_(std::move(impl)) {}
143 
147  impl_->next();
148  return *this;
149  }
150 
153  Batch& operator*() {
154  return impl_->get();
155  }
156 
159  Batch* operator->() {
160  return &impl_->get();
161  }
162 
164  bool operator==(const Iterator& other) const {
165  return *impl_ == *other.impl_;
166  }
167 
169  bool operator!=(const Iterator& other) const {
170  return !(*this == other);
171  }
172 
173  private:
175  std::shared_ptr<detail::IteratorImpl<Batch>> impl_;
176 };
177 } // namespace data
178 } // namespace torch
bool operator==(const IteratorImpl< Batch > &other) const override
Does double dispatch.
Definition: iterator.h:114
bool operator==(const ValidIterator< Batch > &other) const override
Returns true if the memory address of other equals that of this.
Definition: iterator.h:84
Base class for the ValidIterator and SentinelIterator
Definition: iterator.h:33
Batch * operator->()
Returns a pointer to the current batch.
Definition: iterator.h:159
bool operator==(const Iterator &other) const
Compares two iterators for equality.
Definition: iterator.h:164
bool operator==(const SentinelIterator< Batch > &) const override
A ValidIterator is equal to the SentinelIterator iff.
Definition: iterator.h:78
bool operator==(const ValidIterator< Batch > &other) const override
Calls the comparison operator between ValidIterator and SentinelIterator.
Definition: iterator.h:120
Iterator & operator++()
Increments the iterator.
Definition: iterator.h:146
bool operator==(const SentinelIterator< Batch > &other) const override
Sentinel iterators always compare equal.
Definition: iterator.h:125
Batch & operator*()
Returns the current batch.
Definition: iterator.h:153
Definition: jit_type.h:17
bool operator==(const IteratorImpl< Batch > &other) const override
Does double dispatch.
Definition: iterator.h:72
bool operator!=(const Iterator &other) const
Compares two iterators for inequality.
Definition: iterator.h:169
void next() override
Fetches the next batch.
Definition: iterator.h:50
void lazy_initialize() const
Gets the very first batch if it has not yet been fetched.
Definition: iterator.h:89