Caffe2 - C++ API
A deep learning, cross platform ML framework
chunk.h
1 #pragma once
2 
3 #include <torch/data/datasets/stateful.h>
4 
5 namespace torch {
6 namespace data {
7 namespace datasets {
8 
15 template <typename Chunk = std::vector<Example<>>>
17  public:
18  using ChunkType = Chunk;
19 
21  virtual ChunkType read_chunk(size_t chunk_index) = 0;
22 
24  virtual size_t chunk_count() = 0;
25 
27  virtual void reset() = 0;
28 };
29 
30 namespace detail {
36 template <
37  typename UnwrappedBatch = std::vector<Example<>>,
38  typename ExampleSampler = samplers::RandomSampler>
40  public:
41  using UnwrappedBatchType = UnwrappedBatch;
43  using BatchRequestType = typename ExampleSampler::BatchRequestType;
44 
46  size_t batch_size,
47  ExampleSampler& example_sampler,
48  size_t queue_capacity)
49  : batch_size_(batch_size),
50  example_sampler_(example_sampler),
51  queue_capacity_(queue_capacity) {}
52 
56  std::unique_lock<std::mutex> lock(queue_mutex_);
57  cv_read_.wait(lock, [this] {
58  // wait till there is available data in the queue or if all chunks are
59  // loaded (i.e. the dataset is exhausted for this epoch)
60  return (
61  this->total_example_count_in_queue_ >= batch_size_ ||
62  this->stop_);
63  });
64  if (batch_queue_.empty()) {
65  AT_ASSERT(stop_);
66  // All batches have been retrieved. Return an empty batch.
67  return nullopt;
68  }
69 
70  UnwrappedBatchData batch = std::move(batch_queue_.front());
71  batch_queue_.pop();
72  if (batch.exception) {
73  throw WorkerException(batch.exception);
74  }
75 
76  total_example_count_in_queue_ -= batch.batch_data.size();
77  lock.unlock();
78  cv_write_.notify_all();
79 
80  return batch.batch_data;
81  }
82 
85  void add_chunk_data(UnwrappedBatchType data) {
86  std::unique_lock<std::mutex> lock(queue_mutex_);
87  cv_write_.wait(lock, [this] {
88  // stop loading if we have preloaded enough data.
89  return this->total_example_count_in_queue_ < this->queue_capacity_ ||
90  this->stop_;
91  });
92  if (stop_) {
93  // When stop_ is true, it means no further chunk loading is necessary.
94  // Return without any further processing.
95  return;
96  }
97 
98  auto data_size = data.size();
99  auto remaining_size = data_size;
100  example_sampler_.reset(data_size);
101 
102  auto fill_batch = [&](size_t example_count, UnwrappedBatchType& batch) {
103  auto batch_example_indices = this->example_sampler_.next(example_count);
104  AT_ASSERT(
105  batch_example_indices &&
106  batch_example_indices.value().size() == example_count)
107  BatchRequestType& indices = batch_example_indices.value();
108  for (size_t i : indices) {
109  AT_CHECK(i < data_size, "Index out of range");
110  batch.emplace_back(std::move(data[i]));
111  }
112  remaining_size -= example_count;
113  };
114 
115  if (!batch_queue_.empty()) {
116  // if the queue has existing data, and the last batch doesn't have enough
117  // examples to fill a batch_size batch, add more example to this batch first.
118  auto& batch = batch_queue_.back();
119  size_t current_count = batch.batch_data.size();
120  if (current_count < batch_size_) {
121  auto example_count =
122  std::min(remaining_size, batch_size_ - current_count);
123  fill_batch(example_count, batch.batch_data);
124  }
125  }
126 
127  // If we still have data remaining after filling the last pushed batch, add
128  // them to the queue too.
129  while (remaining_size > 0) {
130  UnwrappedBatchType current_batch;
131 
132  // Allocate the batch memory ahead of time.
133  current_batch.reserve(batch_size_);
134 
135  auto example_count = std::min(remaining_size, batch_size_);
136  fill_batch(example_count, current_batch);
137  batch_queue_.emplace(std::move(current_batch));
138  }
139  total_example_count_in_queue_ += data_size;
140  lock.unlock();
141  cv_read_.notify_all();
142  }
143 
146  void add_chunk_data(std::exception_ptr e_ptr) {
147  std::unique_lock<std::mutex> lock(queue_mutex_);
148  cv_write_.wait(lock, [this] {
149  // stop loading if we have preloaded enough data.
150  return (
151  this->total_example_count_in_queue_ < this->queue_capacity_ ||
152  this->stop_);
153  });
154  if (stop_){
155  // When stop_ is true, it means this current thread needs to be tore down,
156  // the batch buffer will be discarded, so no need to enqueue any new
157  // exceptions.
158  return;
159  }
160 
161  batch_queue_.emplace(e_ptr);
162  lock.unlock();
163  cv_read_.notify_all();
164  }
165 
166  void stop(){
167  {
168  // Hold the lock before changing stop_ to prevent a race condition which can
169  // cause a deadlock.
170  // To be more specific, conditional variable cv_write_ waits on predicate
171  // stop_ in add_chunk_data(). The wait happens in two steps: 1) while still
172  // holding the lock, check if predicate is true; 2) if it is true, proceeds,
173  // otherwise, release the lock and wait until notified. Without holding a
174  // lock, cv_write_'s notification can happen in between step 1) and 2). In
175  // that case, as cv_write_ is not in waiting status yet, so the notification
176  // is lost and cv_write_ will sleep forever.
177  // By taking a lock before changing predicate stop_, it is ensured updating
178  // and evaluating stop_ always happen in a synchronized way
179  std::lock_guard<std::mutex> lock(queue_mutex_);
180  stop_ = true;
181  }
182 
183  // notify all writers, wake them from wait to exit current method.
184  cv_write_.notify_all();
185  // notify all readers too.
186  cv_read_.notify_all();
187  }
191  size_t batch_size_ = 0;
192 
194  size_t total_example_count_in_queue_ = 0;
195 
200  explicit UnwrappedBatchData(UnwrappedBatchType data) : batch_data(std::move(data)) {}
201 
202  explicit UnwrappedBatchData(std::exception_ptr e) : exception(e) {}
203 
205  UnwrappedBatchType batch_data;
206 
209  std::exception_ptr exception;
210  };
211 
213  std::queue<UnwrappedBatchData> batch_queue_;
214 
215  // sync batch_queue_ update.
216  std::mutex queue_mutex_;
217 
218  std::condition_variable cv_read_;
219  std::condition_variable cv_write_;
220 
221  ExampleSampler& example_sampler_;
222 
223  // configurable maximun number of elements the queue can hold at one time.
224  size_t queue_capacity_;
225 
226  // When set to true, it wakes the writer threads from the wait and exit current
227  // function call. This is needed when ChunkDataSet.Reset is called while the
228  // previous epoch is not exhausted yet. When ChunkDataset is waiting its
229  // preloader to finish previous work before tearing down the thread, the
230  // preloader could be still waiting for the conditional variable, thus cause
231  // the program to hang. This boolean is used to break this waiting condition.
232  bool stop_ = false;
233 };
234 } // namespace detail
235 
238  ChunkDatasetOptions() = delete;
240  size_t preloader_count,
241  size_t batch_size,
242  size_t cache_size = 2048)
243  : preloader_count_(preloader_count),
244  batch_size_(batch_size),
245  cache_size_(cache_size) {
246  AT_CHECK(
247  preloader_count_ > 0,
248  "Preloader count is 0. At least one preloader needs to be specified.");
249  AT_CHECK(
250  batch_size_ > 0,
251  "Batch size is 0. A positive batch size needs to be specified.");
252  AT_CHECK(
253  cache_size_ > 0,
254  "Cache size is 0. A positive cache size needs to be specified.");
255  AT_CHECK(
256  cache_size_ >= batch_size_,
257  "Cache size is less than batch size. Cache needs to be large enough to "
258  "hold at least one batch.");
259  }
260 
262  TORCH_ARG(size_t, preloader_count);
263 
265  TORCH_ARG(size_t, batch_size);
266 
267  // the capacity of the queue for batch caching.
268  TORCH_ARG(size_t, cache_size) = 2048;
269 };
270 
279 template <
280  typename ChunkReader,
281  typename ChunkSampler = samplers::RandomSampler,
282  typename ExampleSampler = samplers::RandomSampler>
283 class ChunkDataset final
284  : public StatefulDataset<
285  ChunkDataset<ChunkReader, ChunkSampler, ExampleSampler>,
286  typename ChunkReader::BatchType,
287  size_t> {
288  public:
290  using UnwrappedBatchType = typename ChunkReader::BatchType;
291  using BatchRequestType = size_t;
292  using ChunkSamplerType = ChunkSampler;
293  using ExampleSamplerType = ExampleSampler;
294 
295  ChunkDataset(
296  ChunkReader chunk_reader,
297  ChunkSampler chunk_sampler,
298  ExampleSampler example_sampler,
299  ChunkDatasetOptions options)
300  : chunk_reader_(std::move(chunk_reader)),
301  chunk_sampler_(std::move(chunk_sampler)),
302  example_sampler_(std::move(example_sampler)),
303  options_(std::move(options)),
304  quit_worker_(false),
305  running_preloaders_(0) {}
306 
307  virtual ~ChunkDataset() {
308  // stop batch buffer first.
309  if (batch_buffer_) {
310  batch_buffer_->stop();
311  }
312  free_workers();
313  }
314 
319  BatchType get_batch(size_t batch_size) override {
320  AT_CHECK(
321  batch_buffer_ != nullptr,
322  "Dataset needs to call reset() before calling get_batch().");
323 
324  AT_CHECK(
325  batch_size == options_.batch_size_,
326  "The requested batch size does not match with the initialized batch size.\n"
327  " The requested batch size is ", batch_size,
328  ", while the dataset is created with batch size equal to ", options_.batch_size_);
329 
330  return batch_buffer_->get_batch();
331  }
332 
335  void reset() override {
336  // We need this to support partial data reads via dataloader iterator.
337  if (batch_buffer_) {
338  batch_buffer_->stop();
339  }
340  // free workers from previous reset if there is any.
341  free_workers();
342  preload_threads_.clear();
343 
344  chunk_reader_.reset();
345 
346  chunk_sampler_.reset(chunk_reader_.chunk_count());
347 
348  // Throw out any existing cached batch in the buffer and re-creates a new
349  // chunk buffer.
350  batch_buffer_ = torch::make_unique<
352  options_.batch_size_,
353  example_sampler_,
354  options_.cache_size_);
355 
356  // create new workers for this new epoch.
357  quit_worker_ = false;
358 
359  AT_ASSERT(running_preloaders_ == 0);
360  running_preloaders_ = options_.preloader_count_;
361  for (size_t i = 0; i < options_.preloader_count_; ++i) {
362  preload_threads_.emplace_back([this, i]() { this->preloader(i); });
363  }
364  }
365 
367  optional<size_t> size() const override {
368  return torch::nullopt;
369  }
370 
371  // provide a references to chunk sampler. Used mainly in distributed data
372  // loading to set the epoch number for the sampler.
373  ChunkSamplerType& chunk_sampler() {
374  return chunk_sampler_;
375  }
376 
377  private:
379  void preloader(size_t id) {
380  while (!quit_worker_.load()) {
381  try {
382  size_t chunk_id = 0;
383  {
384  std::lock_guard<std::mutex> lock(chunk_index_guard_);
385  if (auto chunk_sampler_result = chunk_sampler_.next(1)) {
386  chunk_id = chunk_sampler_result.value()[0];
387  } else {
388  break;
389  }
390  }
391  UnwrappedBatchType data = chunk_reader_.read_chunk(chunk_id);
392  if (!data.empty()) { // skip empty chunks.
393  batch_buffer_->add_chunk_data(std::move(data));
394  }
395  } catch (...) {
396  batch_buffer_->add_chunk_data(std::current_exception());
397  }
398  }
399  AT_ASSERT(running_preloaders_.load() > 0);
400  --running_preloaders_;
401  if (running_preloaders_.load() == 0) {
402  // all preloaders are completed, so we can notify the batch_buffer.
403  batch_buffer_->stop();
404  }
405  }
406 
408  void free_workers() {
409  if (!quit_worker_.load()) {
410  quit_worker_ = true;
411  for (auto& worker_thread : preload_threads_) {
412  worker_thread.join();
413  }
414  }
415  }
416 
417  private:
418  // Templated class that defines what is a chunk and how to read chunk data.
419  // When a chunk is returned by chunk_reader_, ChunkDataset split it into
420  // batches and caches them in batch_buffer_.
421  ChunkReader chunk_reader_;
422 
423  // chunk sampler to shuffle different chunks
424  ChunkSamplerType chunk_sampler_;
425 
426  // example sampler to shuffle examples in a specific chunk
427  ExampleSamplerType example_sampler_;
428 
429  // batch data buffer which holds chunk data from preloading thread.
430  std::shared_ptr<detail::BatchDataBuffer<UnwrappedBatchType, ExampleSamplerType>>
431  batch_buffer_;
432 
433  // worker thread pool
434  std::vector<std::thread> preload_threads_;
435 
437  const ChunkDatasetOptions options_;
438 
439  // indicate whether the worker thread can be teared down
440  std::atomic<bool> quit_worker_;
441 
442  // keep track of running preloaders to notify batch buffer. A value 0
443  // indicates that the chunk loading is completed.
444  std::atomic<size_t> running_preloaders_;
445 
446  // mutex to synchronize chunk sampler next() call.
447  std::mutex chunk_index_guard_;
448 };
449 } // namespace datasets
450 } // namespace data
451 } // namespace torch
Interface for chunk reader, which performs data chunking and reading of entire chunks.
Definition: chunk.h:16
optional< size_t > size() const override
size is not used for chunk dataset.
Definition: chunk.h:367
A stateful dataset that support hierarchical sampling and prefetching of entre chunks.
Definition: chunk.h:283
BatchDataBuffer manages a queue of UnwrappedBatchData.
Definition: chunk.h:39
virtual ChunkType read_chunk(size_t chunk_index)=0
Read an entire chunk.
void reset() override
This will clear any internal state and starts the internal prefetching mechanism for the chunk datase...
Definition: chunk.h:335
Options to configure a ChunkDataset.
Definition: chunk.h:237
virtual void reset()=0
This will clear any internal state associate with this reader.
std::exception_ptr exception
exception pointer which captures any abnormal exceptions while creating the batch.
Definition: chunk.h:209
A Sampler that returns random indices.
Definition: random.h:22
virtual size_t chunk_count()=0
Returns the number of chunks available in this reader.
void add_chunk_data(UnwrappedBatchType data)
Push preloaded chunks to batch queue.
Definition: chunk.h:85
UnwrappedBatchType batch_data
batch data to return
Definition: chunk.h:205
BatchType get_batch()
Return batch data from the queue.
Definition: chunk.h:55
A stateful dataset is a dataset that maintains some internal state, which will be reset() at the begi...
Definition: stateful.h:28
Definition: jit_type.h:17
An exception thrown when a DataLoader&#39;s worker thread throws an exception, which is caught...
BatchType get_batch(size_t batch_size) override
Default get_batch method of BatchDataset.
Definition: chunk.h:319
struct that contains a raw unwrapped batch unit.
Definition: chunk.h:199
void add_chunk_data(std::exception_ptr e_ptr)
Push exceptions thrown during preloading into batch queue.
Definition: chunk.h:146
std::queue< UnwrappedBatchData > batch_queue_
local cache to store example batches from loaded chunk
Definition: chunk.h:213