Caffe2 - C++ API
A deep learning, cross platform ML framework
stateless.h
1 #pragma once
2 
3 #include <torch/data/dataloader/base.h>
4 #include <torch/data/worker_exception.h>
5 
6 #include <torch/csrc/utils/memory.h>
7 
8 #include <c10/util/Exception.h>
9 
10 #include <cstddef>
11 #include <thread>
12 #include <utility>
13 
14 namespace torch {
15 namespace data {
16 
24 template <typename Dataset, typename Sampler>
26  Dataset,
27  typename Dataset::BatchType,
28  typename Sampler::BatchRequestType> {
29  public:
30  using super = DataLoaderBase<
31  Dataset,
32  typename Dataset::BatchType,
33  typename Sampler::BatchRequestType>;
34  using typename super::BatchRequestType;
35 
39  Dataset dataset,
40  Sampler sampler,
42  : super(std::move(options)), sampler_(std::move(sampler)) {
43  for (size_t w = 0; w < this->options_.workers; ++w) {
44  // Here we copy the dataset into the worker thread closure. Each worker
45  // has its own copy of the dataset. This means the dataset must be
46  // trivially copiable, or else we don't expect more than one worker to
47  // be in use.
48  this->workers_.emplace_back(
49  [this, dataset]() mutable { this->worker_thread(dataset); });
50  }
51  if (this->options_.workers == 0) {
52  this->main_thread_dataset_ =
53  torch::make_unique<Dataset>(std::move(dataset));
54  }
55  }
56 
57  private:
59  void reset() override {
60  sampler_.reset();
61  // Call the base class method last because it calls `prefetch()`
62  super::reset();
63  }
64 
67  optional<BatchRequestType> get_batch_request() override {
68  auto indices = sampler_.next(this->options_.batch_size);
69  if (!indices ||
70  (indices->size() < this->options_.batch_size &&
71  this->options_.drop_last)) {
72  return nullopt;
73  }
74  AT_ASSERT(indices->size() > 0);
75  return indices;
76  }
77 
79  Sampler sampler_;
80 };
81 } // namespace data
82 } // namespace torch
A dataloader for stateless datasets.
Definition: stateless.h:25
const FullDataLoaderOptions & options() const noexcept
Returns the options with which the DataLoader was configured.
Definition: base.h:95
const FullDataLoaderOptions options_
The options the DataLoader was configured with.
Definition: base.h:221
Options to configure a DataLoader.
virtual void reset()
Resets the internal state of the DataLoader, optionally pre-fetching new jobs.
Definition: base.h:138
Definition: jit_type.h:17
std::vector< std::thread > workers_
The worker threads, running the worker_thread() method.
Definition: base.h:234
StatelessDataLoader(Dataset dataset, Sampler sampler, DataLoaderOptions options)
Constructs the StatelessDataLoader from a dataset, a sampler and some options.
Definition: stateless.h:38
std::unique_ptr< Dataset > main_thread_dataset_
The dataset for the main thread, only has a value if the number of worker threads was configured as z...
Definition: base.h:227
void worker_thread(Dataset &dataset)
The function that worker threads run.
Definition: base.h:182