1 #include <gtest/gtest.h>     3 #include <torch/data.h>     4 #include <torch/data/detail/sequencers.h>     5 #include <torch/serialize.h>     6 #include <torch/types.h>     8 #include <test/cpp/api/support.h>    10 #include <c10/util/ArrayRef.h>    23 #include <unordered_set>    28 const std::chrono::milliseconds kMillisecond(1);
    31   explicit DummyDataset(
size_t size = 100) : size_(size) {}
    33   int get(
size_t index) 
override {
    43 TEST(DataTest, DatasetCallsGetCorrectly) {
    45   std::vector<int> batch = d.
get_batch({0, 1, 2, 3, 4});
    46   std::vector<int> expected = {1, 2, 3, 4, 5};
    47   ASSERT_EQ(batch, expected);
    50 TEST(DataTest, TransformCallsGetApplyCorrectly) {
    52     std::string apply(
int input)
 override {
    53       return std::to_string(input);
    58   std::vector<std::string> batch = d.get_batch({0, 1, 2, 3, 4});
    59   std::vector<std::string> expected = {
"1", 
"2", 
"3", 
"4", 
"5"};
    60   ASSERT_EQ(batch, expected);
    68   using BatchType = std::vector<int>;
    73     int start_index = chunk_index == 0
    75         : std::accumulate(chunk_sizes, chunk_sizes + chunk_index, 0);
    77     batch_data.resize(chunk_sizes[chunk_index]);
    79     std::iota(batch_data.begin(), batch_data.end(), start_index);
    90   const static size_t chunk_count_ = 3;
    91   size_t chunk_sizes[chunk_count_] = {10, 5, 20};
    94 TEST(DataTest, ChunkDataSetWithInvalidInitParameter) {
    98   auto initialization_function =
    99       [&](
size_t preloader_count, 
size_t batch_size, 
size_t cache_size) {
   103             samplers::SequentialSampler>>
   105                 DummyChunkDataReader,
   106                 samplers::SequentialSampler,
   107                 samplers::SequentialSampler>>(
   112                     preloader_count, batch_size, cache_size));
   116       initialization_function(0, 1, 1),
   117       "Preloader count is 0. At least one preloader needs to be specified.");
   120       initialization_function(1, 0, 1),
   121       "Batch size is 0. A positive batch size needs to be specified.");
   124       initialization_function(1, 1, 0),
   125       "Cache size is 0. A positive cache size needs to be specified.");
   128       initialization_function(1, 10, 5),
   129       "Cache size is less than batch size. Cache needs to be large enough to "   130       "hold at least one batch.");
   135   std::vector<int> get_batch(
size_t batch_size)
 override {
   136     std::vector<int> batch(batch_size);
   137     for (
auto& i : batch) {
   144     return torch::nullopt;
   151   const size_t kBatchSize = 13;
   156   auto data_loader = torch::data::make_data_loader(
   161   size_t batch_index = 0;
   162   for (
auto& batch : *data_loader) {
   163     ASSERT_LT(batch_index, 3);
   164     ASSERT_EQ(batch.size(), kBatchSize);
   165     for (
size_t j = 0; j < kBatchSize; ++j) {
   166       ASSERT_EQ(batch.at(j), 1 + (batch_index * kBatchSize) + j);
   170   ASSERT_EQ(batch_index, 3);
   173 TEST(DataTest, NoSequencerIsIdentity) {
   176   const auto value = no_sequencer.next([] { 
return 5; }).value();
   180 TEST(DataTest, OrderedSequencerIsSetUpWell) {
   183     size_t sequence_number;
   185   const size_t kMaxJobs = 5;
   188   ASSERT_EQ(sequencer.
buffer_.size(), kMaxJobs);
   191 TEST(DataTest, OrderedSequencerReOrdersValues) {
   194     size_t sequence_number;
   196   const size_t kMaxJobs = 5;
   199   std::vector<size_t> v = {0, 2, 4, 3, 1};
   201   auto getter = [&v, &index]() { 
return S{v.at(index++)}; };
   205   const auto batch = sequencer.
next(getter);
   206   ASSERT_EQ(batch.value().sequence_number, 0);
   210   ASSERT_EQ(1, sequencer.
next(getter).value().sequence_number);
   214   for (
size_t i = 2; i <= 4; ++i) {
   216     ASSERT_EQ(i, sequencer.
next(getter).value().sequence_number);
   222 TEST(DataTest, BatchLambdaAppliesFunctionToBatch) {
   223   using InputBatch = std::vector<int>;
   224   using OutputBatch = std::string;
   227       [](std::vector<int> input) {
   228         return std::to_string(std::accumulate(input.begin(), input.end(), 0));
   230   ASSERT_EQ(e.get_batch({1, 2, 3, 4, 5}), std::string(
"20"));
   233 TEST(DataTest, LambdaAppliesFunctionToExample) {
   235       static_cast<std::string (*)(
int)
>(std::to_string)));
   236   std::vector<std::string> expected = {
"1", 
"2", 
"3", 
"4", 
"5"};
   237   ASSERT_EQ(d.get_batch({0, 1, 2, 3, 4}), expected);
   240 TEST(DataTest, CollateReducesBatch) {
   243         return std::accumulate(input.begin(), input.end(), 0);
   245   ASSERT_EQ(d.get_batch({1, 2, 3, 4, 5}), 20);
   248 TEST(DataTest, CollationReducesBatch) {
   250     int apply_batch(std::vector<int> input)
 override {
   251       return std::accumulate(input.begin(), input.end(), 0);
   255   ASSERT_EQ(d.get_batch({1, 2, 3, 4, 5}), 20);
   258 TEST(DataTest, SequentialSamplerReturnsIndicesInOrder) {
   260   ASSERT_EQ(sampler.
next(3).value(), std::vector<size_t>({0, 1, 2}));
   261   ASSERT_EQ(sampler.
next(5).value(), std::vector<size_t>({3, 4, 5, 6, 7}));
   262   ASSERT_EQ(sampler.
next(2).value(), std::vector<size_t>({8, 9}));
   263   ASSERT_FALSE(sampler.
next(2).has_value());
   266 TEST(DataTest, SequentialSamplerReturnsLessValuesForLastBatch) {
   268   ASSERT_EQ(sampler.
next(3).value(), std::vector<size_t>({0, 1, 2}));
   269   ASSERT_EQ(sampler.
next(100).value(), std::vector<size_t>({3, 4}));
   270   ASSERT_FALSE(sampler.
next(2).has_value());
   273 TEST(DataTest, SequentialSamplerResetsWell) {
   275   ASSERT_EQ(sampler.
next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
   276   ASSERT_FALSE(sampler.
next(2).has_value());
   278   ASSERT_EQ(sampler.
next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
   279   ASSERT_FALSE(sampler.
next(2).has_value());
   282 TEST(DataTest, SequentialSamplerResetsWithNewSizeWell) {
   284   ASSERT_EQ(sampler.
next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
   285   ASSERT_FALSE(sampler.
next(2).has_value());
   288       sampler.
next(7).value(), std::vector<size_t>({0, 1, 2, 3, 4, 5, 6}));
   289   ASSERT_FALSE(sampler.
next(2).has_value());
   291   ASSERT_EQ(sampler.
next(3).value(), std::vector<size_t>({0, 1, 2}));
   292   ASSERT_FALSE(sampler.
next(2).has_value());
   295 TEST(DataTest, CanSaveAndLoadSequentialSampler) {
   298     ASSERT_EQ(a.
index(), 0);
   299     std::stringstream stream;
   300     torch::save(a, stream);
   303     torch::load(b, stream);
   304     ASSERT_EQ(b.
index(), 0);
   310     ASSERT_EQ(a.
index(), 7);
   311     std::stringstream stream;
   312     torch::save(a, stream);
   315     torch::load(b, stream);
   316     ASSERT_EQ(b.
index(), 7);
   320 TEST(DataTest, RandomSamplerReturnsIndicesInCorrectRange) {
   323   std::vector<size_t> indices = sampler.
next(3).value();
   324   for (
auto i : indices) {
   329   indices = sampler.
next(5).value();
   330   for (
auto i : indices) {
   335   indices = sampler.
next(2).value();
   336   for (
auto i : indices) {
   341   ASSERT_FALSE(sampler.
next(10).has_value());
   344 TEST(DataTest, RandomSamplerReturnsLessValuesForLastBatch) {
   346   ASSERT_EQ(sampler.
next(3).value().size(), 3);
   347   ASSERT_EQ(sampler.
next(100).value().size(), 2);
   348   ASSERT_FALSE(sampler.
next(2).has_value());
   351 TEST(DataTest, RandomSamplerResetsWell) {
   353   ASSERT_EQ(sampler.
next(5).value().size(), 5);
   354   ASSERT_FALSE(sampler.
next(2).has_value());
   356   ASSERT_EQ(sampler.
next(5).value().size(), 5);
   357   ASSERT_FALSE(sampler.
next(2).has_value());
   360 TEST(DataTest, RandomSamplerResetsWithNewSizeWell) {
   362   ASSERT_EQ(sampler.
next(5).value().size(), 5);
   363   ASSERT_FALSE(sampler.
next(2).has_value());
   365   ASSERT_EQ(sampler.
next(7).value().size(), 7);
   366   ASSERT_FALSE(sampler.
next(2).has_value());
   368   ASSERT_EQ(sampler.
next(3).value().size(), 3);
   369   ASSERT_FALSE(sampler.
next(2).has_value());
   372 TEST(DataTest, SavingAndLoadingRandomSamplerYieldsSameSequence) {
   376     std::stringstream stream;
   377     torch::save(a, stream);
   380     torch::load(b, stream);
   382     ASSERT_EQ(a.
next(10).value(), b.
next(10).value());
   387     ASSERT_EQ(a.
index(), 3);
   389     std::stringstream stream;
   390     torch::save(a, stream);
   393     torch::load(b, stream);
   394     ASSERT_EQ(b.
index(), 3);
   396     auto b_sequence = b.
next(10).value();
   397     ASSERT_EQ(b_sequence.size(), 7);
   398     ASSERT_EQ(a.
next(10).value(), b_sequence);
   402 TEST(DataTest, StreamSamplerReturnsTheBatchSizeAndThenRemainder) {
   404   ASSERT_EQ(sampler.
next(10).value(), 10);
   405   ASSERT_EQ(sampler.
next(2).value(), 2);
   406   ASSERT_EQ(sampler.
next(85).value(), 85);
   407   ASSERT_EQ(sampler.
next(123).value(), 3);
   408   ASSERT_FALSE(sampler.
next(1).has_value());
   411 TEST(DataTest, StreamSamplerResetsWell) {
   413   ASSERT_EQ(sampler.
next(5).value().size(), 5);
   414   ASSERT_FALSE(sampler.
next(2).has_value());
   416   ASSERT_EQ(sampler.
next(5).value().size(), 5);
   417   ASSERT_FALSE(sampler.
next(2).has_value());
   420 TEST(DataTest, StreamSamplerResetsWithNewSizeWell) {
   422   ASSERT_EQ(sampler.
next(5).value().size(), 5);
   423   ASSERT_FALSE(sampler.
next(2).has_value());
   425   ASSERT_EQ(sampler.
next(7).value().size(), 7);
   426   ASSERT_FALSE(sampler.
next(2).has_value());
   428   ASSERT_EQ(sampler.
next(3).value().size(), 3);
   429   ASSERT_FALSE(sampler.
next(2).has_value());
   432 TEST(DataTest, TensorDatasetConstructsFromSingleTensor) {
   435       torch::tensor({0, 0, 1, 0, 0}, torch::kFloat32).allclose(dataset.
get(2)));
   438 TEST(DataTest, TensorDatasetConstructsFromInitializerListOfTensors) {
   439   std::vector<torch::Tensor> vector = torch::eye(5).chunk(5);
   442       torch::tensor({0, 0, 1, 0, 0}, torch::kFloat32).allclose(dataset.
get(2)));
   445 TEST(DataTest, StackTransformWorksForExample) {
   448       return {tensor[index], 1 + tensor[index]};
   452       return tensor.size(0);
   461   ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(0, 0, 2)));
   462   ASSERT_TRUE(batch.target.allclose(1 + torch::eye(4).slice(0, 0, 2)));
   465   ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(0, 2, 4)));
   466   ASSERT_TRUE(second.target.allclose(1 + torch::eye(4).slice(0, 2, 4)));
   469 TEST(DataTest, StackTransformWorksForTensorExample) {
   474   ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(0, 0, 2)));
   477   ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(0, 2, 4)));
   481 template <
typename Target>
   490           Dataset<TensorStringDataset, Example<torch::Tensor, std::string>> {
   492     return {torch::tensor(static_cast<double>(index)), std::to_string(index)};
   500 TEST(DataTest, TensorTransformWorksForAnyTargetType) {
   502   std::vector<Example<torch::Tensor, std::string>> batch = d.get_batch({1, 2});
   504   ASSERT_EQ(batch.size(), 2);
   505   ASSERT_TRUE(batch[0].data.allclose(torch::tensor(2.0)));
   506   ASSERT_EQ(batch[0].target, 
"1");
   508   ASSERT_TRUE(batch[1].data.allclose(torch::tensor(4.0)));
   509   ASSERT_EQ(batch[1].target, 
"2");
   512 TEST(DataTest, TensorLambdaWorksforAnyTargetType) {
   515   std::vector<Example<torch::Tensor, std::string>> batch = d.get_batch({1, 2});
   517   ASSERT_EQ(batch.size(), 2);
   518   ASSERT_TRUE(batch[0].data.allclose(torch::tensor(2.0)));
   519   ASSERT_EQ(batch[0].target, 
"1");
   521   ASSERT_TRUE(batch[1].data.allclose(torch::tensor(4.0)));
   522   ASSERT_EQ(batch[1].target, 
"2");
   528     const auto channels = 
static_cast<int64_t
>(index);
   530         (channels > 0) ? torch::ones({channels, 4, 4}) : torch::ones({4, 4});
   531     return {tensor, 
static_cast<int>(channels)};
   539 TEST(DataTest, NormalizeTransform) {
   543   std::vector<Example<torch::Tensor, int>> output = dataset.get_batch(0);
   544   ASSERT_EQ(output.size(), 1);
   546   ASSERT_TRUE(output[0].data.allclose(torch::ones({4, 4}) * 5))
   550   output = dataset.get_batch(1);
   551   ASSERT_EQ(output.size(), 1);
   552   ASSERT_EQ(output[0].data.size(0), 1);
   553   ASSERT_TRUE(output[0].data.allclose(torch::ones({1, 4, 4}) * 5))
   559   output = dataset.get_batch(2);
   560   ASSERT_EQ(output.size(), 1);
   561   ASSERT_EQ(output[0].data.size(0), 2);
   562   ASSERT_TRUE(output[0]
   564                   .allclose(torch::ones({1, 4, 4}) * 5))
   566   ASSERT_TRUE(output[0]
   568                   .allclose(torch::ones({1, 4, 4}) * -2.5))
   573   output = dataset.get_batch(3);
   574   ASSERT_EQ(output.size(), 1);
   575   ASSERT_EQ(output[0].data.size(0), 3);
   576   ASSERT_TRUE(output[0].data.allclose(torch::ones({3, 4, 4}) * -2.5))
   582   output = dataset.get_batch(3);
   583   ASSERT_EQ(output.size(), 1);
   584   ASSERT_EQ(output[0].data.size(0), 3);
   585   ASSERT_TRUE(output[0]
   587                   .allclose(torch::ones({1, 4, 4}) * 5))
   589   ASSERT_TRUE(output[0]
   591                   .allclose(torch::ones({1, 4, 4}) * -2.5))
   593   ASSERT_TRUE(output[0]
   595                   .allclose(torch::ones({1, 4, 4}) * 12.5))
   611     return {torch::tensor(static_cast<int64_t>(index)),
   612             torch::tensor(static_cast<int64_t>(index))};
   620 TEST(DataTest, MapDoesNotCopy) {
   629   auto data = dataset.get_batch(1).at(0).data;
   630   ASSERT_EQ(data.numel(), 1);
   631   ASSERT_EQ(data[0].item<float>(), 7);
   634 TEST(DataTest, QueuePushAndPopFromSameThread) {
   638   ASSERT_EQ(queue.
pop(), 1);
   639   ASSERT_EQ(queue.
pop(), 2);
   642 TEST(DataTest, QueuePopWithTimeoutThrowsUponTimeout) {
   645       queue.
pop(10 * kMillisecond),
   646       "Timeout in DataLoader queue while waiting for next batch "   647       "(timeout was 10 ms)");
   650 TEST(DataTest, QueuePushAndPopFromDifferentThreads) {
   658         std::async(std::launch::async, [&queue] { 
return queue.pop(); });
   659     ASSERT_EQ(future.get(), 1);
   665     std::thread thread([&queue] {
   666       std::this_thread::sleep_for(20 * kMillisecond);
   669     ASSERT_EQ(queue.pop(), 123);
   674 TEST(DataTest, QueueClearEmptiesTheQueue) {
   679   ASSERT_EQ(queue.
clear(), 3);
   680   ASSERT_THROWS_WITH(queue.
pop(1 * kMillisecond), 
"Timeout");
   683 TEST(DataTest, DataShuttleCanPushAndPopJob) {
   687   ASSERT_EQ(shuttle.
pop_job(), 1);
   688   ASSERT_EQ(shuttle.
pop_job(), 2);
   691 TEST(DataTest, DataShuttleCanPushAndPopResult) {
   706 TEST(DataTest, DataShuttlePopResultReturnsNulloptWhenNoJobsInFlight) {
   708   ASSERT_FALSE(shuttle.
pop_result().has_value());
   713   ASSERT_FALSE(shuttle.
pop_result().has_value());
   714   ASSERT_FALSE(shuttle.
pop_result().has_value());
   717 TEST(DataTest, DataShuttleDrainMeansPopResultReturnsNullopt) {
   722   ASSERT_FALSE(shuttle.
pop_result().has_value());
   725 TEST(DataTest, DataShuttlePopResultTimesOut) {
   728   ASSERT_THROWS_WITH(shuttle.
pop_result(10 * kMillisecond), 
"Timeout");
   740   int get(
size_t index) 
override {
   748 TEST(DataTest, SharedBatchDatasetReallyIsShared) {
   754   auto shared_dataset =
   755       torch::data::datasets::make_shared_dataset<UncopyableDataset>(
   758   auto data_loader = torch::data::make_data_loader(
   761   for (
auto batch : *data_loader) {
   766 TEST(DataTest, SharedBatchDatasetDoesNotIncurCopyWhenPassedDatasetObject) {
   768   auto shared_dataset =
   769       torch::data::datasets::make_shared_dataset<UncopyableDataset>(
   771   ASSERT_EQ(shared_dataset.size().value(), 100);
   775   explicit TestIndex(
size_t offset, std::vector<size_t> index)
   776       : offset(offset), index(std::move(index)) {}
   781   std::vector<size_t> index;
   787     std::iota(data.begin(), data.end(), size_t(0));
   790     std::vector<int> batch;
   791     for (
auto i : index.index) {
   792       batch.push_back(index.offset + data.at(i));
   799   std::vector<int> data;
   806     if (index_ >= size_) {
   807       return torch::nullopt;
   809     std::vector<size_t> indices(batch_size);
   810     std::iota(indices.begin(), indices.end(), size_t(0));
   811     index_ += batch_size;
   812     return TestIndex(batch_size, std::move(indices));
   820 TEST(DataTest, CanUseCustomTypeAsIndexType) {
   821   const int kBatchSize = 10;
   822   auto data_loader = torch::data::make_data_loader(
   826   for (
auto batch : *data_loader) {
   827     for (
int j = 0; j < kBatchSize; ++j) {
   828       ASSERT_EQ(batch.at(j), 10 + j);
   834 TEST(DataTest, DistributedRandomSamplerSingleReplicaProduceCorrectSamples) {
   835   size_t sample_count = 10;
   838   std::vector<size_t> res;
   840   while ((idx = drs.
next(3)).has_value()) {
   841     res.insert(std::end(res), std::begin(*idx), std::end(*idx));
   844   ASSERT_EQ(res.size(), sample_count);
   846   std::sort(res.begin(), res.end());
   847   for (
size_t i = 0; i < res.size(); ++i) {
   848     ASSERT_EQ(res[i], i);
   852 TEST(DataTest, DistributedRandomSamplerMultiReplicaProduceCorrectSamples) {
   853   size_t sample_count = 10;
   854   size_t num_replicas = 3;
   856   auto test_function = [&](
bool allow_duplicates,
   857                            size_t local_sample_count,
   858                            std::vector<size_t>& output,
   860     std::vector<std::unique_ptr<samplers::DistributedRandomSampler>> samplers;
   862     for (
size_t i = 0; i < num_replicas; ++i) {
   863       samplers.emplace_back(
   864           torch::make_unique<samplers::DistributedRandomSampler>(
   865               sample_count, num_replicas, i, allow_duplicates));
   868     std::vector<size_t> res;
   869     for (
size_t i = 0; i < num_replicas; ++i) {
   870       (*samplers[i]).reset();
   872       while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
   873         res.insert(std::end(res), std::begin(*idx), std::end(*idx));
   875       ASSERT_EQ(res.size(), local_sample_count * (i + 1));
   877     std::sort(res.begin(), res.end());
   878     ASSERT_EQ(res, output);
   881   for (
size_t batch_size = 1; batch_size <= 3; ++batch_size) {
   882     size_t local_sample_count =
   883         static_cast<size_t>(std::ceil(sample_count * 1.0 / num_replicas));
   884     std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9};
   885     test_function(
true, local_sample_count, output1, batch_size);
   888         static_cast<size_t>(std::floor(sample_count * 1.0 / num_replicas));
   889     std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8};
   890     test_function(
false, local_sample_count, output2, batch_size);
   894 TEST(DataTest, CanSaveAndLoadDistributedRandomSampler) {
   897     ASSERT_EQ(a.
index(), 0);
   898     std::stringstream stream;
   899     torch::save(a, stream);
   902     torch::load(b, stream);
   903     ASSERT_EQ(b.
index(), 0);
   909     ASSERT_EQ(a.
index(), 7);
   910     std::stringstream stream;
   911     torch::save(a, stream);
   914     torch::load(b, stream);
   915     ASSERT_EQ(b.
index(), 7);
   920     std::stringstream stream;
   921     torch::save(a, stream);
   924     torch::load(b, stream);
   925     ASSERT_EQ(b.epoch(), 3);
   929 TEST(DataTest, DistributedSequentialSamplerSingleReplicaProduceCorrectSamples) {
   930   size_t sample_count = 10;
   931   size_t batch_size = 3;
   934   std::vector<size_t> res;
   936   while ((idx = dss.
next(batch_size)).has_value()) {
   937     res.insert(std::end(res), std::begin(*idx), std::end(*idx));
   940   ASSERT_EQ(res.size(), sample_count);
   942   std::sort(res.begin(), res.end());
   943   for (
size_t i = 0; i < res.size(); ++i) {
   944     ASSERT_EQ(res[i], i);
   948 TEST(DataTest, DistributedSequentialSamplerMultiReplicaProduceCorrectSamples) {
   949   size_t sample_count = 10;
   950   size_t num_replicas = 3;
   952   auto test_function = [&](
bool allow_duplicates,
   953                            size_t local_sample_count,
   954                            std::vector<size_t>& output,
   956     std::vector<std::unique_ptr<samplers::DistributedSequentialSampler>>
   959     for (
size_t i = 0; i < num_replicas; ++i) {
   960       samplers.emplace_back(
   961           torch::make_unique<samplers::DistributedSequentialSampler>(
   962               sample_count, num_replicas, i, allow_duplicates));
   965     std::vector<size_t> res;
   966     for (
size_t i = 0; i < num_replicas; ++i) {
   967       (*samplers[i]).reset();
   969       while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
   970         res.insert(std::end(res), std::begin(*idx), std::end(*idx));
   972       ASSERT_EQ(res.size(), local_sample_count * (i + 1));
   974     std::sort(res.begin(), res.end());
   975     ASSERT_EQ(res, output);
   978   for (
size_t batch_size = 1; batch_size <= 3; ++batch_size) {
   979     size_t local_sample_count =
   980         static_cast<size_t>(std::ceil(sample_count * 1.0 / num_replicas));
   981     std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9};
   982     test_function(
true, local_sample_count, output1, batch_size);
   985         static_cast<size_t>(std::floor(sample_count * 1.0 / num_replicas));
   986     std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8};
   987     test_function(
false, local_sample_count, output2, batch_size);
   991 TEST(DataTest, CanSaveAndLoadDistributedSequentialSampler) {
   994     ASSERT_EQ(a.
index(), 0);
   995     std::stringstream stream;
   996     torch::save(a, stream);
   999     torch::load(b, stream);
  1000     ASSERT_EQ(b.
index(), 0);
  1006     ASSERT_EQ(a.
index(), 7);
  1007     std::stringstream stream;
  1008     torch::save(a, stream);
  1011     torch::load(b, stream);
  1012     ASSERT_EQ(b.
index(), 7);
  1016 TEST(DataLoaderTest, DataLoaderOptionsDefaultAsExpected) {
  1019   ASSERT_EQ(full_options.batch_size, 1);
  1020   ASSERT_FALSE(full_options.drop_last);
  1021   ASSERT_EQ(full_options.workers, 0);
  1022   ASSERT_EQ(full_options.max_jobs, 0);
  1023   ASSERT_FALSE(full_options.timeout.has_value());
  1024   ASSERT_TRUE(full_options.enforce_ordering);
  1027 TEST(DataLoaderTest, DataLoaderOptionsCoalesceOptionalValues) {
  1030   ASSERT_EQ(full_options.batch_size, 32);
  1031   ASSERT_EQ(full_options.max_jobs, 2 * 10);
  1034 TEST(DataLoaderTest, MakeDataLoaderDefaultsAsExpected) {
  1035   auto data_loader = torch::data::make_data_loader(
  1037   ASSERT_EQ(data_loader->options().batch_size, 1);
  1042     return {torch::ones(i), torch::ones(i)};
  1045     return torch::nullopt;
  1051     MakeDataLoaderThrowsWhenConstructingSamplerWithUnsizedDataset) {
  1054       "Expected the dataset to be sized in order to construct the Sampler");
  1057 TEST(DataLoaderTest, IteratorsCompareEqualToThemselves) {
  1058   auto data_loader = torch::data::make_data_loader(
DummyDataset(), 32);
  1059   auto begin = data_loader->begin();
  1060   ASSERT_EQ(begin, begin);
  1061   auto end = data_loader->end();
  1062   ASSERT_EQ(end, end);
  1065 TEST(DataLoaderTest, ValidIteratorsCompareUnequalToEachOther) {
  1066   auto data_loader = torch::data::make_data_loader(
DummyDataset(), 32);
  1067   auto i = data_loader->begin();
  1068   auto j = data_loader->begin();
  1074 TEST(DataLoaderTest, SentinelIteratorsCompareEqualToEachOther) {
  1075   auto data_loader = torch::data::make_data_loader(
DummyDataset(), 32);
  1076   auto i = data_loader->end();
  1077   auto j = data_loader->end();
  1081 TEST(DataLoaderTest, IteratorsCompareEqualToSentinelWhenExhausted) {
  1084       torch::data::make_data_loader(dataset, dataset.
size().value() / 4);
  1085   auto i = data_loader->begin();
  1086   auto end = data_loader->end();
  1098 TEST(DataLoaderTest, IteratorsShareState) {
  1101       torch::data::make_data_loader(dataset, dataset.
size().value() / 2);
  1102   auto i = data_loader->begin();
  1104   auto end = data_loader->end();
  1115 TEST(DataLoaderTest, CanDereferenceIteratorMultipleTimes) {
  1118       torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
  1121   auto iterator = data_loader->begin();
  1122   std::vector<int> expected = {1};
  1123   ASSERT_EQ(*iterator, expected);
  1124   ASSERT_EQ(*iterator, expected);
  1127   ASSERT_EQ(*iterator, expected);
  1128   ASSERT_EQ(*iterator, expected);
  1131   ASSERT_EQ(*iterator, expected);
  1132   ASSERT_EQ(*iterator, expected);
  1135 TEST(DataLoaderTest, CanUseIteratorAlgorithms) {
  1138       return 1 + indices.
front();
  1147       torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
  1149   std::vector<int> values;
  1151       data_loader->begin(), data_loader->end(), std::back_inserter(values));
  1152   std::vector<int> expected(dataset.size().value());
  1153   std::iota(expected.begin(), expected.end(), size_t(1));
  1154   ASSERT_EQ(values, expected);
  1157 TEST(DataLoaderTest, CallingBeginWhileOtherIteratorIsInFlightThrows) {
  1161   auto i = data_loader->begin();
  1163       data_loader->begin(),
  1164       "Attempted to get a new DataLoader iterator "  1165       "while another iterator is not yet exhausted");
  1168 TEST(DataLoaderTest, IncrementingExhaustedValidIteratorThrows) {
  1171       torch::data::make_data_loader(dataset, dataset.
size().value());
  1172   auto i = data_loader->begin();
  1173   ASSERT_NO_THROW(++i);
  1174   ASSERT_THROWS_WITH(++i, 
"Attempted to increment iterator past the end");
  1177 TEST(DataLoaderTest, DereferencingExhaustedValidIteratorThrows) {
  1180       torch::data::make_data_loader(dataset, dataset.
size().value());
  1181   auto i = data_loader->begin();
  1182   ASSERT_NO_THROW(++i);
  1184       *i, 
"Attempted to dereference iterator that was past the end");
  1187 TEST(DataLoaderTest, IncrementingSentinelIteratorThrows) {
  1190       torch::data::make_data_loader(dataset, dataset.
size().value());
  1191   auto i = data_loader->end();
  1194       "Incrementing the DataLoader's past-the-end iterator is not allowed");
  1197 TEST(DataLoaderTest, DereferencingSentinelIteratorThrows) {
  1200       torch::data::make_data_loader(dataset, dataset.
size().value());
  1201   auto i = data_loader->end();
  1204       "Dereferencing the DataLoader's past-the-end iterator is not allowed");
  1207 TEST(DataLoaderTest, YieldsCorrectBatchSize) {
  1209   auto data_loader = torch::data::make_data_loader(dataset, 25);
  1210   auto iterator = data_loader->begin();
  1211   ASSERT_EQ(iterator->size(), 25);
  1212   ASSERT_EQ((++iterator)->size(), 25);
  1213   ASSERT_EQ((++iterator)->size(), 25);
  1214   ASSERT_EQ((++iterator)->size(), 25);
  1215   ASSERT_EQ(++iterator, data_loader->end());
  1220     ReturnsLastBatchWhenSmallerThanBatchSizeWhenDropLastIsFalse) {
  1222   auto data_loader = torch::data::make_data_loader(
  1224   auto iterator = data_loader->begin();
  1225   ASSERT_EQ(iterator->size(), 33);
  1226   ASSERT_EQ((++iterator)->size(), 33);
  1227   ASSERT_EQ((++iterator)->size(), 33);
  1228   ASSERT_EQ((++iterator)->size(), 1);
  1229   ASSERT_EQ(++iterator, data_loader->end());
  1234     DoesNotReturnLastBatchWhenSmallerThanBatchSizeWhenDropLastIsTrue) {
  1236   auto data_loader = torch::data::make_data_loader(
  1238   auto iterator = data_loader->begin();
  1239   ASSERT_EQ(iterator->size(), 33);
  1240   ASSERT_EQ((++iterator)->size(), 33);
  1241   ASSERT_EQ((++iterator)->size(), 33);
  1242   ASSERT_EQ(++iterator, data_loader->end());
  1245 TEST(DataLoaderTest, RespectsTimeout) {
  1247     std::condition_variable cv;
  1252     D(std::shared_ptr<Baton> b) : baton(std::move(b)) {}
  1253     int get(
size_t index) 
override {
  1254       std::unique_lock<std::mutex> lock(baton->mutex);
  1255       baton->cv.wait_for(lock, 1000 * kMillisecond);
  1261     std::shared_ptr<Baton> baton;
  1264   auto baton = std::make_shared<Baton>();
  1266   auto data_loader = torch::data::make_data_loader(
  1269   auto start = std::chrono::system_clock::now();
  1271   ASSERT_THROWS_WITH(*data_loader->begin(), 
"Timeout");
  1272   baton->cv.notify_one();
  1274   auto end = std::chrono::system_clock::now();
  1275   auto duration = std::chrono::duration_cast<std::chrono::seconds>(end - start);
  1276   ASSERT_LT(duration.count(), 1);
  1281   explicit Barrier(
size_t target) : counter_(target) {}
  1283     std::unique_lock<std::mutex> lock(mutex_);
  1284     if (--counter_ == 0) {
  1287       cv_.wait(lock, [
this] { 
return this->counter_ == 0; });
  1292   std::condition_variable cv_;
  1339 const size_t kNumberOfWorkers = 10;
  1340 const std::vector<size_t> kOrderInWhichWorkersReturnTheirBatch =
  1341     {3, 7, 0, 5, 4, 8, 2, 1, 9, 6};
  1350     static std::atomic<size_t> counter{0};
  1351     thread_id_ = counter.fetch_add(1);
  1359     static Barrier barrier(kNumberOfWorkers);
  1360     static auto order_iterator = kOrderInWhichWorkersReturnTheirBatch.begin();
  1361     static std::condition_variable cv;
  1362     static std::mutex mutex;
  1367     std::unique_lock<std::mutex> lock(mutex);
  1368     cv.wait(lock, [
this] { 
return *order_iterator == this->thread_id_; });
  1373     return indices.
front();
  1377     return kNumberOfWorkers;
  1380   size_t thread_id_ = 0;
  1385 TEST(DataLoaderTest, EnforcesOrderingAmongThreadsWhenConfigured) {
  1386   auto data_loader = torch::data::make_data_loader(
  1391           .workers(ordering_test::kNumberOfWorkers)
  1392           .enforce_ordering(
true));
  1393   std::vector<size_t> output;
  1394   for (
size_t value : *data_loader) {
  1395     output.push_back(value);
  1397   std::vector<size_t> expected(ordering_test::kNumberOfWorkers);
  1398   std::iota(expected.begin(), expected.end(), size_t(0));
  1399   ASSERT_EQ(expected, output);
  1402 TEST(DataLoaderTest, Reset) {
  1405       torch::data::make_data_loader(dataset, dataset.
size().value() / 2);
  1406   auto end = data_loader->end();
  1408   auto iterator = data_loader->begin();
  1409   ASSERT_NE(iterator, end);
  1410   ASSERT_NE(++iterator, end);
  1411   ASSERT_EQ(++iterator, end);
  1413   iterator = data_loader->begin();
  1414   ASSERT_NE(iterator, end);
  1415   ASSERT_NE(++iterator, end);
  1416   ASSERT_EQ(++iterator, end);
  1418   iterator = data_loader->begin();
  1419   ASSERT_NE(iterator, end);
  1420   ASSERT_NE(++iterator, end);
  1421   ASSERT_EQ(++iterator, end);
  1424 TEST(DataLoaderTest, TestExceptionsArePropagatedFromWorkers) {
  1426     int get(
size_t index) 
override {
  1427       throw std::invalid_argument(
"badness");
  1434   auto data_loader = torch::data::make_data_loader(
  1436   auto iterator = data_loader->begin();
  1443         std::string(
"Caught exception in DataLoader worker thread. "  1444                     "Original message: badness"));
  1450 TEST(DataLoaderTest, StatefulDatasetWithNoWorkers) {
  1451   const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
  1455       if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
  1458       return torch::nullopt;
  1463     void reset()
 override {
  1469   auto data_loader = torch::data::make_data_loader(
D{});
  1471   for (
size_t i = 0; i < 10; ++i) {
  1472     const auto number_of_iterations =
  1473         std::distance(data_loader->begin(), data_loader->end());
  1475         number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
  1479   for (
const int i : *data_loader) {
  1480     ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
  1484 TEST(DataLoaderTest, StatefulDatasetWithManyWorkers) {
  1485   const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
  1486   const int kNumberOfWorkers = 4;
  1490       std::lock_guard<std::mutex> lock(mutex);
  1491       if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
  1494       return torch::nullopt;
  1499     void reset()
 override {
  1506   auto data_loader = torch::data::make_data_loader(
  1507       torch::data::datasets::make_shared_dataset<D>(),
  1510   for (
size_t i = 0; i < 10; ++i) {
  1511     const auto number_of_iterations =
  1512         std::distance(data_loader->begin(), data_loader->end());
  1514         number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
  1518   for (
const int i : *data_loader) {
  1519     ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
  1523 TEST(DataLoaderTest, StatefulDatasetWithMap) {
  1524   const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
  1528       if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
  1531       return torch::nullopt;
  1536     void reset()
 override {
  1542   auto data_loader = torch::data::make_data_loader(
  1544                   [](
int x) { 
return std::to_string(x); }))
  1546               [](
const std::string& x) {
  1547                 return torch::tensor(static_cast<int64_t>(std::stoi(x)));
  1551   for (
size_t i = 0; i < 10; ++i) {
  1552     const auto number_of_iterations =
  1553         std::distance(data_loader->begin(), data_loader->end());
  1555         number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
  1560     ASSERT_LT(t.item<int64_t>(), kNumberOfExamplesAfterWhichTheDatasetExhausts);
  1564 TEST(DataLoaderTest, StatefulDatasetWithCollate) {
  1565   const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
  1569         size_t batch_size)
 override {
  1570       if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
  1571         counter += batch_size;
  1572         std::vector<Example<>> batch(
  1575                       torch::zeros(batch_size - 1)});
  1578       return torch::nullopt;
  1583     void reset()
 override {
  1591   const size_t kBatchSize = 5;
  1596   ASSERT_TRUE(batch.has_value());
  1597   ASSERT_EQ(batch->data.size(0), kBatchSize);
  1598   ASSERT_EQ(batch->data.size(1), kBatchSize + 1);
  1599   ASSERT_EQ(batch->target.size(0), kBatchSize);
  1600   ASSERT_EQ(batch->target.size(1), kBatchSize - 1);
  1602   ASSERT_TRUE(batch->data[0].allclose(torch::ones(kBatchSize + 1)));
  1603   ASSERT_TRUE(batch->target[0].allclose(torch::zeros(kBatchSize - 1)));
  1610 TEST(DataLoaderTest, ChunkDataSetGetBatch) {
  1612   const size_t prefetch_counts[] = {1, 2, 3, 4};
  1615   const size_t batch_sizes[] = {5, 7};
  1618   const size_t dataloader_worker_counts[] = {0, 2};
  1620   const size_t total_example_count = 35;
  1625   const int epoch_count = 2;
  1627   for (
auto prefetch_count : prefetch_counts) {
  1628     for (
auto batch_size : batch_sizes) {
  1629       for (
auto dataloader_worker_count : dataloader_worker_counts) {
  1633             samplers::SequentialSampler>>
  1635                 DummyChunkDataReader,
  1636                 samplers::SequentialSampler,
  1637                 samplers::SequentialSampler>>(
  1643         auto data_loader = torch::data::make_data_loader(
  1647         for (
int epoch_index = 0; epoch_index < epoch_count; ++epoch_index) {
  1648           std::vector<bool> result(total_example_count, 
false);
  1649           int iteration_count = 0;
  1650           for (
auto iterator = data_loader->begin();
  1651                iterator != data_loader->end();
  1652                ++iterator, ++iteration_count) {
  1653             std::vector<int>& batch = *iterator;
  1654             ASSERT_EQ(batch.size(), batch_size);
  1658             if (prefetch_count == 1 && dataloader_worker_count == 0) {
  1659               for (
size_t j = 0; j < batch_size; ++j) {
  1660                 ASSERT_EQ(batch[j], iteration_count * batch_size + j);
  1663             for (
size_t j = 0; j < batch_size; ++j) {
  1664               result[batch[j]] = 
true;
  1668           for (
auto data : result) {
  1669             ASSERT_EQ(data, 
true);
  1677 TEST(DataLoaderTest, ChunkDataSetWithBatchSizeMismatch) {
  1678   const size_t prefetch_count = 1;
  1679   const size_t batch_size = 5;
  1680   const size_t requested_batch_size = 6;
  1688       samplers::SequentialSampler>>
  1690           DummyChunkDataReader,
  1691           samplers::SequentialSampler,
  1692           samplers::SequentialSampler>>(
  1698   auto data_loader = torch::data::make_data_loader(
  1702   std::string exception_msg =
  1703       "The requested batch size does not match with the initialized batch "  1704       "size.\n The requested batch size is 6, while the dataset is created"  1705       " with batch size equal to 5";
  1707   ASSERT_THROWS_WITH(*(data_loader->begin()), exception_msg);
  1710 TEST(DataLoaderTest, ChunkDataSetWithEmptyBatch) {
  1711   struct DummyEmptyChunkDataReader
  1714     using BatchType = std::vector<int>;
  1716     BatchType read_chunk(
size_t chunk_index)
 override {
  1720     size_t chunk_count()
 override {
  1724     void reset()
 override{};
  1727   const size_t prefetch_count = 1;
  1728   const size_t batch_size = 5;
  1729   DummyEmptyChunkDataReader data_reader;
  1733       DummyEmptyChunkDataReader,
  1735       samplers::SequentialSampler>>
  1737           DummyEmptyChunkDataReader,
  1738           samplers::SequentialSampler,
  1739           samplers::SequentialSampler>>(
  1745   auto data_loader = torch::data::make_data_loader(
  1748   for (
auto iterator = data_loader->begin(); iterator != data_loader->end();
  1750     ASSERT_EQ(iterator->size(), 0);
  1754 TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) {
  1757     using BatchType = std::vector<int>;
  1759     BatchType read_chunk(
size_t chunk_index)
 override {
  1760       BatchType batch_data(10, 0);
  1764     size_t chunk_count()
 override {
  1768     void reset()
 override{};
  1771   const size_t batch_sizes[] = {17, 30};
  1775   for (
auto batch_size : batch_sizes) {
  1779         samplers::SequentialSampler>>
  1782             samplers::SequentialSampler,
  1783             samplers::SequentialSampler>>(
  1789     auto data_loader = torch::data::make_data_loader(
  1792     for (
auto iterator = data_loader->begin(); iterator != data_loader->end();
  1794       std::vector<int> batch = *iterator;
  1795       auto batch_size = batch.size();
  1796       if (batch_size == 17) {
  1797         ASSERT_TRUE(batch.size() == 17 || batch.size() == 3);
  1799       if (batch_size == 30) {
  1800         ASSERT_TRUE(batch.size() == 20);
  1806 TEST(DataLoaderTest, CanAccessChunkSamplerWithChunkDataSet) {
  1807   const size_t prefetch_count = 2;
  1808   const size_t batch_size = 5;
  1815       samplers::SequentialSampler>>
  1817           DummyChunkDataReader,
  1818           samplers::SequentialSampler,
  1819           samplers::SequentialSampler>>(
  1825   samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler();
  1827   auto data_loader = torch::data::make_data_loader(
  1829           [](std::vector<int> batch) {
  1830             return std::accumulate(batch.begin(), batch.end(), 0);
  1835   ASSERT_EQ(chunk_sampler.index(), 0);
  1838   for (
auto iterator = data_loader->begin(); iterator != data_loader->end();
  1842   ASSERT_EQ(sum, 595); 
  1844   ASSERT_EQ(chunk_sampler.index(), 3);
  1847 TEST(DataLoaderTest, ChunkDatasetDoesNotHang) {
  1848   const size_t prefetch_count = 2;
  1849   const size_t batch_size = 5;
  1851   const size_t cache_size = 10;
  1858       samplers::SequentialSampler>>
  1860           DummyChunkDataReader,
  1861           samplers::SequentialSampler,
  1862           samplers::SequentialSampler>>(
  1867               prefetch_count, batch_size, cache_size));
  1869   samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler();
  1871   auto data_loader = torch::data::make_data_loader(
  1873           [](std::vector<int> batch) {
  1874             return std::accumulate(batch.begin(), batch.end(), 0);
  1880   auto iterator = data_loader->begin();
 size_t next_sequence_number_
The monotonically increasing sequence number we expect. 
 
Interface for chunk reader, which performs data chunking and reading of entire chunks. 
 
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized. 
 
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized. 
 
void save(torch::serialize::OutputArchive &archive) const override
Serializes the Sampler to the archive. 
 
A stateful dataset that support hierarchical sampling and prefetching of entre chunks. 
 
TORCH_API void reset(optional< size_t > new_size=nullopt) override
Resets the RandomSampler to a new set of indices. 
 
void load(torch::serialize::InputArchive &archive) override
Deserializes the Sampler from the archive. 
 
TORCH_API size_t index() const noexcept
Returns the current index of the DistributedSequentialSampler. 
 
TensorExample get(size_t index) override
Returns a single TensorExample. 
 
void reset() override
This will clear any internal state associate with this reader. 
 
AT_CPP14_CONSTEXPR const T & front() const 
front - Get the first element. 
 
MapDataset< Self, TransformType > map(TransformType transform)&
Creates a MapDataset that applies the given transform to this dataset. 
 
BatchType read_chunk(size_t chunk_index) override
Read an entire chunk. 
 
TORCH_API size_t index() const noexcept
Returns the current index of the DistributedRandomSampler. 
 
TORCH_API optional< BatchSize > next(size_t batch_size) override
Returns a BatchSize object with the number of elements to fetch in the next batch. 
 
Like DataLoaderOptions, but without any unconfigured state. 
 
Job pop_job()
Returns the next job, blocking until there is one available. 
 
An Example from a dataset. 
 
TORCH_API optional< std::vector< size_t > > next(size_t batch_size) override
Returns the next batch of indices. 
 
A Sequencer that does not enforce any ordering. 
 
TORCH_API optional< std::vector< size_t > > next(size_t batch_size) override
Returns the next batch of indices. 
 
Options to configure a ChunkDataset. 
 
std::vector< ExampleType > get_batch(ArrayRef< size_t > indices) override
Returns a batch of data. 
 
void set_epoch(size_t epoch)
Set the epoch for the current enumeration. 
 
A dataset that can yield data only in batches. 
 
torch::Tensor operator()(torch::Tensor input) override
Transforms a single input tensor to an output tensor. 
 
optional< Result > next(ResultProducer next_result) override
Buffers results until the next one in the expected order is received. 
 
size_t chunk_count() override
Returns the number of chunks available in this reader. 
 
size_t clear()
Empties the queue and returns the number of elements that were present at the start of the function...
 
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized. 
 
torch::optional< size_t > size() const noexcept
Returns the size of the dataset, or an empty optional if it is unsized. 
 
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized. 
 
A Sequencer that buffers results and returns them in order of their sequence number. 
 
TORCH_API optional< std::vector< size_t > > next(size_t batch_size) override
Returns the next batch of indices. 
 
Options to configure a DataLoader. 
 
A Sampler that returns random indices. 
 
std::vector< int > get_batch(TestIndex index) override
Returns a batch of data given an index. 
 
A Sampler is an object that yields an index with which to access a dataset. 
 
optional< Result > pop_result(optional< std::chrono::milliseconds > timeout=nullopt)
Returns the result of a job, or nullopt if all jobs were exhausted. 
 
A dataset that can yield data in batches, or as individual examples. 
 
void push_job(Job job)
Pushes a new job. Called by the main thread. 
 
Encapsulates the full life cycle of DataLoader jobs. 
 
TORCH_API void reset(optional< size_t > new_size=nullopt) override
Resets the internal state of the sampler. 
 
A stateful dataset is a dataset that maintains some internal state, which will be reset() at the begi...
 
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized. 
 
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized. 
 
A dataset that wraps another dataset in a shared pointer and implements the BatchDataset API...
 
A basic locked, blocking MPMC queue. 
 
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized. 
 
An exception thrown when a DataLoader's worker thread throws an exception, which is caught...
 
void push(T value)
Pushes a new value to the back of the Queue and notifies one thread on the waiting side about this ev...
 
Select samples sequentially. 
 
void push_result(Result result)
Pushes the result of a job. Called by worker threads. 
 
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
 
TORCH_API void reset(optional< size_t > new_size=nullopt) override
Resets the SequentialSampler to zero. 
 
A Sampler that returns indices sequentially. 
 
TORCH_API size_t index() const noexcept
Returns the current index of the RandomSampler. 
 
A sampler for (potentially infinite) streams of data. 
 
std::exception_ptr original_exception
The original exception thrown in the worker thread. 
 
std::vector< optional< Result > > buffer_
A fixed-size buffer (after construction). 
 
size_t size() const override
The number of elements accessed by this index. 
 
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized. 
 
T pop(optional< std::chrono::milliseconds > timeout=nullopt)
Blocks until at least one element is ready to be popped from the front of the queue. 
 
A base class for custom index types. 
 
void drain()
Discards any jobs that are not yet in flight, and waits for all in-flight jobs to finish...
 
torch::optional< TestIndex > next(size_t batch_size) override
Returns the next index if possible, or an empty optional if the sampler is exhausted for this epoch...
 
TORCH_API size_t index() const noexcept
Returns the current index of the SequentialSampler. 
 
TORCH_API optional< std::vector< size_t > > next(size_t batch_size) override
Returns the next batch of indices. 
 
void reset(torch::optional< size_t > new_size=torch::nullopt) override
Resets the Sampler's internal state.