1 #include <gtest/gtest.h> 3 #include <torch/data.h> 4 #include <torch/data/detail/sequencers.h> 5 #include <torch/serialize.h> 6 #include <torch/types.h> 8 #include <test/cpp/api/support.h> 10 #include <c10/util/ArrayRef.h> 23 #include <unordered_set> 28 const std::chrono::milliseconds kMillisecond(1);
31 explicit DummyDataset(
size_t size = 100) : size_(size) {}
33 int get(
size_t index)
override {
43 TEST(DataTest, DatasetCallsGetCorrectly) {
45 std::vector<int> batch = d.
get_batch({0, 1, 2, 3, 4});
46 std::vector<int> expected = {1, 2, 3, 4, 5};
47 ASSERT_EQ(batch, expected);
50 TEST(DataTest, TransformCallsGetApplyCorrectly) {
52 std::string apply(
int input)
override {
53 return std::to_string(input);
58 std::vector<std::string> batch = d.get_batch({0, 1, 2, 3, 4});
59 std::vector<std::string> expected = {
"1",
"2",
"3",
"4",
"5"};
60 ASSERT_EQ(batch, expected);
68 using BatchType = std::vector<int>;
73 int start_index = chunk_index == 0
75 : std::accumulate(chunk_sizes, chunk_sizes + chunk_index, 0);
77 batch_data.resize(chunk_sizes[chunk_index]);
79 std::iota(batch_data.begin(), batch_data.end(), start_index);
90 const static size_t chunk_count_ = 3;
91 size_t chunk_sizes[chunk_count_] = {10, 5, 20};
94 TEST(DataTest, ChunkDataSetWithInvalidInitParameter) {
98 auto initialization_function =
99 [&](
size_t preloader_count,
size_t batch_size,
size_t cache_size) {
103 samplers::SequentialSampler>>
105 DummyChunkDataReader,
106 samplers::SequentialSampler,
107 samplers::SequentialSampler>>(
112 preloader_count, batch_size, cache_size));
116 initialization_function(0, 1, 1),
117 "Preloader count is 0. At least one preloader needs to be specified.");
120 initialization_function(1, 0, 1),
121 "Batch size is 0. A positive batch size needs to be specified.");
124 initialization_function(1, 1, 0),
125 "Cache size is 0. A positive cache size needs to be specified.");
128 initialization_function(1, 10, 5),
129 "Cache size is less than batch size. Cache needs to be large enough to " 130 "hold at least one batch.");
135 std::vector<int> get_batch(
size_t batch_size)
override {
136 std::vector<int> batch(batch_size);
137 for (
auto& i : batch) {
144 return torch::nullopt;
151 const size_t kBatchSize = 13;
156 auto data_loader = torch::data::make_data_loader(
161 size_t batch_index = 0;
162 for (
auto& batch : *data_loader) {
163 ASSERT_LT(batch_index, 3);
164 ASSERT_EQ(batch.size(), kBatchSize);
165 for (
size_t j = 0; j < kBatchSize; ++j) {
166 ASSERT_EQ(batch.at(j), 1 + (batch_index * kBatchSize) + j);
170 ASSERT_EQ(batch_index, 3);
173 TEST(DataTest, NoSequencerIsIdentity) {
176 const auto value = no_sequencer.next([] {
return 5; }).value();
180 TEST(DataTest, OrderedSequencerIsSetUpWell) {
183 size_t sequence_number;
185 const size_t kMaxJobs = 5;
188 ASSERT_EQ(sequencer.
buffer_.size(), kMaxJobs);
191 TEST(DataTest, OrderedSequencerReOrdersValues) {
194 size_t sequence_number;
196 const size_t kMaxJobs = 5;
199 std::vector<size_t> v = {0, 2, 4, 3, 1};
201 auto getter = [&v, &index]() {
return S{v.at(index++)}; };
205 const auto batch = sequencer.
next(getter);
206 ASSERT_EQ(batch.value().sequence_number, 0);
210 ASSERT_EQ(1, sequencer.
next(getter).value().sequence_number);
214 for (
size_t i = 2; i <= 4; ++i) {
216 ASSERT_EQ(i, sequencer.
next(getter).value().sequence_number);
222 TEST(DataTest, BatchLambdaAppliesFunctionToBatch) {
223 using InputBatch = std::vector<int>;
224 using OutputBatch = std::string;
227 [](std::vector<int> input) {
228 return std::to_string(std::accumulate(input.begin(), input.end(), 0));
230 ASSERT_EQ(e.get_batch({1, 2, 3, 4, 5}), std::string(
"20"));
233 TEST(DataTest, LambdaAppliesFunctionToExample) {
235 static_cast<std::string (*)(
int)
>(std::to_string)));
236 std::vector<std::string> expected = {
"1",
"2",
"3",
"4",
"5"};
237 ASSERT_EQ(d.get_batch({0, 1, 2, 3, 4}), expected);
240 TEST(DataTest, CollateReducesBatch) {
243 return std::accumulate(input.begin(), input.end(), 0);
245 ASSERT_EQ(d.get_batch({1, 2, 3, 4, 5}), 20);
248 TEST(DataTest, CollationReducesBatch) {
250 int apply_batch(std::vector<int> input)
override {
251 return std::accumulate(input.begin(), input.end(), 0);
255 ASSERT_EQ(d.get_batch({1, 2, 3, 4, 5}), 20);
258 TEST(DataTest, SequentialSamplerReturnsIndicesInOrder) {
260 ASSERT_EQ(sampler.
next(3).value(), std::vector<size_t>({0, 1, 2}));
261 ASSERT_EQ(sampler.
next(5).value(), std::vector<size_t>({3, 4, 5, 6, 7}));
262 ASSERT_EQ(sampler.
next(2).value(), std::vector<size_t>({8, 9}));
263 ASSERT_FALSE(sampler.
next(2).has_value());
266 TEST(DataTest, SequentialSamplerReturnsLessValuesForLastBatch) {
268 ASSERT_EQ(sampler.
next(3).value(), std::vector<size_t>({0, 1, 2}));
269 ASSERT_EQ(sampler.
next(100).value(), std::vector<size_t>({3, 4}));
270 ASSERT_FALSE(sampler.
next(2).has_value());
273 TEST(DataTest, SequentialSamplerResetsWell) {
275 ASSERT_EQ(sampler.
next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
276 ASSERT_FALSE(sampler.
next(2).has_value());
278 ASSERT_EQ(sampler.
next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
279 ASSERT_FALSE(sampler.
next(2).has_value());
282 TEST(DataTest, SequentialSamplerResetsWithNewSizeWell) {
284 ASSERT_EQ(sampler.
next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
285 ASSERT_FALSE(sampler.
next(2).has_value());
288 sampler.
next(7).value(), std::vector<size_t>({0, 1, 2, 3, 4, 5, 6}));
289 ASSERT_FALSE(sampler.
next(2).has_value());
291 ASSERT_EQ(sampler.
next(3).value(), std::vector<size_t>({0, 1, 2}));
292 ASSERT_FALSE(sampler.
next(2).has_value());
295 TEST(DataTest, CanSaveAndLoadSequentialSampler) {
298 ASSERT_EQ(a.
index(), 0);
299 std::stringstream stream;
300 torch::save(a, stream);
303 torch::load(b, stream);
304 ASSERT_EQ(b.
index(), 0);
310 ASSERT_EQ(a.
index(), 7);
311 std::stringstream stream;
312 torch::save(a, stream);
315 torch::load(b, stream);
316 ASSERT_EQ(b.
index(), 7);
320 TEST(DataTest, RandomSamplerReturnsIndicesInCorrectRange) {
323 std::vector<size_t> indices = sampler.
next(3).value();
324 for (
auto i : indices) {
329 indices = sampler.
next(5).value();
330 for (
auto i : indices) {
335 indices = sampler.
next(2).value();
336 for (
auto i : indices) {
341 ASSERT_FALSE(sampler.
next(10).has_value());
344 TEST(DataTest, RandomSamplerReturnsLessValuesForLastBatch) {
346 ASSERT_EQ(sampler.
next(3).value().size(), 3);
347 ASSERT_EQ(sampler.
next(100).value().size(), 2);
348 ASSERT_FALSE(sampler.
next(2).has_value());
351 TEST(DataTest, RandomSamplerResetsWell) {
353 ASSERT_EQ(sampler.
next(5).value().size(), 5);
354 ASSERT_FALSE(sampler.
next(2).has_value());
356 ASSERT_EQ(sampler.
next(5).value().size(), 5);
357 ASSERT_FALSE(sampler.
next(2).has_value());
360 TEST(DataTest, RandomSamplerResetsWithNewSizeWell) {
362 ASSERT_EQ(sampler.
next(5).value().size(), 5);
363 ASSERT_FALSE(sampler.
next(2).has_value());
365 ASSERT_EQ(sampler.
next(7).value().size(), 7);
366 ASSERT_FALSE(sampler.
next(2).has_value());
368 ASSERT_EQ(sampler.
next(3).value().size(), 3);
369 ASSERT_FALSE(sampler.
next(2).has_value());
372 TEST(DataTest, SavingAndLoadingRandomSamplerYieldsSameSequence) {
376 std::stringstream stream;
377 torch::save(a, stream);
380 torch::load(b, stream);
382 ASSERT_EQ(a.
next(10).value(), b.
next(10).value());
387 ASSERT_EQ(a.
index(), 3);
389 std::stringstream stream;
390 torch::save(a, stream);
393 torch::load(b, stream);
394 ASSERT_EQ(b.
index(), 3);
396 auto b_sequence = b.
next(10).value();
397 ASSERT_EQ(b_sequence.size(), 7);
398 ASSERT_EQ(a.
next(10).value(), b_sequence);
402 TEST(DataTest, StreamSamplerReturnsTheBatchSizeAndThenRemainder) {
404 ASSERT_EQ(sampler.
next(10).value(), 10);
405 ASSERT_EQ(sampler.
next(2).value(), 2);
406 ASSERT_EQ(sampler.
next(85).value(), 85);
407 ASSERT_EQ(sampler.
next(123).value(), 3);
408 ASSERT_FALSE(sampler.
next(1).has_value());
411 TEST(DataTest, StreamSamplerResetsWell) {
413 ASSERT_EQ(sampler.
next(5).value().size(), 5);
414 ASSERT_FALSE(sampler.
next(2).has_value());
416 ASSERT_EQ(sampler.
next(5).value().size(), 5);
417 ASSERT_FALSE(sampler.
next(2).has_value());
420 TEST(DataTest, StreamSamplerResetsWithNewSizeWell) {
422 ASSERT_EQ(sampler.
next(5).value().size(), 5);
423 ASSERT_FALSE(sampler.
next(2).has_value());
425 ASSERT_EQ(sampler.
next(7).value().size(), 7);
426 ASSERT_FALSE(sampler.
next(2).has_value());
428 ASSERT_EQ(sampler.
next(3).value().size(), 3);
429 ASSERT_FALSE(sampler.
next(2).has_value());
432 TEST(DataTest, TensorDatasetConstructsFromSingleTensor) {
435 torch::tensor({0, 0, 1, 0, 0}, torch::kFloat32).allclose(dataset.
get(2)));
438 TEST(DataTest, TensorDatasetConstructsFromInitializerListOfTensors) {
439 std::vector<torch::Tensor> vector = torch::eye(5).chunk(5);
442 torch::tensor({0, 0, 1, 0, 0}, torch::kFloat32).allclose(dataset.
get(2)));
445 TEST(DataTest, StackTransformWorksForExample) {
448 return {tensor[index], 1 + tensor[index]};
452 return tensor.size(0);
461 ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(0, 0, 2)));
462 ASSERT_TRUE(batch.target.allclose(1 + torch::eye(4).slice(0, 0, 2)));
465 ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(0, 2, 4)));
466 ASSERT_TRUE(second.target.allclose(1 + torch::eye(4).slice(0, 2, 4)));
469 TEST(DataTest, StackTransformWorksForTensorExample) {
474 ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(0, 0, 2)));
477 ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(0, 2, 4)));
481 template <
typename Target>
490 Dataset<TensorStringDataset, Example<torch::Tensor, std::string>> {
492 return {torch::tensor(static_cast<double>(index)), std::to_string(index)};
500 TEST(DataTest, TensorTransformWorksForAnyTargetType) {
502 std::vector<Example<torch::Tensor, std::string>> batch = d.get_batch({1, 2});
504 ASSERT_EQ(batch.size(), 2);
505 ASSERT_TRUE(batch[0].data.allclose(torch::tensor(2.0)));
506 ASSERT_EQ(batch[0].target,
"1");
508 ASSERT_TRUE(batch[1].data.allclose(torch::tensor(4.0)));
509 ASSERT_EQ(batch[1].target,
"2");
512 TEST(DataTest, TensorLambdaWorksforAnyTargetType) {
515 std::vector<Example<torch::Tensor, std::string>> batch = d.get_batch({1, 2});
517 ASSERT_EQ(batch.size(), 2);
518 ASSERT_TRUE(batch[0].data.allclose(torch::tensor(2.0)));
519 ASSERT_EQ(batch[0].target,
"1");
521 ASSERT_TRUE(batch[1].data.allclose(torch::tensor(4.0)));
522 ASSERT_EQ(batch[1].target,
"2");
528 const auto channels =
static_cast<int64_t
>(index);
530 (channels > 0) ? torch::ones({channels, 4, 4}) : torch::ones({4, 4});
531 return {tensor,
static_cast<int>(channels)};
539 TEST(DataTest, NormalizeTransform) {
543 std::vector<Example<torch::Tensor, int>> output = dataset.get_batch(0);
544 ASSERT_EQ(output.size(), 1);
546 ASSERT_TRUE(output[0].data.allclose(torch::ones({4, 4}) * 5))
550 output = dataset.get_batch(1);
551 ASSERT_EQ(output.size(), 1);
552 ASSERT_EQ(output[0].data.size(0), 1);
553 ASSERT_TRUE(output[0].data.allclose(torch::ones({1, 4, 4}) * 5))
559 output = dataset.get_batch(2);
560 ASSERT_EQ(output.size(), 1);
561 ASSERT_EQ(output[0].data.size(0), 2);
562 ASSERT_TRUE(output[0]
564 .allclose(torch::ones({1, 4, 4}) * 5))
566 ASSERT_TRUE(output[0]
568 .allclose(torch::ones({1, 4, 4}) * -2.5))
573 output = dataset.get_batch(3);
574 ASSERT_EQ(output.size(), 1);
575 ASSERT_EQ(output[0].data.size(0), 3);
576 ASSERT_TRUE(output[0].data.allclose(torch::ones({3, 4, 4}) * -2.5))
582 output = dataset.get_batch(3);
583 ASSERT_EQ(output.size(), 1);
584 ASSERT_EQ(output[0].data.size(0), 3);
585 ASSERT_TRUE(output[0]
587 .allclose(torch::ones({1, 4, 4}) * 5))
589 ASSERT_TRUE(output[0]
591 .allclose(torch::ones({1, 4, 4}) * -2.5))
593 ASSERT_TRUE(output[0]
595 .allclose(torch::ones({1, 4, 4}) * 12.5))
611 return {torch::tensor(static_cast<int64_t>(index)),
612 torch::tensor(static_cast<int64_t>(index))};
620 TEST(DataTest, MapDoesNotCopy) {
629 auto data = dataset.get_batch(1).at(0).data;
630 ASSERT_EQ(data.numel(), 1);
631 ASSERT_EQ(data[0].item<float>(), 7);
634 TEST(DataTest, QueuePushAndPopFromSameThread) {
638 ASSERT_EQ(queue.
pop(), 1);
639 ASSERT_EQ(queue.
pop(), 2);
642 TEST(DataTest, QueuePopWithTimeoutThrowsUponTimeout) {
645 queue.
pop(10 * kMillisecond),
646 "Timeout in DataLoader queue while waiting for next batch " 647 "(timeout was 10 ms)");
650 TEST(DataTest, QueuePushAndPopFromDifferentThreads) {
658 std::async(std::launch::async, [&queue] {
return queue.pop(); });
659 ASSERT_EQ(future.get(), 1);
665 std::thread thread([&queue] {
666 std::this_thread::sleep_for(20 * kMillisecond);
669 ASSERT_EQ(queue.pop(), 123);
674 TEST(DataTest, QueueClearEmptiesTheQueue) {
679 ASSERT_EQ(queue.
clear(), 3);
680 ASSERT_THROWS_WITH(queue.
pop(1 * kMillisecond),
"Timeout");
683 TEST(DataTest, DataShuttleCanPushAndPopJob) {
687 ASSERT_EQ(shuttle.
pop_job(), 1);
688 ASSERT_EQ(shuttle.
pop_job(), 2);
691 TEST(DataTest, DataShuttleCanPushAndPopResult) {
706 TEST(DataTest, DataShuttlePopResultReturnsNulloptWhenNoJobsInFlight) {
708 ASSERT_FALSE(shuttle.
pop_result().has_value());
713 ASSERT_FALSE(shuttle.
pop_result().has_value());
714 ASSERT_FALSE(shuttle.
pop_result().has_value());
717 TEST(DataTest, DataShuttleDrainMeansPopResultReturnsNullopt) {
722 ASSERT_FALSE(shuttle.
pop_result().has_value());
725 TEST(DataTest, DataShuttlePopResultTimesOut) {
728 ASSERT_THROWS_WITH(shuttle.
pop_result(10 * kMillisecond),
"Timeout");
740 int get(
size_t index)
override {
748 TEST(DataTest, SharedBatchDatasetReallyIsShared) {
754 auto shared_dataset =
755 torch::data::datasets::make_shared_dataset<UncopyableDataset>(
758 auto data_loader = torch::data::make_data_loader(
761 for (
auto batch : *data_loader) {
766 TEST(DataTest, SharedBatchDatasetDoesNotIncurCopyWhenPassedDatasetObject) {
768 auto shared_dataset =
769 torch::data::datasets::make_shared_dataset<UncopyableDataset>(
771 ASSERT_EQ(shared_dataset.size().value(), 100);
775 explicit TestIndex(
size_t offset, std::vector<size_t> index)
776 : offset(offset), index(std::move(index)) {}
781 std::vector<size_t> index;
787 std::iota(data.begin(), data.end(), size_t(0));
790 std::vector<int> batch;
791 for (
auto i : index.index) {
792 batch.push_back(index.offset + data.at(i));
799 std::vector<int> data;
806 if (index_ >= size_) {
807 return torch::nullopt;
809 std::vector<size_t> indices(batch_size);
810 std::iota(indices.begin(), indices.end(), size_t(0));
811 index_ += batch_size;
812 return TestIndex(batch_size, std::move(indices));
820 TEST(DataTest, CanUseCustomTypeAsIndexType) {
821 const int kBatchSize = 10;
822 auto data_loader = torch::data::make_data_loader(
826 for (
auto batch : *data_loader) {
827 for (
int j = 0; j < kBatchSize; ++j) {
828 ASSERT_EQ(batch.at(j), 10 + j);
834 TEST(DataTest, DistributedRandomSamplerSingleReplicaProduceCorrectSamples) {
835 size_t sample_count = 10;
838 std::vector<size_t> res;
840 while ((idx = drs.
next(3)).has_value()) {
841 res.insert(std::end(res), std::begin(*idx), std::end(*idx));
844 ASSERT_EQ(res.size(), sample_count);
846 std::sort(res.begin(), res.end());
847 for (
size_t i = 0; i < res.size(); ++i) {
848 ASSERT_EQ(res[i], i);
852 TEST(DataTest, DistributedRandomSamplerMultiReplicaProduceCorrectSamples) {
853 size_t sample_count = 10;
854 size_t num_replicas = 3;
856 auto test_function = [&](
bool allow_duplicates,
857 size_t local_sample_count,
858 std::vector<size_t>& output,
860 std::vector<std::unique_ptr<samplers::DistributedRandomSampler>> samplers;
862 for (
size_t i = 0; i < num_replicas; ++i) {
863 samplers.emplace_back(
864 torch::make_unique<samplers::DistributedRandomSampler>(
865 sample_count, num_replicas, i, allow_duplicates));
868 std::vector<size_t> res;
869 for (
size_t i = 0; i < num_replicas; ++i) {
870 (*samplers[i]).reset();
872 while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
873 res.insert(std::end(res), std::begin(*idx), std::end(*idx));
875 ASSERT_EQ(res.size(), local_sample_count * (i + 1));
877 std::sort(res.begin(), res.end());
878 ASSERT_EQ(res, output);
881 for (
size_t batch_size = 1; batch_size <= 3; ++batch_size) {
882 size_t local_sample_count =
883 static_cast<size_t>(std::ceil(sample_count * 1.0 / num_replicas));
884 std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9};
885 test_function(
true, local_sample_count, output1, batch_size);
888 static_cast<size_t>(std::floor(sample_count * 1.0 / num_replicas));
889 std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8};
890 test_function(
false, local_sample_count, output2, batch_size);
894 TEST(DataTest, CanSaveAndLoadDistributedRandomSampler) {
897 ASSERT_EQ(a.
index(), 0);
898 std::stringstream stream;
899 torch::save(a, stream);
902 torch::load(b, stream);
903 ASSERT_EQ(b.
index(), 0);
909 ASSERT_EQ(a.
index(), 7);
910 std::stringstream stream;
911 torch::save(a, stream);
914 torch::load(b, stream);
915 ASSERT_EQ(b.
index(), 7);
920 std::stringstream stream;
921 torch::save(a, stream);
924 torch::load(b, stream);
925 ASSERT_EQ(b.epoch(), 3);
929 TEST(DataTest, DistributedSequentialSamplerSingleReplicaProduceCorrectSamples) {
930 size_t sample_count = 10;
931 size_t batch_size = 3;
934 std::vector<size_t> res;
936 while ((idx = dss.
next(batch_size)).has_value()) {
937 res.insert(std::end(res), std::begin(*idx), std::end(*idx));
940 ASSERT_EQ(res.size(), sample_count);
942 std::sort(res.begin(), res.end());
943 for (
size_t i = 0; i < res.size(); ++i) {
944 ASSERT_EQ(res[i], i);
948 TEST(DataTest, DistributedSequentialSamplerMultiReplicaProduceCorrectSamples) {
949 size_t sample_count = 10;
950 size_t num_replicas = 3;
952 auto test_function = [&](
bool allow_duplicates,
953 size_t local_sample_count,
954 std::vector<size_t>& output,
956 std::vector<std::unique_ptr<samplers::DistributedSequentialSampler>>
959 for (
size_t i = 0; i < num_replicas; ++i) {
960 samplers.emplace_back(
961 torch::make_unique<samplers::DistributedSequentialSampler>(
962 sample_count, num_replicas, i, allow_duplicates));
965 std::vector<size_t> res;
966 for (
size_t i = 0; i < num_replicas; ++i) {
967 (*samplers[i]).reset();
969 while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
970 res.insert(std::end(res), std::begin(*idx), std::end(*idx));
972 ASSERT_EQ(res.size(), local_sample_count * (i + 1));
974 std::sort(res.begin(), res.end());
975 ASSERT_EQ(res, output);
978 for (
size_t batch_size = 1; batch_size <= 3; ++batch_size) {
979 size_t local_sample_count =
980 static_cast<size_t>(std::ceil(sample_count * 1.0 / num_replicas));
981 std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9};
982 test_function(
true, local_sample_count, output1, batch_size);
985 static_cast<size_t>(std::floor(sample_count * 1.0 / num_replicas));
986 std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8};
987 test_function(
false, local_sample_count, output2, batch_size);
991 TEST(DataTest, CanSaveAndLoadDistributedSequentialSampler) {
994 ASSERT_EQ(a.
index(), 0);
995 std::stringstream stream;
996 torch::save(a, stream);
999 torch::load(b, stream);
1000 ASSERT_EQ(b.
index(), 0);
1006 ASSERT_EQ(a.
index(), 7);
1007 std::stringstream stream;
1008 torch::save(a, stream);
1011 torch::load(b, stream);
1012 ASSERT_EQ(b.
index(), 7);
1016 TEST(DataLoaderTest, DataLoaderOptionsDefaultAsExpected) {
1019 ASSERT_EQ(full_options.batch_size, 1);
1020 ASSERT_FALSE(full_options.drop_last);
1021 ASSERT_EQ(full_options.workers, 0);
1022 ASSERT_EQ(full_options.max_jobs, 0);
1023 ASSERT_FALSE(full_options.timeout.has_value());
1024 ASSERT_TRUE(full_options.enforce_ordering);
1027 TEST(DataLoaderTest, DataLoaderOptionsCoalesceOptionalValues) {
1030 ASSERT_EQ(full_options.batch_size, 32);
1031 ASSERT_EQ(full_options.max_jobs, 2 * 10);
1034 TEST(DataLoaderTest, MakeDataLoaderDefaultsAsExpected) {
1035 auto data_loader = torch::data::make_data_loader(
1037 ASSERT_EQ(data_loader->options().batch_size, 1);
1042 return {torch::ones(i), torch::ones(i)};
1045 return torch::nullopt;
1051 MakeDataLoaderThrowsWhenConstructingSamplerWithUnsizedDataset) {
1054 "Expected the dataset to be sized in order to construct the Sampler");
1057 TEST(DataLoaderTest, IteratorsCompareEqualToThemselves) {
1058 auto data_loader = torch::data::make_data_loader(
DummyDataset(), 32);
1059 auto begin = data_loader->begin();
1060 ASSERT_EQ(begin, begin);
1061 auto end = data_loader->end();
1062 ASSERT_EQ(end, end);
1065 TEST(DataLoaderTest, ValidIteratorsCompareUnequalToEachOther) {
1066 auto data_loader = torch::data::make_data_loader(
DummyDataset(), 32);
1067 auto i = data_loader->begin();
1068 auto j = data_loader->begin();
1074 TEST(DataLoaderTest, SentinelIteratorsCompareEqualToEachOther) {
1075 auto data_loader = torch::data::make_data_loader(
DummyDataset(), 32);
1076 auto i = data_loader->end();
1077 auto j = data_loader->end();
1081 TEST(DataLoaderTest, IteratorsCompareEqualToSentinelWhenExhausted) {
1084 torch::data::make_data_loader(dataset, dataset.
size().value() / 4);
1085 auto i = data_loader->begin();
1086 auto end = data_loader->end();
1098 TEST(DataLoaderTest, IteratorsShareState) {
1101 torch::data::make_data_loader(dataset, dataset.
size().value() / 2);
1102 auto i = data_loader->begin();
1104 auto end = data_loader->end();
1115 TEST(DataLoaderTest, CanDereferenceIteratorMultipleTimes) {
1118 torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
1121 auto iterator = data_loader->begin();
1122 std::vector<int> expected = {1};
1123 ASSERT_EQ(*iterator, expected);
1124 ASSERT_EQ(*iterator, expected);
1127 ASSERT_EQ(*iterator, expected);
1128 ASSERT_EQ(*iterator, expected);
1131 ASSERT_EQ(*iterator, expected);
1132 ASSERT_EQ(*iterator, expected);
1135 TEST(DataLoaderTest, CanUseIteratorAlgorithms) {
1138 return 1 + indices.
front();
1147 torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
1149 std::vector<int> values;
1151 data_loader->begin(), data_loader->end(), std::back_inserter(values));
1152 std::vector<int> expected(dataset.size().value());
1153 std::iota(expected.begin(), expected.end(), size_t(1));
1154 ASSERT_EQ(values, expected);
1157 TEST(DataLoaderTest, CallingBeginWhileOtherIteratorIsInFlightThrows) {
1161 auto i = data_loader->begin();
1163 data_loader->begin(),
1164 "Attempted to get a new DataLoader iterator " 1165 "while another iterator is not yet exhausted");
1168 TEST(DataLoaderTest, IncrementingExhaustedValidIteratorThrows) {
1171 torch::data::make_data_loader(dataset, dataset.
size().value());
1172 auto i = data_loader->begin();
1173 ASSERT_NO_THROW(++i);
1174 ASSERT_THROWS_WITH(++i,
"Attempted to increment iterator past the end");
1177 TEST(DataLoaderTest, DereferencingExhaustedValidIteratorThrows) {
1180 torch::data::make_data_loader(dataset, dataset.
size().value());
1181 auto i = data_loader->begin();
1182 ASSERT_NO_THROW(++i);
1184 *i,
"Attempted to dereference iterator that was past the end");
1187 TEST(DataLoaderTest, IncrementingSentinelIteratorThrows) {
1190 torch::data::make_data_loader(dataset, dataset.
size().value());
1191 auto i = data_loader->end();
1194 "Incrementing the DataLoader's past-the-end iterator is not allowed");
1197 TEST(DataLoaderTest, DereferencingSentinelIteratorThrows) {
1200 torch::data::make_data_loader(dataset, dataset.
size().value());
1201 auto i = data_loader->end();
1204 "Dereferencing the DataLoader's past-the-end iterator is not allowed");
1207 TEST(DataLoaderTest, YieldsCorrectBatchSize) {
1209 auto data_loader = torch::data::make_data_loader(dataset, 25);
1210 auto iterator = data_loader->begin();
1211 ASSERT_EQ(iterator->size(), 25);
1212 ASSERT_EQ((++iterator)->size(), 25);
1213 ASSERT_EQ((++iterator)->size(), 25);
1214 ASSERT_EQ((++iterator)->size(), 25);
1215 ASSERT_EQ(++iterator, data_loader->end());
1220 ReturnsLastBatchWhenSmallerThanBatchSizeWhenDropLastIsFalse) {
1222 auto data_loader = torch::data::make_data_loader(
1224 auto iterator = data_loader->begin();
1225 ASSERT_EQ(iterator->size(), 33);
1226 ASSERT_EQ((++iterator)->size(), 33);
1227 ASSERT_EQ((++iterator)->size(), 33);
1228 ASSERT_EQ((++iterator)->size(), 1);
1229 ASSERT_EQ(++iterator, data_loader->end());
1234 DoesNotReturnLastBatchWhenSmallerThanBatchSizeWhenDropLastIsTrue) {
1236 auto data_loader = torch::data::make_data_loader(
1238 auto iterator = data_loader->begin();
1239 ASSERT_EQ(iterator->size(), 33);
1240 ASSERT_EQ((++iterator)->size(), 33);
1241 ASSERT_EQ((++iterator)->size(), 33);
1242 ASSERT_EQ(++iterator, data_loader->end());
1245 TEST(DataLoaderTest, RespectsTimeout) {
1247 std::condition_variable cv;
1252 D(std::shared_ptr<Baton> b) : baton(std::move(b)) {}
1253 int get(
size_t index)
override {
1254 std::unique_lock<std::mutex> lock(baton->mutex);
1255 baton->cv.wait_for(lock, 1000 * kMillisecond);
1261 std::shared_ptr<Baton> baton;
1264 auto baton = std::make_shared<Baton>();
1266 auto data_loader = torch::data::make_data_loader(
1269 auto start = std::chrono::system_clock::now();
1271 ASSERT_THROWS_WITH(*data_loader->begin(),
"Timeout");
1272 baton->cv.notify_one();
1274 auto end = std::chrono::system_clock::now();
1275 auto duration = std::chrono::duration_cast<std::chrono::seconds>(end - start);
1276 ASSERT_LT(duration.count(), 1);
1281 explicit Barrier(
size_t target) : counter_(target) {}
1283 std::unique_lock<std::mutex> lock(mutex_);
1284 if (--counter_ == 0) {
1287 cv_.wait(lock, [
this] {
return this->counter_ == 0; });
1292 std::condition_variable cv_;
1339 const size_t kNumberOfWorkers = 10;
1340 const std::vector<size_t> kOrderInWhichWorkersReturnTheirBatch =
1341 {3, 7, 0, 5, 4, 8, 2, 1, 9, 6};
1350 static std::atomic<size_t> counter{0};
1351 thread_id_ = counter.fetch_add(1);
1359 static Barrier barrier(kNumberOfWorkers);
1360 static auto order_iterator = kOrderInWhichWorkersReturnTheirBatch.begin();
1361 static std::condition_variable cv;
1362 static std::mutex mutex;
1367 std::unique_lock<std::mutex> lock(mutex);
1368 cv.wait(lock, [
this] {
return *order_iterator == this->thread_id_; });
1373 return indices.
front();
1377 return kNumberOfWorkers;
1380 size_t thread_id_ = 0;
1385 TEST(DataLoaderTest, EnforcesOrderingAmongThreadsWhenConfigured) {
1386 auto data_loader = torch::data::make_data_loader(
1391 .workers(ordering_test::kNumberOfWorkers)
1392 .enforce_ordering(
true));
1393 std::vector<size_t> output;
1394 for (
size_t value : *data_loader) {
1395 output.push_back(value);
1397 std::vector<size_t> expected(ordering_test::kNumberOfWorkers);
1398 std::iota(expected.begin(), expected.end(), size_t(0));
1399 ASSERT_EQ(expected, output);
1402 TEST(DataLoaderTest, Reset) {
1405 torch::data::make_data_loader(dataset, dataset.
size().value() / 2);
1406 auto end = data_loader->end();
1408 auto iterator = data_loader->begin();
1409 ASSERT_NE(iterator, end);
1410 ASSERT_NE(++iterator, end);
1411 ASSERT_EQ(++iterator, end);
1413 iterator = data_loader->begin();
1414 ASSERT_NE(iterator, end);
1415 ASSERT_NE(++iterator, end);
1416 ASSERT_EQ(++iterator, end);
1418 iterator = data_loader->begin();
1419 ASSERT_NE(iterator, end);
1420 ASSERT_NE(++iterator, end);
1421 ASSERT_EQ(++iterator, end);
1424 TEST(DataLoaderTest, TestExceptionsArePropagatedFromWorkers) {
1426 int get(
size_t index)
override {
1427 throw std::invalid_argument(
"badness");
1434 auto data_loader = torch::data::make_data_loader(
1436 auto iterator = data_loader->begin();
1443 std::string(
"Caught exception in DataLoader worker thread. " 1444 "Original message: badness"));
1450 TEST(DataLoaderTest, StatefulDatasetWithNoWorkers) {
1451 const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1455 if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1458 return torch::nullopt;
1463 void reset()
override {
1469 auto data_loader = torch::data::make_data_loader(
D{});
1471 for (
size_t i = 0; i < 10; ++i) {
1472 const auto number_of_iterations =
1473 std::distance(data_loader->begin(), data_loader->end());
1475 number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
1479 for (
const int i : *data_loader) {
1480 ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
1484 TEST(DataLoaderTest, StatefulDatasetWithManyWorkers) {
1485 const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1486 const int kNumberOfWorkers = 4;
1490 std::lock_guard<std::mutex> lock(mutex);
1491 if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1494 return torch::nullopt;
1499 void reset()
override {
1506 auto data_loader = torch::data::make_data_loader(
1507 torch::data::datasets::make_shared_dataset<D>(),
1510 for (
size_t i = 0; i < 10; ++i) {
1511 const auto number_of_iterations =
1512 std::distance(data_loader->begin(), data_loader->end());
1514 number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
1518 for (
const int i : *data_loader) {
1519 ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
1523 TEST(DataLoaderTest, StatefulDatasetWithMap) {
1524 const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1528 if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1531 return torch::nullopt;
1536 void reset()
override {
1542 auto data_loader = torch::data::make_data_loader(
1544 [](
int x) {
return std::to_string(x); }))
1546 [](
const std::string& x) {
1547 return torch::tensor(static_cast<int64_t>(std::stoi(x)));
1551 for (
size_t i = 0; i < 10; ++i) {
1552 const auto number_of_iterations =
1553 std::distance(data_loader->begin(), data_loader->end());
1555 number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
1560 ASSERT_LT(t.item<int64_t>(), kNumberOfExamplesAfterWhichTheDatasetExhausts);
1564 TEST(DataLoaderTest, StatefulDatasetWithCollate) {
1565 const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1569 size_t batch_size)
override {
1570 if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1571 counter += batch_size;
1572 std::vector<Example<>> batch(
1575 torch::zeros(batch_size - 1)});
1578 return torch::nullopt;
1583 void reset()
override {
1591 const size_t kBatchSize = 5;
1596 ASSERT_TRUE(batch.has_value());
1597 ASSERT_EQ(batch->data.size(0), kBatchSize);
1598 ASSERT_EQ(batch->data.size(1), kBatchSize + 1);
1599 ASSERT_EQ(batch->target.size(0), kBatchSize);
1600 ASSERT_EQ(batch->target.size(1), kBatchSize - 1);
1602 ASSERT_TRUE(batch->data[0].allclose(torch::ones(kBatchSize + 1)));
1603 ASSERT_TRUE(batch->target[0].allclose(torch::zeros(kBatchSize - 1)));
1610 TEST(DataLoaderTest, ChunkDataSetGetBatch) {
1612 const size_t prefetch_counts[] = {1, 2, 3, 4};
1615 const size_t batch_sizes[] = {5, 7};
1618 const size_t dataloader_worker_counts[] = {0, 2};
1620 const size_t total_example_count = 35;
1625 const int epoch_count = 2;
1627 for (
auto prefetch_count : prefetch_counts) {
1628 for (
auto batch_size : batch_sizes) {
1629 for (
auto dataloader_worker_count : dataloader_worker_counts) {
1633 samplers::SequentialSampler>>
1635 DummyChunkDataReader,
1636 samplers::SequentialSampler,
1637 samplers::SequentialSampler>>(
1643 auto data_loader = torch::data::make_data_loader(
1647 for (
int epoch_index = 0; epoch_index < epoch_count; ++epoch_index) {
1648 std::vector<bool> result(total_example_count,
false);
1649 int iteration_count = 0;
1650 for (
auto iterator = data_loader->begin();
1651 iterator != data_loader->end();
1652 ++iterator, ++iteration_count) {
1653 std::vector<int>& batch = *iterator;
1654 ASSERT_EQ(batch.size(), batch_size);
1658 if (prefetch_count == 1 && dataloader_worker_count == 0) {
1659 for (
size_t j = 0; j < batch_size; ++j) {
1660 ASSERT_EQ(batch[j], iteration_count * batch_size + j);
1663 for (
size_t j = 0; j < batch_size; ++j) {
1664 result[batch[j]] =
true;
1668 for (
auto data : result) {
1669 ASSERT_EQ(data,
true);
1677 TEST(DataLoaderTest, ChunkDataSetWithBatchSizeMismatch) {
1678 const size_t prefetch_count = 1;
1679 const size_t batch_size = 5;
1680 const size_t requested_batch_size = 6;
1688 samplers::SequentialSampler>>
1690 DummyChunkDataReader,
1691 samplers::SequentialSampler,
1692 samplers::SequentialSampler>>(
1698 auto data_loader = torch::data::make_data_loader(
1702 std::string exception_msg =
1703 "The requested batch size does not match with the initialized batch " 1704 "size.\n The requested batch size is 6, while the dataset is created" 1705 " with batch size equal to 5";
1707 ASSERT_THROWS_WITH(*(data_loader->begin()), exception_msg);
1710 TEST(DataLoaderTest, ChunkDataSetWithEmptyBatch) {
1711 struct DummyEmptyChunkDataReader
1714 using BatchType = std::vector<int>;
1716 BatchType read_chunk(
size_t chunk_index)
override {
1720 size_t chunk_count()
override {
1724 void reset()
override{};
1727 const size_t prefetch_count = 1;
1728 const size_t batch_size = 5;
1729 DummyEmptyChunkDataReader data_reader;
1733 DummyEmptyChunkDataReader,
1735 samplers::SequentialSampler>>
1737 DummyEmptyChunkDataReader,
1738 samplers::SequentialSampler,
1739 samplers::SequentialSampler>>(
1745 auto data_loader = torch::data::make_data_loader(
1748 for (
auto iterator = data_loader->begin(); iterator != data_loader->end();
1750 ASSERT_EQ(iterator->size(), 0);
1754 TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) {
1757 using BatchType = std::vector<int>;
1759 BatchType read_chunk(
size_t chunk_index)
override {
1760 BatchType batch_data(10, 0);
1764 size_t chunk_count()
override {
1768 void reset()
override{};
1771 const size_t batch_sizes[] = {17, 30};
1775 for (
auto batch_size : batch_sizes) {
1779 samplers::SequentialSampler>>
1782 samplers::SequentialSampler,
1783 samplers::SequentialSampler>>(
1789 auto data_loader = torch::data::make_data_loader(
1792 for (
auto iterator = data_loader->begin(); iterator != data_loader->end();
1794 std::vector<int> batch = *iterator;
1795 auto batch_size = batch.size();
1796 if (batch_size == 17) {
1797 ASSERT_TRUE(batch.size() == 17 || batch.size() == 3);
1799 if (batch_size == 30) {
1800 ASSERT_TRUE(batch.size() == 20);
1806 TEST(DataLoaderTest, CanAccessChunkSamplerWithChunkDataSet) {
1807 const size_t prefetch_count = 2;
1808 const size_t batch_size = 5;
1815 samplers::SequentialSampler>>
1817 DummyChunkDataReader,
1818 samplers::SequentialSampler,
1819 samplers::SequentialSampler>>(
1825 samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler();
1827 auto data_loader = torch::data::make_data_loader(
1829 [](std::vector<int> batch) {
1830 return std::accumulate(batch.begin(), batch.end(), 0);
1835 ASSERT_EQ(chunk_sampler.index(), 0);
1838 for (
auto iterator = data_loader->begin(); iterator != data_loader->end();
1842 ASSERT_EQ(sum, 595);
1844 ASSERT_EQ(chunk_sampler.index(), 3);
1847 TEST(DataLoaderTest, ChunkDatasetDoesNotHang) {
1848 const size_t prefetch_count = 2;
1849 const size_t batch_size = 5;
1851 const size_t cache_size = 10;
1858 samplers::SequentialSampler>>
1860 DummyChunkDataReader,
1861 samplers::SequentialSampler,
1862 samplers::SequentialSampler>>(
1867 prefetch_count, batch_size, cache_size));
1869 samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler();
1871 auto data_loader = torch::data::make_data_loader(
1873 [](std::vector<int> batch) {
1874 return std::accumulate(batch.begin(), batch.end(), 0);
1880 auto iterator = data_loader->begin();
size_t next_sequence_number_
The monotonically increasing sequence number we expect.
Interface for chunk reader, which performs data chunking and reading of entire chunks.
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized.
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized.
void save(torch::serialize::OutputArchive &archive) const override
Serializes the Sampler to the archive.
A stateful dataset that support hierarchical sampling and prefetching of entre chunks.
TORCH_API void reset(optional< size_t > new_size=nullopt) override
Resets the RandomSampler to a new set of indices.
void load(torch::serialize::InputArchive &archive) override
Deserializes the Sampler from the archive.
TORCH_API size_t index() const noexcept
Returns the current index of the DistributedSequentialSampler.
TensorExample get(size_t index) override
Returns a single TensorExample.
void reset() override
This will clear any internal state associate with this reader.
AT_CPP14_CONSTEXPR const T & front() const
front - Get the first element.
MapDataset< Self, TransformType > map(TransformType transform)&
Creates a MapDataset that applies the given transform to this dataset.
BatchType read_chunk(size_t chunk_index) override
Read an entire chunk.
TORCH_API size_t index() const noexcept
Returns the current index of the DistributedRandomSampler.
TORCH_API optional< BatchSize > next(size_t batch_size) override
Returns a BatchSize object with the number of elements to fetch in the next batch.
Like DataLoaderOptions, but without any unconfigured state.
Job pop_job()
Returns the next job, blocking until there is one available.
An Example from a dataset.
TORCH_API optional< std::vector< size_t > > next(size_t batch_size) override
Returns the next batch of indices.
A Sequencer that does not enforce any ordering.
TORCH_API optional< std::vector< size_t > > next(size_t batch_size) override
Returns the next batch of indices.
Options to configure a ChunkDataset.
std::vector< ExampleType > get_batch(ArrayRef< size_t > indices) override
Returns a batch of data.
void set_epoch(size_t epoch)
Set the epoch for the current enumeration.
A dataset that can yield data only in batches.
torch::Tensor operator()(torch::Tensor input) override
Transforms a single input tensor to an output tensor.
optional< Result > next(ResultProducer next_result) override
Buffers results until the next one in the expected order is received.
size_t chunk_count() override
Returns the number of chunks available in this reader.
size_t clear()
Empties the queue and returns the number of elements that were present at the start of the function...
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized.
torch::optional< size_t > size() const noexcept
Returns the size of the dataset, or an empty optional if it is unsized.
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized.
A Sequencer that buffers results and returns them in order of their sequence number.
TORCH_API optional< std::vector< size_t > > next(size_t batch_size) override
Returns the next batch of indices.
Options to configure a DataLoader.
A Sampler that returns random indices.
std::vector< int > get_batch(TestIndex index) override
Returns a batch of data given an index.
A Sampler is an object that yields an index with which to access a dataset.
optional< Result > pop_result(optional< std::chrono::milliseconds > timeout=nullopt)
Returns the result of a job, or nullopt if all jobs were exhausted.
A dataset that can yield data in batches, or as individual examples.
void push_job(Job job)
Pushes a new job. Called by the main thread.
Encapsulates the full life cycle of DataLoader jobs.
TORCH_API void reset(optional< size_t > new_size=nullopt) override
Resets the internal state of the sampler.
A stateful dataset is a dataset that maintains some internal state, which will be reset() at the begi...
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized.
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized.
A dataset that wraps another dataset in a shared pointer and implements the BatchDataset API...
A basic locked, blocking MPMC queue.
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized.
An exception thrown when a DataLoader's worker thread throws an exception, which is caught...
void push(T value)
Pushes a new value to the back of the Queue and notifies one thread on the waiting side about this ev...
Select samples sequentially.
void push_result(Result result)
Pushes the result of a job. Called by worker threads.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
TORCH_API void reset(optional< size_t > new_size=nullopt) override
Resets the SequentialSampler to zero.
A Sampler that returns indices sequentially.
TORCH_API size_t index() const noexcept
Returns the current index of the RandomSampler.
A sampler for (potentially infinite) streams of data.
std::exception_ptr original_exception
The original exception thrown in the worker thread.
std::vector< optional< Result > > buffer_
A fixed-size buffer (after construction).
size_t size() const override
The number of elements accessed by this index.
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized.
T pop(optional< std::chrono::milliseconds > timeout=nullopt)
Blocks until at least one element is ready to be popped from the front of the queue.
A base class for custom index types.
void drain()
Discards any jobs that are not yet in flight, and waits for all in-flight jobs to finish...
torch::optional< TestIndex > next(size_t batch_size) override
Returns the next index if possible, or an empty optional if the sampler is exhausted for this epoch...
TORCH_API size_t index() const noexcept
Returns the current index of the SequentialSampler.
TORCH_API optional< std::vector< size_t > > next(size_t batch_size) override
Returns the next batch of indices.
void reset(torch::optional< size_t > new_size=torch::nullopt) override
Resets the Sampler's internal state.