Caffe2 - C++ API
A deep learning, cross platform ML framework
stateful.h
1 #pragma once
2 
3 #include <torch/data/dataloader/base.h>
4 
5 #include <cstddef>
6 #include <thread>
7 #include <utility>
8 
9 namespace torch {
10 namespace data {
11 
24 template <typename Dataset>
26  Dataset,
27  typename Dataset::BatchType::value_type,
28  typename Dataset::BatchRequestType> {
29  public:
30  using super = DataLoaderBase<
31  Dataset,
32  typename Dataset::BatchType::value_type,
33  typename Dataset::BatchRequestType>;
34  using typename super::BatchRequestType;
35 
38  : super(
39  std::move(options),
40  torch::make_unique<Dataset>(std::move(dataset))) {
41  for (size_t w = 0; w < this->options_.workers; ++w) {
42  // As opposed to the stateless case, here all worker threads access the
43  // same underlying dataset.
44  this->workers_.emplace_back(
45  [this] { this->worker_thread(*this->main_thread_dataset_); });
46  }
47  }
48 
49  private:
51  void reset() override {
52  this->main_thread_dataset_->reset();
53  // Call the base class method last because it calls `prefetch()`
54  super::reset();
55  }
56 
59  optional<BatchRequestType> get_batch_request() override {
60  return this->options_.batch_size;
61  }
62 };
63 } // namespace data
64 } // namespace torch
const FullDataLoaderOptions & options() const noexcept
Returns the options with which the DataLoader was configured.
Definition: base.h:95
const FullDataLoaderOptions options_
The options the DataLoader was configured with.
Definition: base.h:221
Options to configure a DataLoader.
virtual void reset()
Resets the internal state of the DataLoader, optionally pre-fetching new jobs.
Definition: base.h:138
Definition: jit_type.h:17
std::vector< std::thread > workers_
The worker threads, running the worker_thread() method.
Definition: base.h:234
StatefulDataLoader(Dataset dataset, DataLoaderOptions options)
Constructs the StatefulDataLoader from a dataset and some options.
Definition: stateful.h:37
std::unique_ptr< Dataset > main_thread_dataset_
The dataset for the main thread, only has a value if the number of worker threads was configured as z...
Definition: base.h:227
A dataloader for stateful datasets.
Definition: stateful.h:25
void worker_thread(Dataset &dataset)
The function that worker threads run.
Definition: base.h:182