Caffe2 - C++ API
A deep learning, cross platform ML framework
base.h
1 #pragma once
2 
3 #include <torch/data/dataloader_options.h>
4 #include <torch/data/detail/data_shuttle.h>
5 #include <torch/data/detail/sequencers.h>
6 #include <torch/data/iterator.h>
7 #include <torch/data/samplers/random.h>
8 #include <torch/data/worker_exception.h>
9 #include <torch/types.h>
10 
11 #include <torch/csrc/utils/memory.h>
12 #include <torch/csrc/utils/variadic.h>
13 
14 #include <c10/util/Exception.h>
15 
16 #include <cstddef>
17 #include <exception>
18 #include <memory>
19 #include <thread>
20 #include <type_traits>
21 #include <utility>
22 #include <vector>
23 
24 namespace torch {
25 namespace data {
26 template <typename Dataset, typename Batch, typename BatchRequest>
28  public:
29  using BatchType = Batch;
30  using BatchRequestType = BatchRequest;
31 
37  std::unique_ptr<Dataset> main_thread_dataset = nullptr)
38  : options_(std::move(options)),
39  main_thread_dataset_(std::move(main_thread_dataset)),
41 
42  virtual ~DataLoaderBase() {
43  join();
44  }
45 
58  AT_CHECK(
59  shuttle_.in_flight_jobs() == 0,
60  "Attempted to get a new DataLoader iterator "
61  "while another iterator is not yet exhausted");
62  reset();
63  return Iterator<Batch>(torch::make_unique<detail::ValidIterator<Batch>>(
64  [this] { return this->next(); }));
65  }
66 
70  return Iterator<Batch>(
71  torch::make_unique<detail::SentinelIterator<Batch>>());
72  }
73 
77  void join() {
78  if (joined_) {
79  return;
80  }
81  shuttle_.drain();
82  // Send one 'quit' message per worker. Since a worker dies (exits its
83  // thread) after receiving this message, each `QuitWorker()` message will be
84  // read by exactly one worker.
85  for (size_t w = 0; w < options_.workers; ++w) {
86  push_job(QuitWorker());
87  }
88  for (auto& worker : workers_) {
89  worker.join();
90  }
91  joined_ = true;
92  }
93 
95  const FullDataLoaderOptions& options() const noexcept {
96  return options_;
97  }
98 
99  protected:
101  struct Sequenced {
102  Sequenced() = default;
103  Sequenced(size_t sqn) : sequence_number(sqn) {}
104  size_t sequence_number;
105  };
106 
107  struct QuitWorker {};
108 
111  struct Job : Sequenced {
112  Job() = default;
113  Job(QuitWorker q, size_t sqn) : Sequenced(sqn), quit(q) {}
114  Job(BatchRequest&& i, size_t sqn)
115  : Sequenced(sqn), batch_request(std::move(i)) {}
117  optional<BatchRequest> batch_request;
118  };
119 
121  struct Result : Sequenced {
122  Result() = default;
123  Result(optional<Batch>&& b, size_t sqn)
124  : Sequenced(sqn), batch(std::move(b)) {}
125  Result(std::exception_ptr exception, size_t sqn)
126  : Sequenced(sqn), exception(std::move(exception)) {}
127  optional<Batch> batch;
128  std::exception_ptr exception;
129  };
130 
135 
138  virtual void reset() {
139  shuttle_.drain();
140  sequence_number_ = 0;
142  prefetch();
143  }
144 
147  void prefetch(size_t requested_jobs) {
148  for (size_t r = 0; r < requested_jobs; ++r) {
149  if (auto batch_request = get_batch_request()) {
150  this->push_job(std::move(*batch_request));
151  } else {
152  break;
153  }
154  }
155  }
156 
158  void prefetch() {
159  prefetch(options_.max_jobs);
160  }
161 
166  if (options_.workers > 0) {
167  while (optional<Result> result = this->pop_result()) {
168  if (result->exception) {
169  throw WorkerException(result->exception);
170  } else if (result->batch) {
171  prefetch(1);
172  return std::move(result->batch);
173  }
174  }
175  } else if (auto batch_request = get_batch_request()) {
176  return this->main_thread_dataset_->get_batch(std::move(*batch_request));
177  }
178  return nullopt;
179  }
180 
182  void worker_thread(Dataset& dataset) {
183  while (true) {
184  auto job = shuttle_.pop_job();
185  if (job.quit) {
186  break;
187  }
188  try {
189  auto batch = dataset.get_batch(std::move(*job.batch_request));
190  shuttle_.push_result({std::move(batch), job.sequence_number});
191  } catch (...) {
192  shuttle_.push_result({std::current_exception(), job.sequence_number});
193  }
194  }
195  }
196 
199  template <typename T>
200  void push_job(T value) {
201  shuttle_.push_job({std::move(value), sequence_number_++});
202  }
203 
206  return sequencer_->next(
207  [this] { return this->shuttle_.pop_result(this->options_.timeout); });
208  }
209 
212  std::unique_ptr<detail::sequencers::Sequencer<Result>> new_sequencer() {
213  if (options_.enforce_ordering) {
214  return torch::make_unique<detail::sequencers::OrderedSequencer<Result>>(
215  options_.max_jobs);
216  }
217  return torch::make_unique<detail::sequencers::NoSequencer<Result>>();
218  }
219 
222 
227  std::unique_ptr<Dataset> main_thread_dataset_;
228 
231  size_t sequence_number_ = 0;
232 
234  std::vector<std::thread> workers_;
235 
238 
240  std::unique_ptr<detail::sequencers::Sequencer<Result>> sequencer_;
241 
243  bool joined_ = false;
244 };
245 } // namespace data
246 } // namespace torch
void prefetch()
Schedules the maximum number of jobs (based on the max_jobs option).
Definition: base.h:158
void join()
Joins the DataLoader&#39;s worker threads and drains internal queues.
Definition: base.h:77
size_t sequence_number_
The sequence number for the next batch to be retrieved from the dataset.
Definition: base.h:231
Like DataLoaderOptions, but without any unconfigured state.
const FullDataLoaderOptions & options() const noexcept
Returns the options with which the DataLoader was configured.
Definition: base.h:95
bool joined_
True if the DataLoader has joined its worker threads.
Definition: base.h:243
const FullDataLoaderOptions options_
The options the DataLoader was configured with.
Definition: base.h:221
Simple mix-in to give something a sequence number.
Definition: base.h:101
The finished result of a job.
Definition: base.h:121
virtual optional< BatchRequestType > get_batch_request()=0
Subclass hook for getting the next batch request.
Options to configure a DataLoader.
Iterator< Batch > end()
Returns a special "sentinel" iterator that compares equal with a non-sentinel iterator once the DataL...
Definition: base.h:69
DataLoaderBase(DataLoaderOptions options, std::unique_ptr< Dataset > main_thread_dataset=nullptr)
Constructs a new DataLoader from a dataset to sample from, options to configure the DataLoader with...
Definition: base.h:35
Encapsulates the full life cycle of DataLoader jobs.
Definition: data_shuttle.h:26
virtual void reset()
Resets the internal state of the DataLoader, optionally pre-fetching new jobs.
Definition: base.h:138
Definition: jit_type.h:17
void prefetch(size_t requested_jobs)
Schedules requested_jobs many new batches to be fetched.
Definition: base.h:147
optional< Result > pop_result()
Convenience method that gets the next result from the sequencer.
Definition: base.h:205
An exception thrown when a DataLoader&#39;s worker thread throws an exception, which is caught...
detail::DataShuttle< Job, Result > shuttle_
The DataShuttle which takes care of the life cycle of a job.
Definition: base.h:237
std::vector< std::thread > workers_
The worker threads, running the worker_thread() method.
Definition: base.h:234
std::unique_ptr< detail::sequencers::Sequencer< Result > > sequencer_
The Sequencer, which handles optional ordering of batches.
Definition: base.h:240
Iterator< Batch > begin()
Returns an iterator into the DataLoader.
Definition: base.h:57
void push_job(T value)
Convenience method that calls shuttle_.push_job() with the next sequence number.
Definition: base.h:200
optional< BatchType > next()
Returns the next batch of data, or an empty optional if the DataLoader is exhausted.
Definition: base.h:165
std::unique_ptr< detail::sequencers::Sequencer< Result > > new_sequencer()
Convenience method that creates a new sequencer based on the enforce_ordering option.
Definition: base.h:212
A Job is either a BatchRequest (new indices to fetch data at) or a QuitWorker object, to indicate the worker should shut down.
Definition: base.h:111
std::unique_ptr< Dataset > main_thread_dataset_
The dataset for the main thread, only has a value if the number of worker threads was configured as z...
Definition: base.h:227
void worker_thread(Dataset &dataset)
The function that worker threads run.
Definition: base.h:182