3 #include <torch/data/dataloader/stateful.h> 4 #include <torch/data/dataloader/stateless.h> 6 #include <torch/csrc/utils/memory.h> 7 #include <torch/csrc/utils/variadic.h> 9 #include <c10/util/Exception.h> 13 #include <type_traits> 21 template <
typename Dataset,
typename Sampler>
24 std::unique_ptr<StatelessDataLoader<Dataset, Sampler>>>
25 make_data_loader(Dataset dataset, Sampler sampler, DataLoaderOptions options) {
26 return torch::make_unique<StatelessDataLoader<Dataset, Sampler>>(
27 std::move(dataset), std::move(sampler), std::move(options));
33 template <
typename Sampler = samplers::RandomSampler,
typename Dataset>
35 Dataset::is_stateful || !std::is_constructible<Sampler, size_t>::value,
36 std::unique_ptr<StatelessDataLoader<Dataset, Sampler>>>
39 DataLoaderOptions options = DataLoaderOptions()) {
40 const optional<size_t> size = dataset.size();
43 "Expected the dataset to be sized in " 44 "order to construct the Sampler");
45 return make_data_loader(
46 std::move(dataset), Sampler(*size), std::move(options));
50 template <
typename Dataset,
typename = torch::enable_if_t<Dataset::is_stateful>>
51 std::unique_ptr<StatefulDataLoader<Dataset>> make_data_loader(
53 DataLoaderOptions options = DataLoaderOptions()) {
54 return torch::make_unique<StatefulDataLoader<Dataset>>(
55 std::move(dataset), std::move(options));