Caffe2 - C++ API
A deep learning, cross platform ML framework
dataloader.h
1 #pragma once
2 
3 #include <torch/data/dataloader/stateful.h>
4 #include <torch/data/dataloader/stateless.h>
5 
6 #include <torch/csrc/utils/memory.h>
7 #include <torch/csrc/utils/variadic.h>
8 
9 #include <c10/util/Exception.h>
10 
11 #include <cstddef>
12 #include <memory>
13 #include <type_traits>
14 #include <utility>
15 
16 namespace torch {
17 namespace data {
18 
21 template <typename Dataset, typename Sampler>
22 torch::disable_if_t<
23  Dataset::is_stateful,
24  std::unique_ptr<StatelessDataLoader<Dataset, Sampler>>>
25 make_data_loader(Dataset dataset, Sampler sampler, DataLoaderOptions options) {
26  return torch::make_unique<StatelessDataLoader<Dataset, Sampler>>(
27  std::move(dataset), std::move(sampler), std::move(options));
28 }
29 
33 template <typename Sampler = samplers::RandomSampler, typename Dataset>
34 torch::disable_if_t<
35  Dataset::is_stateful || !std::is_constructible<Sampler, size_t>::value,
36  std::unique_ptr<StatelessDataLoader<Dataset, Sampler>>>
37 make_data_loader(
38  Dataset dataset,
39  DataLoaderOptions options = DataLoaderOptions()) {
40  const optional<size_t> size = dataset.size();
41  AT_CHECK(
42  size.has_value(),
43  "Expected the dataset to be sized in "
44  "order to construct the Sampler");
45  return make_data_loader(
46  std::move(dataset), Sampler(*size), std::move(options));
47 }
48 
50 template <typename Dataset, typename = torch::enable_if_t<Dataset::is_stateful>>
51 std::unique_ptr<StatefulDataLoader<Dataset>> make_data_loader(
52  Dataset dataset,
53  DataLoaderOptions options = DataLoaderOptions()) {
54  return torch::make_unique<StatefulDataLoader<Dataset>>(
55  std::move(dataset), std::move(options));
56 }
57 } // namespace data
58 } // namespace torch
Definition: jit_type.h:17