Caffe2 - C++ API
A deep learning, cross platform ML framework
dataloader.cpp
1 #include <gtest/gtest.h>
2 
3 #include <torch/data.h>
4 #include <torch/data/detail/sequencers.h>
5 #include <torch/serialize.h>
6 #include <torch/types.h>
7 
8 #include <test/cpp/api/support.h>
9 
10 #include <c10/util/ArrayRef.h>
11 
12 #include <algorithm>
13 #include <chrono>
14 #include <future>
15 #include <iostream>
16 #include <iterator>
17 #include <limits>
18 #include <mutex>
19 #include <numeric>
20 #include <stdexcept>
21 #include <string>
22 #include <thread>
23 #include <unordered_set>
24 #include <vector>
25 
26 using namespace torch::data; // NOLINT
27 
28 const std::chrono::milliseconds kMillisecond(1);
29 
30 struct DummyDataset : datasets::Dataset<DummyDataset, int> {
31  explicit DummyDataset(size_t size = 100) : size_(size) {}
32 
33  int get(size_t index) override {
34  return 1 + index;
35  }
36  torch::optional<size_t> size() const override {
37  return size_;
38  }
39 
40  size_t size_;
41 };
42 
43 TEST(DataTest, DatasetCallsGetCorrectly) {
44  DummyDataset d;
45  std::vector<int> batch = d.get_batch({0, 1, 2, 3, 4});
46  std::vector<int> expected = {1, 2, 3, 4, 5};
47  ASSERT_EQ(batch, expected);
48 }
49 
50 TEST(DataTest, TransformCallsGetApplyCorrectly) {
51  struct T : transforms::Transform<int, std::string> {
52  std::string apply(int input) override {
53  return std::to_string(input);
54  }
55  };
56 
57  auto d = DummyDataset{}.map(T{});
58  std::vector<std::string> batch = d.get_batch({0, 1, 2, 3, 4});
59  std::vector<std::string> expected = {"1", "2", "3", "4", "5"};
60  ASSERT_EQ(batch, expected);
61 }
62 
63 // dummy chunk data reader with 3 chunks and 35 examples in total. Each chunk
64 // contains 10, 5, 20 examples respectively.
66  : public datasets::ChunkDataReader<std::vector<int>> {
67  public:
68  using BatchType = std::vector<int>;
69 
71  BatchType read_chunk(size_t chunk_index) override {
72  BatchType batch_data;
73  int start_index = chunk_index == 0
74  ? 0
75  : std::accumulate(chunk_sizes, chunk_sizes + chunk_index, 0);
76 
77  batch_data.resize(chunk_sizes[chunk_index]);
78 
79  std::iota(batch_data.begin(), batch_data.end(), start_index);
80 
81  return batch_data;
82  }
83 
84  size_t chunk_count() override {
85  return chunk_count_;
86  };
87 
88  void reset() override{};
89 
90  const static size_t chunk_count_ = 3;
91  size_t chunk_sizes[chunk_count_] = {10, 5, 20};
92 };
93 
94 TEST(DataTest, ChunkDataSetWithInvalidInitParameter) {
95  DummyChunkDataReader data_reader;
96  samplers::SequentialSampler sampler(0);
97 
98  auto initialization_function =
99  [&](size_t preloader_count, size_t batch_size, size_t cache_size) {
103  samplers::SequentialSampler>>
104  dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
105  DummyChunkDataReader,
106  samplers::SequentialSampler,
107  samplers::SequentialSampler>>(
108  data_reader,
109  sampler,
110  sampler,
112  preloader_count, batch_size, cache_size));
113  };
114 
115  ASSERT_THROWS_WITH(
116  initialization_function(0, 1, 1),
117  "Preloader count is 0. At least one preloader needs to be specified.");
118 
119  ASSERT_THROWS_WITH(
120  initialization_function(1, 0, 1),
121  "Batch size is 0. A positive batch size needs to be specified.");
122 
123  ASSERT_THROWS_WITH(
124  initialization_function(1, 1, 0),
125  "Cache size is 0. A positive cache size needs to be specified.");
126 
127  ASSERT_THROWS_WITH(
128  initialization_function(1, 10, 5),
129  "Cache size is less than batch size. Cache needs to be large enough to "
130  "hold at least one batch.");
131 }
132 
134  : datasets::StreamDataset<InfiniteStreamDataset, std::vector<int>> {
135  std::vector<int> get_batch(size_t batch_size) override {
136  std::vector<int> batch(batch_size);
137  for (auto& i : batch) {
138  i = counter++;
139  }
140  return batch;
141  }
142 
143  torch::optional<size_t> size() const override {
144  return torch::nullopt;
145  }
146 
147  size_t counter = 0;
148 };
149 
150 TEST(DataTest, InfiniteStreamDataset) {
151  const size_t kBatchSize = 13;
152 
153  auto dataset = InfiniteStreamDataset().map(
154  transforms::Lambda<int>([](int x) { return x + 1; }));
155 
156  auto data_loader = torch::data::make_data_loader(
157  std::move(dataset),
158  samplers::StreamSampler(/*epoch_size=*/39),
159  kBatchSize);
160 
161  size_t batch_index = 0;
162  for (auto& batch : *data_loader) {
163  ASSERT_LT(batch_index, 3);
164  ASSERT_EQ(batch.size(), kBatchSize);
165  for (size_t j = 0; j < kBatchSize; ++j) {
166  ASSERT_EQ(batch.at(j), 1 + (batch_index * kBatchSize) + j);
167  }
168  batch_index += 1;
169  }
170  ASSERT_EQ(batch_index, 3);
171 }
172 
173 TEST(DataTest, NoSequencerIsIdentity) {
174  using namespace torch::data::detail::sequencers; // NOLINT
175  NoSequencer<int> no_sequencer;
176  const auto value = no_sequencer.next([] { return 5; }).value();
177  ASSERT_EQ(value, 5);
178 }
179 
180 TEST(DataTest, OrderedSequencerIsSetUpWell) {
181  using namespace torch::data::detail::sequencers; // NOLINT
182  struct S {
183  size_t sequence_number;
184  };
185  const size_t kMaxJobs = 5;
186  OrderedSequencer<S> sequencer(kMaxJobs);
187  ASSERT_EQ(sequencer.next_sequence_number_, 0);
188  ASSERT_EQ(sequencer.buffer_.size(), kMaxJobs);
189 }
190 
191 TEST(DataTest, OrderedSequencerReOrdersValues) {
192  using namespace torch::data::detail::sequencers; // NOLINT
193  struct S {
194  size_t sequence_number;
195  };
196  const size_t kMaxJobs = 5;
197  OrderedSequencer<S> sequencer(kMaxJobs);
198 
199  std::vector<size_t> v = {0, 2, 4, 3, 1};
200  size_t index = 0;
201  auto getter = [&v, &index]() { return S{v.at(index++)}; };
202 
203  // Let's say the sequence number matches for the batch one, then it should
204  // return immediately.
205  const auto batch = sequencer.next(getter);
206  ASSERT_EQ(batch.value().sequence_number, 0);
207  ASSERT_EQ(index, 1);
208 
209  // Now it should call the getter until it gets the next value.
210  ASSERT_EQ(1, sequencer.next(getter).value().sequence_number);
211  ASSERT_EQ(index, 5);
212 
213  // The next three should come in order.
214  for (size_t i = 2; i <= 4; ++i) {
215  // New value doesn't matter. In fact, it shouldn't be accessed.
216  ASSERT_EQ(i, sequencer.next(getter).value().sequence_number);
217  // The index doesn't change.
218  ASSERT_EQ(index, 5);
219  }
220 }
221 
222 TEST(DataTest, BatchLambdaAppliesFunctionToBatch) {
223  using InputBatch = std::vector<int>;
224  using OutputBatch = std::string;
225  DummyDataset d;
227  [](std::vector<int> input) {
228  return std::to_string(std::accumulate(input.begin(), input.end(), 0));
229  }));
230  ASSERT_EQ(e.get_batch({1, 2, 3, 4, 5}), std::string("20"));
231 }
232 
233 TEST(DataTest, LambdaAppliesFunctionToExample) {
235  static_cast<std::string (*)(int)>(std::to_string)));
236  std::vector<std::string> expected = {"1", "2", "3", "4", "5"};
237  ASSERT_EQ(d.get_batch({0, 1, 2, 3, 4}), expected);
238 }
239 
240 TEST(DataTest, CollateReducesBatch) {
241  auto d =
242  DummyDataset().map(transforms::Collate<int>([](std::vector<int> input) {
243  return std::accumulate(input.begin(), input.end(), 0);
244  }));
245  ASSERT_EQ(d.get_batch({1, 2, 3, 4, 5}), 20);
246 }
247 
248 TEST(DataTest, CollationReducesBatch) {
249  struct Summer : transforms::Collation<int> {
250  int apply_batch(std::vector<int> input) override {
251  return std::accumulate(input.begin(), input.end(), 0);
252  }
253  };
254  auto d = DummyDataset().map(Summer{});
255  ASSERT_EQ(d.get_batch({1, 2, 3, 4, 5}), 20);
256 }
257 
258 TEST(DataTest, SequentialSamplerReturnsIndicesInOrder) {
259  samplers::SequentialSampler sampler(10);
260  ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2}));
261  ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({3, 4, 5, 6, 7}));
262  ASSERT_EQ(sampler.next(2).value(), std::vector<size_t>({8, 9}));
263  ASSERT_FALSE(sampler.next(2).has_value());
264 }
265 
266 TEST(DataTest, SequentialSamplerReturnsLessValuesForLastBatch) {
267  samplers::SequentialSampler sampler(5);
268  ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2}));
269  ASSERT_EQ(sampler.next(100).value(), std::vector<size_t>({3, 4}));
270  ASSERT_FALSE(sampler.next(2).has_value());
271 }
272 
273 TEST(DataTest, SequentialSamplerResetsWell) {
274  samplers::SequentialSampler sampler(5);
275  ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
276  ASSERT_FALSE(sampler.next(2).has_value());
277  sampler.reset();
278  ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
279  ASSERT_FALSE(sampler.next(2).has_value());
280 }
281 
282 TEST(DataTest, SequentialSamplerResetsWithNewSizeWell) {
283  samplers::SequentialSampler sampler(5);
284  ASSERT_EQ(sampler.next(5).value(), std::vector<size_t>({0, 1, 2, 3, 4}));
285  ASSERT_FALSE(sampler.next(2).has_value());
286  sampler.reset(7);
287  ASSERT_EQ(
288  sampler.next(7).value(), std::vector<size_t>({0, 1, 2, 3, 4, 5, 6}));
289  ASSERT_FALSE(sampler.next(2).has_value());
290  sampler.reset(3);
291  ASSERT_EQ(sampler.next(3).value(), std::vector<size_t>({0, 1, 2}));
292  ASSERT_FALSE(sampler.next(2).has_value());
293 }
294 
295 TEST(DataTest, CanSaveAndLoadSequentialSampler) {
296  {
298  ASSERT_EQ(a.index(), 0);
299  std::stringstream stream;
300  torch::save(a, stream);
301 
303  torch::load(b, stream);
304  ASSERT_EQ(b.index(), 0);
305  }
306  {
308  a.next(3);
309  a.next(4);
310  ASSERT_EQ(a.index(), 7);
311  std::stringstream stream;
312  torch::save(a, stream);
313 
315  torch::load(b, stream);
316  ASSERT_EQ(b.index(), 7);
317  }
318 }
319 
320 TEST(DataTest, RandomSamplerReturnsIndicesInCorrectRange) {
321  samplers::RandomSampler sampler(10);
322 
323  std::vector<size_t> indices = sampler.next(3).value();
324  for (auto i : indices) {
325  ASSERT_GE(i, 0);
326  ASSERT_LT(i, 10);
327  }
328 
329  indices = sampler.next(5).value();
330  for (auto i : indices) {
331  ASSERT_GE(i, 0);
332  ASSERT_LT(i, 10);
333  }
334 
335  indices = sampler.next(2).value();
336  for (auto i : indices) {
337  ASSERT_GE(i, 0);
338  ASSERT_LT(i, 10);
339  }
340 
341  ASSERT_FALSE(sampler.next(10).has_value());
342 }
343 
344 TEST(DataTest, RandomSamplerReturnsLessValuesForLastBatch) {
345  samplers::RandomSampler sampler(5);
346  ASSERT_EQ(sampler.next(3).value().size(), 3);
347  ASSERT_EQ(sampler.next(100).value().size(), 2);
348  ASSERT_FALSE(sampler.next(2).has_value());
349 }
350 
351 TEST(DataTest, RandomSamplerResetsWell) {
352  samplers::RandomSampler sampler(5);
353  ASSERT_EQ(sampler.next(5).value().size(), 5);
354  ASSERT_FALSE(sampler.next(2).has_value());
355  sampler.reset();
356  ASSERT_EQ(sampler.next(5).value().size(), 5);
357  ASSERT_FALSE(sampler.next(2).has_value());
358 }
359 
360 TEST(DataTest, RandomSamplerResetsWithNewSizeWell) {
361  samplers::RandomSampler sampler(5);
362  ASSERT_EQ(sampler.next(5).value().size(), 5);
363  ASSERT_FALSE(sampler.next(2).has_value());
364  sampler.reset(7);
365  ASSERT_EQ(sampler.next(7).value().size(), 7);
366  ASSERT_FALSE(sampler.next(2).has_value());
367  sampler.reset(3);
368  ASSERT_EQ(sampler.next(3).value().size(), 3);
369  ASSERT_FALSE(sampler.next(2).has_value());
370 }
371 
372 TEST(DataTest, SavingAndLoadingRandomSamplerYieldsSameSequence) {
373  {
375 
376  std::stringstream stream;
377  torch::save(a, stream);
378 
380  torch::load(b, stream);
381 
382  ASSERT_EQ(a.next(10).value(), b.next(10).value());
383  }
384  {
386  a.next(3);
387  ASSERT_EQ(a.index(), 3);
388 
389  std::stringstream stream;
390  torch::save(a, stream);
391 
393  torch::load(b, stream);
394  ASSERT_EQ(b.index(), 3);
395 
396  auto b_sequence = b.next(10).value();
397  ASSERT_EQ(b_sequence.size(), 7);
398  ASSERT_EQ(a.next(10).value(), b_sequence);
399  }
400 }
401 
402 TEST(DataTest, StreamSamplerReturnsTheBatchSizeAndThenRemainder) {
403  samplers::StreamSampler sampler(/*epoch_size=*/100);
404  ASSERT_EQ(sampler.next(10).value(), 10);
405  ASSERT_EQ(sampler.next(2).value(), 2);
406  ASSERT_EQ(sampler.next(85).value(), 85);
407  ASSERT_EQ(sampler.next(123).value(), 3);
408  ASSERT_FALSE(sampler.next(1).has_value());
409 }
410 
411 TEST(DataTest, StreamSamplerResetsWell) {
412  samplers::StreamSampler sampler(/*epoch_size=*/5);
413  ASSERT_EQ(sampler.next(5).value().size(), 5);
414  ASSERT_FALSE(sampler.next(2).has_value());
415  sampler.reset();
416  ASSERT_EQ(sampler.next(5).value().size(), 5);
417  ASSERT_FALSE(sampler.next(2).has_value());
418 }
419 
420 TEST(DataTest, StreamSamplerResetsWithNewSizeWell) {
421  samplers::StreamSampler sampler(/*epoch_size=*/5);
422  ASSERT_EQ(sampler.next(5).value().size(), 5);
423  ASSERT_FALSE(sampler.next(2).has_value());
424  sampler.reset(7);
425  ASSERT_EQ(sampler.next(7).value().size(), 7);
426  ASSERT_FALSE(sampler.next(2).has_value());
427  sampler.reset(3);
428  ASSERT_EQ(sampler.next(3).value().size(), 3);
429  ASSERT_FALSE(sampler.next(2).has_value());
430 }
431 
432 TEST(DataTest, TensorDatasetConstructsFromSingleTensor) {
433  datasets::TensorDataset dataset(torch::eye(5));
434  ASSERT_TRUE(
435  torch::tensor({0, 0, 1, 0, 0}, torch::kFloat32).allclose(dataset.get(2)));
436 }
437 
438 TEST(DataTest, TensorDatasetConstructsFromInitializerListOfTensors) {
439  std::vector<torch::Tensor> vector = torch::eye(5).chunk(5);
440  datasets::TensorDataset dataset(vector);
441  ASSERT_TRUE(
442  torch::tensor({0, 0, 1, 0, 0}, torch::kFloat32).allclose(dataset.get(2)));
443 }
444 
445 TEST(DataTest, StackTransformWorksForExample) {
446  struct D : public datasets::Dataset<D> {
447  Example<> get(size_t index) override {
448  return {tensor[index], 1 + tensor[index]};
449  }
450 
451  torch::optional<size_t> size() const override {
452  return tensor.size(0);
453  }
454 
455  torch::Tensor tensor{torch::eye(4)};
456  };
457 
458  auto d = D().map(transforms::Stack<Example<>>());
459 
460  Example<> batch = d.get_batch({0, 1});
461  ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
462  ASSERT_TRUE(batch.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 0, 2)));
463 
464  Example<> second = d.get_batch({2, 3});
465  ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4)));
466  ASSERT_TRUE(second.target.allclose(1 + torch::eye(4).slice(/*dim=*/0, 2, 4)));
467 }
468 
469 TEST(DataTest, StackTransformWorksForTensorExample) {
470  auto d = datasets::TensorDataset(torch::eye(4))
472 
473  TensorExample batch = d.get_batch({0, 1});
474  ASSERT_TRUE(batch.data.allclose(torch::eye(4).slice(/*dim=*/0, 0, 2)));
475 
476  TensorExample second = d.get_batch({2, 3});
477  ASSERT_TRUE(second.data.allclose(torch::eye(4).slice(/*dim=*/0, 2, 4)));
478 }
479 
480 // Template classes cannot be nested in functions.
481 template <typename Target>
482 struct T : transforms::TensorTransform<Target> {
484  return input * 2;
485  }
486 };
487 
489  : datasets::
490  Dataset<TensorStringDataset, Example<torch::Tensor, std::string>> {
491  Example<torch::Tensor, std::string> get(size_t index) override {
492  return {torch::tensor(static_cast<double>(index)), std::to_string(index)};
493  }
494 
495  torch::optional<size_t> size() const override {
496  return 100;
497  }
498 };
499 
500 TEST(DataTest, TensorTransformWorksForAnyTargetType) {
501  auto d = TensorStringDataset().map(T<std::string>{});
502  std::vector<Example<torch::Tensor, std::string>> batch = d.get_batch({1, 2});
503 
504  ASSERT_EQ(batch.size(), 2);
505  ASSERT_TRUE(batch[0].data.allclose(torch::tensor(2.0)));
506  ASSERT_EQ(batch[0].target, "1");
507 
508  ASSERT_TRUE(batch[1].data.allclose(torch::tensor(4.0)));
509  ASSERT_EQ(batch[1].target, "2");
510 }
511 
512 TEST(DataTest, TensorLambdaWorksforAnyTargetType) {
514  [](torch::Tensor input) { return input * 2; }));
515  std::vector<Example<torch::Tensor, std::string>> batch = d.get_batch({1, 2});
516 
517  ASSERT_EQ(batch.size(), 2);
518  ASSERT_TRUE(batch[0].data.allclose(torch::tensor(2.0)));
519  ASSERT_EQ(batch[0].target, "1");
520 
521  ASSERT_TRUE(batch[1].data.allclose(torch::tensor(4.0)));
522  ASSERT_EQ(batch[1].target, "2");
523 }
524 
526  : datasets::Dataset<DummyTensorDataset, Example<torch::Tensor, int>> {
527  Example<torch::Tensor, int> get(size_t index) override {
528  const auto channels = static_cast<int64_t>(index);
529  torch::Tensor tensor =
530  (channels > 0) ? torch::ones({channels, 4, 4}) : torch::ones({4, 4});
531  return {tensor, static_cast<int>(channels)};
532  }
533 
534  torch::optional<size_t> size() const override {
535  return 100;
536  }
537 };
538 
539 TEST(DataTest, NormalizeTransform) {
540  auto dataset = DummyTensorDataset().map(transforms::Normalize<int>(0.5, 0.1));
541 
542  // Works for zero (one implicit) channels
543  std::vector<Example<torch::Tensor, int>> output = dataset.get_batch(0);
544  ASSERT_EQ(output.size(), 1);
545  // (1 - 0.5) / 0.1 = 5
546  ASSERT_TRUE(output[0].data.allclose(torch::ones({4, 4}) * 5))
547  << output[0].data;
548 
549  // Works for one explicit channel
550  output = dataset.get_batch(1);
551  ASSERT_EQ(output.size(), 1);
552  ASSERT_EQ(output[0].data.size(0), 1);
553  ASSERT_TRUE(output[0].data.allclose(torch::ones({1, 4, 4}) * 5))
554  << output[0].data;
555 
556  // Works for two channels with different moments
557  dataset = DummyTensorDataset().map(
558  transforms::Normalize<int>({0.5, 1.5}, {0.1, 0.2}));
559  output = dataset.get_batch(2);
560  ASSERT_EQ(output.size(), 1);
561  ASSERT_EQ(output[0].data.size(0), 2);
562  ASSERT_TRUE(output[0]
563  .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1)
564  .allclose(torch::ones({1, 4, 4}) * 5))
565  << output[0].data;
566  ASSERT_TRUE(output[0]
567  .data.slice(/*dim=*/0, /*start=*/1)
568  .allclose(torch::ones({1, 4, 4}) * -2.5))
569  << output[0].data;
570 
571  // Works for three channels with one moment value
572  dataset = DummyTensorDataset().map(transforms::Normalize<int>(1.5, 0.2));
573  output = dataset.get_batch(3);
574  ASSERT_EQ(output.size(), 1);
575  ASSERT_EQ(output[0].data.size(0), 3);
576  ASSERT_TRUE(output[0].data.allclose(torch::ones({3, 4, 4}) * -2.5))
577  << output[0].data;
578 
579  // Works for three channels with different moments
580  dataset = DummyTensorDataset().map(
581  transforms::Normalize<int>({0.5, 1.5, -1.5}, {0.1, 0.2, 0.2}));
582  output = dataset.get_batch(3);
583  ASSERT_EQ(output.size(), 1);
584  ASSERT_EQ(output[0].data.size(0), 3);
585  ASSERT_TRUE(output[0]
586  .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1)
587  .allclose(torch::ones({1, 4, 4}) * 5))
588  << output[0].data;
589  ASSERT_TRUE(output[0]
590  .data.slice(/*dim=*/0, /*start=*/1, /*end=*/2)
591  .allclose(torch::ones({1, 4, 4}) * -2.5))
592  << output[0].data;
593  ASSERT_TRUE(output[0]
594  .data.slice(/*dim=*/0, /*start=*/2)
595  .allclose(torch::ones({1, 4, 4}) * 12.5))
596  << output[0].data;
597 }
598 
599 struct UnCopyableDataset : public datasets::Dataset<UnCopyableDataset> {
600  UnCopyableDataset() = default;
601 
602  UnCopyableDataset(const UnCopyableDataset&) = delete;
603  UnCopyableDataset& operator=(const UnCopyableDataset&) = delete;
604 
606  UnCopyableDataset& operator=(UnCopyableDataset&&) = default;
607 
608  ~UnCopyableDataset() = default;
609 
610  Example<> get(size_t index) override {
611  return {torch::tensor(static_cast<int64_t>(index)),
612  torch::tensor(static_cast<int64_t>(index))};
613  }
614 
615  torch::optional<size_t> size() const override {
616  return 100;
617  }
618 };
619 
620 TEST(DataTest, MapDoesNotCopy) {
621  auto dataset = UnCopyableDataset()
623  [](torch::Tensor tensor) { return tensor + 1; }))
625  [](torch::Tensor tensor) { return tensor + 2; }))
627  [](torch::Tensor tensor) { return tensor + 3; }));
628 
629  auto data = dataset.get_batch(1).at(0).data;
630  ASSERT_EQ(data.numel(), 1);
631  ASSERT_EQ(data[0].item<float>(), 7);
632 }
633 
634 TEST(DataTest, QueuePushAndPopFromSameThread) {
636  queue.push(1);
637  queue.push(2);
638  ASSERT_EQ(queue.pop(), 1);
639  ASSERT_EQ(queue.pop(), 2);
640 }
641 
642 TEST(DataTest, QueuePopWithTimeoutThrowsUponTimeout) {
644  ASSERT_THROWS_WITH(
645  queue.pop(10 * kMillisecond),
646  "Timeout in DataLoader queue while waiting for next batch "
647  "(timeout was 10 ms)");
648 }
649 
650 TEST(DataTest, QueuePushAndPopFromDifferentThreads) {
652 
653  // First test: push batch and the pop in thread.
654  {
655  Queue<int> queue;
656  queue.push(1);
657  auto future =
658  std::async(std::launch::async, [&queue] { return queue.pop(); });
659  ASSERT_EQ(future.get(), 1);
660  }
661 
662  // Second test: attempt to pop batch (and block), then push.
663  {
664  Queue<int> queue;
665  std::thread thread([&queue] {
666  std::this_thread::sleep_for(20 * kMillisecond);
667  queue.push(123);
668  });
669  ASSERT_EQ(queue.pop(), 123);
670  thread.join();
671  }
672 }
673 
674 TEST(DataTest, QueueClearEmptiesTheQueue) {
676  queue.push(1);
677  queue.push(2);
678  queue.push(3);
679  ASSERT_EQ(queue.clear(), 3);
680  ASSERT_THROWS_WITH(queue.pop(1 * kMillisecond), "Timeout");
681 }
682 
683 TEST(DataTest, DataShuttleCanPushAndPopJob) {
685  shuttle.push_job(1);
686  shuttle.push_job(2);
687  ASSERT_EQ(shuttle.pop_job(), 1);
688  ASSERT_EQ(shuttle.pop_job(), 2);
689 }
690 
691 TEST(DataTest, DataShuttleCanPushAndPopResult) {
693  // pop_result() will only attempt to pop if there was a push_job() batch.
694  shuttle.push_job(1);
695  shuttle.push_job(2);
696 
697  shuttle.pop_job();
698  shuttle.push_result(1);
699  ASSERT_EQ(shuttle.pop_result().value(), 1);
700 
701  shuttle.pop_job();
702  shuttle.push_result(2);
703  ASSERT_EQ(shuttle.pop_result().value(), 2);
704 }
705 
706 TEST(DataTest, DataShuttlePopResultReturnsNulloptWhenNoJobsInFlight) {
708  ASSERT_FALSE(shuttle.pop_result().has_value());
709  shuttle.push_job(1);
710  shuttle.pop_job();
711  shuttle.push_result(1);
712  ASSERT_EQ(shuttle.pop_result().value(), 1);
713  ASSERT_FALSE(shuttle.pop_result().has_value());
714  ASSERT_FALSE(shuttle.pop_result().has_value());
715 }
716 
717 TEST(DataTest, DataShuttleDrainMeansPopResultReturnsNullopt) {
719  shuttle.push_job(1);
720  shuttle.push_result(1);
721  shuttle.drain();
722  ASSERT_FALSE(shuttle.pop_result().has_value());
723 }
724 
725 TEST(DataTest, DataShuttlePopResultTimesOut) {
727  shuttle.push_job(1);
728  ASSERT_THROWS_WITH(shuttle.pop_result(10 * kMillisecond), "Timeout");
729 }
730 
731 struct UncopyableDataset : datasets::Dataset<UncopyableDataset, int> {
732  UncopyableDataset(const std::string& /* unused */) {}
733 
735  UncopyableDataset& operator=(UncopyableDataset&&) = default;
736 
737  UncopyableDataset(const UncopyableDataset&) = delete;
738  UncopyableDataset& operator=(const UncopyableDataset&) = delete;
739 
740  int get(size_t index) override {
741  return 1 + index;
742  }
743  torch::optional<size_t> size() const override {
744  return 100;
745  }
746 };
747 
748 TEST(DataTest, SharedBatchDatasetReallyIsShared) {
749  // This test will only compile if we really are not making any copies.
750  // There is otherwise no logic to test and because it is not deterministic
751  // how many and when worker threads access the shareddataset, we don't have
752  // any additional assertions here.
753 
754  auto shared_dataset =
755  torch::data::datasets::make_shared_dataset<UncopyableDataset>(
756  "uncopyable");
757 
758  auto data_loader = torch::data::make_data_loader(
759  shared_dataset, torch::data::DataLoaderOptions().workers(3));
760 
761  for (auto batch : *data_loader) {
762  /* exhaust */
763  }
764 }
765 
766 TEST(DataTest, SharedBatchDatasetDoesNotIncurCopyWhenPassedDatasetObject) {
767  // This will not compile if a copy is made.
768  auto shared_dataset =
769  torch::data::datasets::make_shared_dataset<UncopyableDataset>(
770  UncopyableDataset("uncopyable"));
771  ASSERT_EQ(shared_dataset.size().value(), 100);
772 }
773 
775  explicit TestIndex(size_t offset, std::vector<size_t> index)
776  : offset(offset), index(std::move(index)) {}
777  size_t size() const override {
778  return index.size();
779  }
780  size_t offset;
781  std::vector<size_t> index;
782 };
783 
785  : datasets::BatchDataset<TestIndexDataset, std::vector<int>, TestIndex> {
786  explicit TestIndexDataset(size_t size) : data(size) {
787  std::iota(data.begin(), data.end(), size_t(0));
788  }
789  std::vector<int> get_batch(TestIndex index) override {
790  std::vector<int> batch;
791  for (auto i : index.index) {
792  batch.push_back(index.offset + data.at(i));
793  }
794  return batch;
795  }
796  torch::optional<size_t> size() const override {
797  return data.size();
798  }
799  std::vector<int> data;
800 };
801 
802 struct TestIndexSampler : public samplers::Sampler<TestIndex> {
803  explicit TestIndexSampler(size_t size) : size_(size) {}
804  void reset(torch::optional<size_t> new_size = torch::nullopt) override {}
805  torch::optional<TestIndex> next(size_t batch_size) override {
806  if (index_ >= size_) {
807  return torch::nullopt;
808  }
809  std::vector<size_t> indices(batch_size);
810  std::iota(indices.begin(), indices.end(), size_t(0));
811  index_ += batch_size;
812  return TestIndex(batch_size, std::move(indices));
813  }
814  void save(torch::serialize::OutputArchive& archive) const override {}
815  void load(torch::serialize::InputArchive& archive) override {}
816  size_t index_ = 0;
817  size_t size_;
818 };
819 
820 TEST(DataTest, CanUseCustomTypeAsIndexType) {
821  const int kBatchSize = 10;
822  auto data_loader = torch::data::make_data_loader(
823  TestIndexDataset(23), TestIndexSampler(23), kBatchSize);
824 
825  size_t i = 0;
826  for (auto batch : *data_loader) {
827  for (int j = 0; j < kBatchSize; ++j) {
828  ASSERT_EQ(batch.at(j), 10 + j);
829  }
830  i += 1;
831  }
832 }
833 
834 TEST(DataTest, DistributedRandomSamplerSingleReplicaProduceCorrectSamples) {
835  size_t sample_count = 10;
836  samplers::DistributedRandomSampler drs(sample_count);
837 
838  std::vector<size_t> res;
840  while ((idx = drs.next(3)).has_value()) {
841  res.insert(std::end(res), std::begin(*idx), std::end(*idx));
842  }
843 
844  ASSERT_EQ(res.size(), sample_count);
845 
846  std::sort(res.begin(), res.end());
847  for (size_t i = 0; i < res.size(); ++i) {
848  ASSERT_EQ(res[i], i);
849  }
850 }
851 
852 TEST(DataTest, DistributedRandomSamplerMultiReplicaProduceCorrectSamples) {
853  size_t sample_count = 10;
854  size_t num_replicas = 3;
855 
856  auto test_function = [&](bool allow_duplicates,
857  size_t local_sample_count,
858  std::vector<size_t>& output,
859  size_t batch_size) {
860  std::vector<std::unique_ptr<samplers::DistributedRandomSampler>> samplers;
861 
862  for (size_t i = 0; i < num_replicas; ++i) {
863  samplers.emplace_back(
864  torch::make_unique<samplers::DistributedRandomSampler>(
865  sample_count, num_replicas, i, allow_duplicates));
866  }
867 
868  std::vector<size_t> res;
869  for (size_t i = 0; i < num_replicas; ++i) {
870  (*samplers[i]).reset();
872  while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
873  res.insert(std::end(res), std::begin(*idx), std::end(*idx));
874  }
875  ASSERT_EQ(res.size(), local_sample_count * (i + 1));
876  }
877  std::sort(res.begin(), res.end());
878  ASSERT_EQ(res, output);
879  };
880 
881  for (size_t batch_size = 1; batch_size <= 3; ++batch_size) {
882  size_t local_sample_count =
883  static_cast<size_t>(std::ceil(sample_count * 1.0 / num_replicas));
884  std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9};
885  test_function(true, local_sample_count, output1, batch_size);
886 
887  local_sample_count =
888  static_cast<size_t>(std::floor(sample_count * 1.0 / num_replicas));
889  std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8};
890  test_function(false, local_sample_count, output2, batch_size);
891  }
892 }
893 
894 TEST(DataTest, CanSaveAndLoadDistributedRandomSampler) {
895  {
897  ASSERT_EQ(a.index(), 0);
898  std::stringstream stream;
899  torch::save(a, stream);
900 
902  torch::load(b, stream);
903  ASSERT_EQ(b.index(), 0);
904  }
905  {
907  a.next(3);
908  a.next(4);
909  ASSERT_EQ(a.index(), 7);
910  std::stringstream stream;
911  torch::save(a, stream);
912 
914  torch::load(b, stream);
915  ASSERT_EQ(b.index(), 7);
916  }
917  {
919  a.set_epoch(3);
920  std::stringstream stream;
921  torch::save(a, stream);
922 
924  torch::load(b, stream);
925  ASSERT_EQ(b.epoch(), 3);
926  }
927 }
928 
929 TEST(DataTest, DistributedSequentialSamplerSingleReplicaProduceCorrectSamples) {
930  size_t sample_count = 10;
931  size_t batch_size = 3;
932  samplers::DistributedSequentialSampler dss(sample_count);
933 
934  std::vector<size_t> res;
936  while ((idx = dss.next(batch_size)).has_value()) {
937  res.insert(std::end(res), std::begin(*idx), std::end(*idx));
938  }
939 
940  ASSERT_EQ(res.size(), sample_count);
941 
942  std::sort(res.begin(), res.end());
943  for (size_t i = 0; i < res.size(); ++i) {
944  ASSERT_EQ(res[i], i);
945  }
946 }
947 
948 TEST(DataTest, DistributedSequentialSamplerMultiReplicaProduceCorrectSamples) {
949  size_t sample_count = 10;
950  size_t num_replicas = 3;
951 
952  auto test_function = [&](bool allow_duplicates,
953  size_t local_sample_count,
954  std::vector<size_t>& output,
955  size_t batch_size) {
956  std::vector<std::unique_ptr<samplers::DistributedSequentialSampler>>
957  samplers;
958 
959  for (size_t i = 0; i < num_replicas; ++i) {
960  samplers.emplace_back(
961  torch::make_unique<samplers::DistributedSequentialSampler>(
962  sample_count, num_replicas, i, allow_duplicates));
963  }
964 
965  std::vector<size_t> res;
966  for (size_t i = 0; i < num_replicas; ++i) {
967  (*samplers[i]).reset();
969  while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
970  res.insert(std::end(res), std::begin(*idx), std::end(*idx));
971  }
972  ASSERT_EQ(res.size(), local_sample_count * (i + 1));
973  }
974  std::sort(res.begin(), res.end());
975  ASSERT_EQ(res, output);
976  };
977 
978  for (size_t batch_size = 1; batch_size <= 3; ++batch_size) {
979  size_t local_sample_count =
980  static_cast<size_t>(std::ceil(sample_count * 1.0 / num_replicas));
981  std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9};
982  test_function(true, local_sample_count, output1, batch_size);
983 
984  local_sample_count =
985  static_cast<size_t>(std::floor(sample_count * 1.0 / num_replicas));
986  std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8};
987  test_function(false, local_sample_count, output2, batch_size);
988  }
989 }
990 
991 TEST(DataTest, CanSaveAndLoadDistributedSequentialSampler) {
992  {
994  ASSERT_EQ(a.index(), 0);
995  std::stringstream stream;
996  torch::save(a, stream);
997 
999  torch::load(b, stream);
1000  ASSERT_EQ(b.index(), 0);
1001  }
1002  {
1004  a.next(3);
1005  a.next(4);
1006  ASSERT_EQ(a.index(), 7);
1007  std::stringstream stream;
1008  torch::save(a, stream);
1009 
1011  torch::load(b, stream);
1012  ASSERT_EQ(b.index(), 7);
1013  }
1014 }
1015 
1016 TEST(DataLoaderTest, DataLoaderOptionsDefaultAsExpected) {
1017  DataLoaderOptions partial_options;
1018  FullDataLoaderOptions full_options(partial_options);
1019  ASSERT_EQ(full_options.batch_size, 1);
1020  ASSERT_FALSE(full_options.drop_last);
1021  ASSERT_EQ(full_options.workers, 0);
1022  ASSERT_EQ(full_options.max_jobs, 0);
1023  ASSERT_FALSE(full_options.timeout.has_value());
1024  ASSERT_TRUE(full_options.enforce_ordering);
1025 }
1026 
1027 TEST(DataLoaderTest, DataLoaderOptionsCoalesceOptionalValues) {
1028  auto partial_options = DataLoaderOptions(32).workers(10);
1029  FullDataLoaderOptions full_options(partial_options);
1030  ASSERT_EQ(full_options.batch_size, 32);
1031  ASSERT_EQ(full_options.max_jobs, 2 * 10);
1032 }
1033 
1034 TEST(DataLoaderTest, MakeDataLoaderDefaultsAsExpected) {
1035  auto data_loader = torch::data::make_data_loader(
1036  DummyDataset().map(transforms::Lambda<int>([](int x) { return x + 1; })));
1037  ASSERT_EQ(data_loader->options().batch_size, 1);
1038 }
1039 
1040 struct UnsizedDataset : public datasets::Dataset<UnsizedDataset> {
1041  torch::data::Example<> get(size_t i) {
1042  return {torch::ones(i), torch::ones(i)};
1043  }
1044  torch::optional<size_t> size() const noexcept {
1045  return torch::nullopt;
1046  }
1047 };
1048 
1049 TEST(
1050  DataLoaderTest,
1051  MakeDataLoaderThrowsWhenConstructingSamplerWithUnsizedDataset) {
1052  ASSERT_THROWS_WITH(
1053  torch::data::make_data_loader(UnsizedDataset{}),
1054  "Expected the dataset to be sized in order to construct the Sampler");
1055 }
1056 
1057 TEST(DataLoaderTest, IteratorsCompareEqualToThemselves) {
1058  auto data_loader = torch::data::make_data_loader(DummyDataset(), 32);
1059  auto begin = data_loader->begin();
1060  ASSERT_EQ(begin, begin);
1061  auto end = data_loader->end();
1062  ASSERT_EQ(end, end);
1063 }
1064 
1065 TEST(DataLoaderTest, ValidIteratorsCompareUnequalToEachOther) {
1066  auto data_loader = torch::data::make_data_loader(DummyDataset(), 32);
1067  auto i = data_loader->begin();
1068  auto j = data_loader->begin();
1069  ASSERT_NE(i, j);
1070  ++j;
1071  ASSERT_NE(i, j);
1072 }
1073 
1074 TEST(DataLoaderTest, SentinelIteratorsCompareEqualToEachOther) {
1075  auto data_loader = torch::data::make_data_loader(DummyDataset(), 32);
1076  auto i = data_loader->end();
1077  auto j = data_loader->end();
1078  ASSERT_EQ(i, j);
1079 }
1080 
1081 TEST(DataLoaderTest, IteratorsCompareEqualToSentinelWhenExhausted) {
1082  DummyDataset dataset;
1083  auto data_loader =
1084  torch::data::make_data_loader(dataset, dataset.size().value() / 4);
1085  auto i = data_loader->begin();
1086  auto end = data_loader->end();
1087  ASSERT_NE(i, end);
1088  ++i;
1089  ASSERT_NE(i, end);
1090  ++i;
1091  ASSERT_NE(i, end);
1092  ++i;
1093  ASSERT_NE(i, end);
1094  ++i;
1095  ASSERT_EQ(i, end);
1096 }
1097 
1098 TEST(DataLoaderTest, IteratorsShareState) {
1099  DummyDataset dataset;
1100  auto data_loader =
1101  torch::data::make_data_loader(dataset, dataset.size().value() / 2);
1102  auto i = data_loader->begin();
1103  auto j = i;
1104  auto end = data_loader->end();
1105  ASSERT_NE(i, end);
1106  ASSERT_NE(j, end);
1107  ++i;
1108  ASSERT_NE(i, end);
1109  ASSERT_NE(j, end);
1110  ++j;
1111  ASSERT_EQ(i, end);
1112  ASSERT_EQ(j, end);
1113 }
1114 
1115 TEST(DataLoaderTest, CanDereferenceIteratorMultipleTimes) {
1116  DummyDataset dataset;
1117  auto data_loader =
1118  torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
1119  dataset,
1120  /*batch_size=*/1);
1121  auto iterator = data_loader->begin();
1122  std::vector<int> expected = {1};
1123  ASSERT_EQ(*iterator, expected);
1124  ASSERT_EQ(*iterator, expected);
1125  ++iterator;
1126  expected[0] = 2;
1127  ASSERT_EQ(*iterator, expected);
1128  ASSERT_EQ(*iterator, expected);
1129  ++iterator;
1130  expected[0] = 3;
1131  ASSERT_EQ(*iterator, expected);
1132  ASSERT_EQ(*iterator, expected);
1133 }
1134 
1135 TEST(DataLoaderTest, CanUseIteratorAlgorithms) {
1136  struct D : datasets::BatchDataset<D, int> {
1137  int get_batch(torch::ArrayRef<size_t> indices) override {
1138  return 1 + indices.front();
1139  }
1140  torch::optional<size_t> size() const override {
1141  return 10;
1142  }
1143  };
1144 
1145  D dataset;
1146  auto data_loader =
1147  torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
1148  dataset, 1);
1149  std::vector<int> values;
1150  std::copy(
1151  data_loader->begin(), data_loader->end(), std::back_inserter(values));
1152  std::vector<int> expected(dataset.size().value());
1153  std::iota(expected.begin(), expected.end(), size_t(1));
1154  ASSERT_EQ(values, expected);
1155 }
1156 
1157 TEST(DataLoaderTest, CallingBeginWhileOtherIteratorIsInFlightThrows) {
1158  DummyDataset dataset;
1159  auto data_loader =
1160  torch::data::make_data_loader(dataset, DataLoaderOptions(1).workers(2));
1161  auto i = data_loader->begin();
1162  ASSERT_THROWS_WITH(
1163  data_loader->begin(),
1164  "Attempted to get a new DataLoader iterator "
1165  "while another iterator is not yet exhausted");
1166 }
1167 
1168 TEST(DataLoaderTest, IncrementingExhaustedValidIteratorThrows) {
1169  DummyDataset dataset;
1170  auto data_loader =
1171  torch::data::make_data_loader(dataset, dataset.size().value());
1172  auto i = data_loader->begin();
1173  ASSERT_NO_THROW(++i);
1174  ASSERT_THROWS_WITH(++i, "Attempted to increment iterator past the end");
1175 }
1176 
1177 TEST(DataLoaderTest, DereferencingExhaustedValidIteratorThrows) {
1178  DummyDataset dataset;
1179  auto data_loader =
1180  torch::data::make_data_loader(dataset, dataset.size().value());
1181  auto i = data_loader->begin();
1182  ASSERT_NO_THROW(++i);
1183  ASSERT_THROWS_WITH(
1184  *i, "Attempted to dereference iterator that was past the end");
1185 }
1186 
1187 TEST(DataLoaderTest, IncrementingSentinelIteratorThrows) {
1188  DummyDataset dataset;
1189  auto data_loader =
1190  torch::data::make_data_loader(dataset, dataset.size().value());
1191  auto i = data_loader->end();
1192  ASSERT_THROWS_WITH(
1193  ++i,
1194  "Incrementing the DataLoader's past-the-end iterator is not allowed");
1195 }
1196 
1197 TEST(DataLoaderTest, DereferencingSentinelIteratorThrows) {
1198  DummyDataset dataset;
1199  auto data_loader =
1200  torch::data::make_data_loader(dataset, dataset.size().value());
1201  auto i = data_loader->end();
1202  ASSERT_THROWS_WITH(
1203  *i,
1204  "Dereferencing the DataLoader's past-the-end iterator is not allowed");
1205 }
1206 
1207 TEST(DataLoaderTest, YieldsCorrectBatchSize) {
1208  DummyDataset dataset;
1209  auto data_loader = torch::data::make_data_loader(dataset, 25);
1210  auto iterator = data_loader->begin();
1211  ASSERT_EQ(iterator->size(), 25);
1212  ASSERT_EQ((++iterator)->size(), 25);
1213  ASSERT_EQ((++iterator)->size(), 25);
1214  ASSERT_EQ((++iterator)->size(), 25);
1215  ASSERT_EQ(++iterator, data_loader->end());
1216 }
1217 
1218 TEST(
1219  DataLoaderTest,
1220  ReturnsLastBatchWhenSmallerThanBatchSizeWhenDropLastIsFalse) {
1221  DummyDataset dataset;
1222  auto data_loader = torch::data::make_data_loader(
1223  dataset, DataLoaderOptions(33).drop_last(false));
1224  auto iterator = data_loader->begin();
1225  ASSERT_EQ(iterator->size(), 33);
1226  ASSERT_EQ((++iterator)->size(), 33);
1227  ASSERT_EQ((++iterator)->size(), 33);
1228  ASSERT_EQ((++iterator)->size(), 1);
1229  ASSERT_EQ(++iterator, data_loader->end());
1230 }
1231 
1232 TEST(
1233  DataLoaderTest,
1234  DoesNotReturnLastBatchWhenSmallerThanBatchSizeWhenDropLastIsTrue) {
1235  DummyDataset dataset;
1236  auto data_loader = torch::data::make_data_loader(
1237  dataset, DataLoaderOptions(33).drop_last(true));
1238  auto iterator = data_loader->begin();
1239  ASSERT_EQ(iterator->size(), 33);
1240  ASSERT_EQ((++iterator)->size(), 33);
1241  ASSERT_EQ((++iterator)->size(), 33);
1242  ASSERT_EQ(++iterator, data_loader->end());
1243 }
1244 
1245 TEST(DataLoaderTest, RespectsTimeout) {
1246  struct Baton {
1247  std::condition_variable cv;
1248  std::mutex mutex;
1249  };
1250 
1251  struct D : datasets::Dataset<DummyDataset, int> {
1252  D(std::shared_ptr<Baton> b) : baton(std::move(b)) {}
1253  int get(size_t index) override {
1254  std::unique_lock<std::mutex> lock(baton->mutex);
1255  baton->cv.wait_for(lock, 1000 * kMillisecond);
1256  return 0;
1257  }
1258  torch::optional<size_t> size() const override {
1259  return 100;
1260  }
1261  std::shared_ptr<Baton> baton;
1262  };
1263 
1264  auto baton = std::make_shared<Baton>();
1265 
1266  auto data_loader = torch::data::make_data_loader(
1267  D{baton}, DataLoaderOptions().workers(1).timeout(10 * kMillisecond));
1268 
1269  auto start = std::chrono::system_clock::now();
1270 
1271  ASSERT_THROWS_WITH(*data_loader->begin(), "Timeout");
1272  baton->cv.notify_one();
1273 
1274  auto end = std::chrono::system_clock::now();
1275  auto duration = std::chrono::duration_cast<std::chrono::seconds>(end - start);
1276  ASSERT_LT(duration.count(), 1);
1277 }
1278 
1279 // stackoverflow.com/questions/24465533/implementing-boostbarrier-in-c11
1280 struct Barrier {
1281  explicit Barrier(size_t target) : counter_(target) {}
1282  void wait() {
1283  std::unique_lock<std::mutex> lock(mutex_);
1284  if (--counter_ == 0) {
1285  cv_.notify_all();
1286  } else {
1287  cv_.wait(lock, [this] { return this->counter_ == 0; });
1288  }
1289  }
1290 
1291  size_t counter_;
1292  std::condition_variable cv_;
1293  std::mutex mutex_;
1294 };
1295 
1296 // On the OrderingTest: This test is intended to verify that the
1297 // `enforce_ordering` option of the dataloader works correctly. The reason this
1298 // flag exists is because when the dataloader has multiple workers (threads)
1299 // enabled and this flag is not set, the order in which worker threads finish
1300 // loading their respective batch and push it back to the dataloader's main
1301 // thread (for outside consumption) is not deterministic. Imagine the sampler is
1302 // a SequentialSampler with indices 0, 1, 2, 3. With batch size 1, each index
1303 // will be a single "job". Inside the dataloader, worker threads block until a
1304 // job is available. It is not deterministic which worker thread wakes up batch
1305 // to dequeue a particular batch. Further, some worker threads may take longer
1306 // than others to read the data for their index. As such, it could be that
1307 // worker thread 2 finishes before all other threads and returns its batch to
1308 // the main thread. In that case, the dataloader iterator would return the datum
1309 // at index 2 batch, and afterwards the datum from whatever thread finishes
1310 // next. As such, the user may see data from indices 2, 0, 3, 1. On another run
1311 // of the same dataloader on the same data, threads may be scheduled differently
1312 // and return in order 0, 2, 3, 1. To force this ordering to deterministically
1313 // be 0, 1, 2, 3, the `enforce_ordering` flag can be set to true. In that case,
1314 // the dataloader will use a *sequencer* internally which keeps track of which
1315 // datum is expected next, and buffers any other results until that next
1316 // expected value arrives. For example, workers 1, 2, 3 may finish before worker
1317 // 0. If `enforce_ordering` is true, the sequencer will internally buffer the
1318 // results from 1, 2, 3 until worker 0 finishes. Only then does the dataloader
1319 // return the datum from worker 0 to the user (and then datum 1 the next time,
1320 // then 2 and so on).
1321 //
1322 // The way the test works is that we start
1323 // `kNumberOfWorkers` workers in the dataloader, which each get an index from a
1324 // `SequentialSampler` in the range `0...kNumberOfWorkers-1`. Each worker thread
1325 // has a copy of the dataset, and thus `get_batch()` is called on the
1326 // thread-local copy in each worker. We want to simulate out-of-order completion
1327 // of these threads. For this, we batch set a barrier in the `get_batch()`
1328 // method to make sure every worker has some index to fetch assigned. Further,
1329 // each worker thread has a unique ID in `0...kNumberOfWorkers-1`.
1330 // There is a hard-coded ordering, `kOrderInWhichWorkersReturnTheirBatch`, in
1331 // which we want the worker threads to return. For this, an iterator into this
1332 // order is maintained. When the derferenced iterator (the current order index)
1333 // matches the thread ID of a worker, it knows it can now return its index as
1334 // well as progress the iterator. Inside the dataloader, the sequencer should
1335 // buffer these indices such that they are ultimately returned in order.
1336 
1337 namespace ordering_test {
1338 namespace {
1339 const size_t kNumberOfWorkers = 10;
1340 const std::vector<size_t> kOrderInWhichWorkersReturnTheirBatch =
1341  {3, 7, 0, 5, 4, 8, 2, 1, 9, 6};
1342 } // namespace
1343 
1344 struct Dataset : datasets::BatchDataset<Dataset, size_t> {
1345  Dataset() = default;
1346 
1347  // This copy constructor will be called when we copy the dataset into a
1348  // particular thread.
1349  Dataset(const Dataset& other) {
1350  static std::atomic<size_t> counter{0};
1351  thread_id_ = counter.fetch_add(1);
1352  }
1353 
1354  Dataset(Dataset&& other) noexcept = default;
1355  Dataset& operator=(const Dataset& other) = delete;
1356  Dataset& operator=(Dataset&& other) noexcept = delete;
1357 
1358  size_t get_batch(torch::ArrayRef<size_t> indices) override {
1359  static Barrier barrier(kNumberOfWorkers);
1360  static auto order_iterator = kOrderInWhichWorkersReturnTheirBatch.begin();
1361  static std::condition_variable cv;
1362  static std::mutex mutex;
1363 
1364  // Wait for all threads to get an index batch and arrive here.
1365  barrier.wait();
1366 
1367  std::unique_lock<std::mutex> lock(mutex);
1368  cv.wait(lock, [this] { return *order_iterator == this->thread_id_; });
1369  ++order_iterator;
1370  lock.unlock();
1371  cv.notify_all();
1372 
1373  return indices.front();
1374  }
1375 
1376  torch::optional<size_t> size() const override {
1377  return kNumberOfWorkers;
1378  }
1379 
1380  size_t thread_id_ = 0;
1381 };
1382 
1383 } // namespace ordering_test
1384 
1385 TEST(DataLoaderTest, EnforcesOrderingAmongThreadsWhenConfigured) {
1386  auto data_loader = torch::data::make_data_loader(
1388  torch::data::samplers::SequentialSampler(ordering_test::kNumberOfWorkers),
1390  .batch_size(1)
1391  .workers(ordering_test::kNumberOfWorkers)
1392  .enforce_ordering(true));
1393  std::vector<size_t> output;
1394  for (size_t value : *data_loader) {
1395  output.push_back(value);
1396  }
1397  std::vector<size_t> expected(ordering_test::kNumberOfWorkers);
1398  std::iota(expected.begin(), expected.end(), size_t(0));
1399  ASSERT_EQ(expected, output);
1400 }
1401 
1402 TEST(DataLoaderTest, Reset) {
1403  DummyDataset dataset;
1404  auto data_loader =
1405  torch::data::make_data_loader(dataset, dataset.size().value() / 2);
1406  auto end = data_loader->end();
1407 
1408  auto iterator = data_loader->begin();
1409  ASSERT_NE(iterator, end);
1410  ASSERT_NE(++iterator, end);
1411  ASSERT_EQ(++iterator, end);
1412 
1413  iterator = data_loader->begin();
1414  ASSERT_NE(iterator, end);
1415  ASSERT_NE(++iterator, end);
1416  ASSERT_EQ(++iterator, end);
1417 
1418  iterator = data_loader->begin();
1419  ASSERT_NE(iterator, end);
1420  ASSERT_NE(++iterator, end);
1421  ASSERT_EQ(++iterator, end);
1422 }
1423 
1424 TEST(DataLoaderTest, TestExceptionsArePropagatedFromWorkers) {
1425  struct D : datasets::Dataset<DummyDataset, int> {
1426  int get(size_t index) override {
1427  throw std::invalid_argument("badness");
1428  }
1429  torch::optional<size_t> size() const override {
1430  return 100;
1431  }
1432  };
1433 
1434  auto data_loader = torch::data::make_data_loader(
1435  D{}, samplers::RandomSampler(100), DataLoaderOptions().workers(2));
1436  auto iterator = data_loader->begin();
1437 
1438  try {
1439  (void)*iterator;
1440  } catch (torch::data::WorkerException& e) {
1441  ASSERT_EQ(
1442  e.what(),
1443  std::string("Caught exception in DataLoader worker thread. "
1444  "Original message: badness"));
1445  ASSERT_THROW(
1446  std::rethrow_exception(e.original_exception), std::invalid_argument);
1447  }
1448 }
1449 
1450 TEST(DataLoaderTest, StatefulDatasetWithNoWorkers) {
1451  const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1452 
1453  struct D : datasets::StatefulDataset<D, int, size_t> {
1454  torch::optional<int> get_batch(size_t) override {
1455  if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1456  return counter++;
1457  }
1458  return torch::nullopt;
1459  }
1460  torch::optional<size_t> size() const override {
1461  return 100;
1462  }
1463  void reset() override {
1464  counter = 0;
1465  }
1466  int counter = 0;
1467  };
1468 
1469  auto data_loader = torch::data::make_data_loader(D{});
1470 
1471  for (size_t i = 0; i < 10; ++i) {
1472  const auto number_of_iterations =
1473  std::distance(data_loader->begin(), data_loader->end());
1474  ASSERT_EQ(
1475  number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
1476  << "epoch " << i;
1477  }
1478 
1479  for (const int i : *data_loader) {
1480  ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
1481  }
1482 }
1483 
1484 TEST(DataLoaderTest, StatefulDatasetWithManyWorkers) {
1485  const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1486  const int kNumberOfWorkers = 4;
1487 
1488  struct D : datasets::StatefulDataset<D, int, size_t> {
1489  torch::optional<int> get_batch(size_t) override {
1490  std::lock_guard<std::mutex> lock(mutex);
1491  if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1492  return counter++;
1493  }
1494  return torch::nullopt;
1495  }
1496  torch::optional<size_t> size() const override {
1497  return 100;
1498  }
1499  void reset() override {
1500  counter = 0;
1501  }
1502  int counter = 0;
1503  std::mutex mutex;
1504  };
1505 
1506  auto data_loader = torch::data::make_data_loader(
1507  torch::data::datasets::make_shared_dataset<D>(),
1508  DataLoaderOptions().workers(kNumberOfWorkers));
1509 
1510  for (size_t i = 0; i < 10; ++i) {
1511  const auto number_of_iterations =
1512  std::distance(data_loader->begin(), data_loader->end());
1513  ASSERT_EQ(
1514  number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
1515  << "epoch " << i;
1516  }
1517 
1518  for (const int i : *data_loader) {
1519  ASSERT_LT(i, kNumberOfExamplesAfterWhichTheDatasetExhausts);
1520  }
1521 }
1522 
1523 TEST(DataLoaderTest, StatefulDatasetWithMap) {
1524  const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1525 
1526  struct D : datasets::StatefulDataset<D, int, size_t> {
1527  torch::optional<int> get_batch(size_t) override {
1528  if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1529  return counter++;
1530  }
1531  return torch::nullopt;
1532  }
1533  torch::optional<size_t> size() const override {
1534  return 100;
1535  }
1536  void reset() override {
1537  counter = 0;
1538  }
1539  int counter = 0;
1540  };
1541 
1542  auto data_loader = torch::data::make_data_loader(
1544  [](int x) { return std::to_string(x); }))
1546  [](const std::string& x) {
1547  return torch::tensor(static_cast<int64_t>(std::stoi(x)));
1548  })),
1549  DataLoaderOptions{});
1550 
1551  for (size_t i = 0; i < 10; ++i) {
1552  const auto number_of_iterations =
1553  std::distance(data_loader->begin(), data_loader->end());
1554  ASSERT_EQ(
1555  number_of_iterations, kNumberOfExamplesAfterWhichTheDatasetExhausts)
1556  << "epoch " << i;
1557  }
1558 
1559  for (const torch::Tensor& t : *data_loader) {
1560  ASSERT_LT(t.item<int64_t>(), kNumberOfExamplesAfterWhichTheDatasetExhausts);
1561  }
1562 }
1563 
1564 TEST(DataLoaderTest, StatefulDatasetWithCollate) {
1565  const int kNumberOfExamplesAfterWhichTheDatasetExhausts = 10;
1566 
1567  struct D : datasets::StatefulDataset<D> {
1569  size_t batch_size) override {
1570  if (counter < kNumberOfExamplesAfterWhichTheDatasetExhausts) {
1571  counter += batch_size;
1572  std::vector<Example<>> batch(
1573  /*count=*/batch_size,
1574  Example<>{torch::ones(batch_size + 1),
1575  torch::zeros(batch_size - 1)});
1576  return batch;
1577  }
1578  return torch::nullopt;
1579  }
1580  torch::optional<size_t> size() const override {
1581  return 100;
1582  }
1583  void reset() override {
1584  counter = 0;
1585  }
1586  int counter = 0;
1587  };
1588 
1589  auto d = D().map(transforms::Stack<Example<>>());
1590 
1591  const size_t kBatchSize = 5;
1592 
1593  // Notice that the `get_batch()` of the dataset returns a vector<Example>, but
1594  // the `Stack` collation stacks the tensors into one.
1595  torch::optional<Example<>> batch = d.get_batch(kBatchSize);
1596  ASSERT_TRUE(batch.has_value());
1597  ASSERT_EQ(batch->data.size(0), kBatchSize);
1598  ASSERT_EQ(batch->data.size(1), kBatchSize + 1);
1599  ASSERT_EQ(batch->target.size(0), kBatchSize);
1600  ASSERT_EQ(batch->target.size(1), kBatchSize - 1);
1601 
1602  ASSERT_TRUE(batch->data[0].allclose(torch::ones(kBatchSize + 1)));
1603  ASSERT_TRUE(batch->target[0].allclose(torch::zeros(kBatchSize - 1)));
1604 }
1605 
1606 // This test tests the core function for iterate through a chunk dataset. It
1607 // contains test cases with different parameter combination. (For example,
1608 // different prefetch count, batch size and data loader worker count). It
1609 // verifies the return batches size and content when the order is deterministic.
1610 TEST(DataLoaderTest, ChunkDataSetGetBatch) {
1611  // different prefetch count for testing.
1612  const size_t prefetch_counts[] = {1, 2, 3, 4};
1613 
1614  // different batch size for testing.
1615  const size_t batch_sizes[] = {5, 7};
1616 
1617  // test with/without worker threads
1618  const size_t dataloader_worker_counts[] = {0, 2};
1619 
1620  const size_t total_example_count = 35;
1621  DummyChunkDataReader data_reader;
1622  samplers::SequentialSampler sampler(0);
1623 
1624  // test functionality across epoch boundary
1625  const int epoch_count = 2;
1626 
1627  for (auto prefetch_count : prefetch_counts) {
1628  for (auto batch_size : batch_sizes) {
1629  for (auto dataloader_worker_count : dataloader_worker_counts) {
1633  samplers::SequentialSampler>>
1634  dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1635  DummyChunkDataReader,
1636  samplers::SequentialSampler,
1637  samplers::SequentialSampler>>(
1638  data_reader,
1639  sampler,
1640  sampler,
1641  datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1642 
1643  auto data_loader = torch::data::make_data_loader(
1644  dataset,
1645  DataLoaderOptions(batch_size).workers(dataloader_worker_count));
1646 
1647  for (int epoch_index = 0; epoch_index < epoch_count; ++epoch_index) {
1648  std::vector<bool> result(total_example_count, false);
1649  int iteration_count = 0;
1650  for (auto iterator = data_loader->begin();
1651  iterator != data_loader->end();
1652  ++iterator, ++iteration_count) {
1653  std::vector<int>& batch = *iterator;
1654  ASSERT_EQ(batch.size(), batch_size);
1655 
1656  // When prefetch_count is equal to 1 and no worker thread, the batch
1657  // order is deterministic. So we can verify elements in each batch.
1658  if (prefetch_count == 1 && dataloader_worker_count == 0) {
1659  for (size_t j = 0; j < batch_size; ++j) {
1660  ASSERT_EQ(batch[j], iteration_count * batch_size + j);
1661  }
1662  }
1663  for (size_t j = 0; j < batch_size; ++j) {
1664  result[batch[j]] = true;
1665  }
1666  }
1667 
1668  for (auto data : result) {
1669  ASSERT_EQ(data, true);
1670  }
1671  }
1672  }
1673  }
1674  }
1675 }
1676 
1677 TEST(DataLoaderTest, ChunkDataSetWithBatchSizeMismatch) {
1678  const size_t prefetch_count = 1;
1679  const size_t batch_size = 5;
1680  const size_t requested_batch_size = 6;
1681 
1682  DummyChunkDataReader data_reader;
1683  samplers::SequentialSampler sampler(0);
1684 
1688  samplers::SequentialSampler>>
1689  dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1690  DummyChunkDataReader,
1691  samplers::SequentialSampler,
1692  samplers::SequentialSampler>>(
1693  data_reader,
1694  sampler,
1695  sampler,
1696  datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1697 
1698  auto data_loader = torch::data::make_data_loader(
1699  dataset,
1700  DataLoaderOptions(requested_batch_size).workers(0));
1701 
1702  std::string exception_msg =
1703  "The requested batch size does not match with the initialized batch "
1704  "size.\n The requested batch size is 6, while the dataset is created"
1705  " with batch size equal to 5";
1706 
1707  ASSERT_THROWS_WITH(*(data_loader->begin()), exception_msg);
1708 }
1709 
1710 TEST(DataLoaderTest, ChunkDataSetWithEmptyBatch) {
1711  struct DummyEmptyChunkDataReader
1712  : datasets::ChunkDataReader<std::vector<int>> {
1713  public:
1714  using BatchType = std::vector<int>;
1715 
1716  BatchType read_chunk(size_t chunk_index) override {
1717  return {};
1718  }
1719 
1720  size_t chunk_count() override {
1721  return 1;
1722  };
1723 
1724  void reset() override{};
1725  };
1726 
1727  const size_t prefetch_count = 1;
1728  const size_t batch_size = 5;
1729  DummyEmptyChunkDataReader data_reader;
1730  samplers::SequentialSampler sampler(0);
1731 
1733  DummyEmptyChunkDataReader,
1735  samplers::SequentialSampler>>
1736  dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1737  DummyEmptyChunkDataReader,
1738  samplers::SequentialSampler,
1739  samplers::SequentialSampler>>(
1740  data_reader,
1741  sampler,
1742  sampler,
1743  datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1744 
1745  auto data_loader = torch::data::make_data_loader(
1746  dataset, DataLoaderOptions(batch_size).workers(0));
1747 
1748  for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1749  ++iterator) {
1750  ASSERT_EQ(iterator->size(), 0);
1751  }
1752 }
1753 
1754 TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) {
1755  struct D : public datasets::ChunkDataReader<std::vector<int>> {
1756  public:
1757  using BatchType = std::vector<int>;
1758 
1759  BatchType read_chunk(size_t chunk_index) override {
1760  BatchType batch_data(10, 0);
1761  return batch_data;
1762  }
1763 
1764  size_t chunk_count() override {
1765  return 2;
1766  };
1767 
1768  void reset() override{};
1769  };
1770 
1771  const size_t batch_sizes[] = {17, 30};
1772  D data_reader;
1773  samplers::SequentialSampler sampler(0);
1774 
1775  for (auto batch_size : batch_sizes) {
1777  D,
1779  samplers::SequentialSampler>>
1780  dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1781  D,
1782  samplers::SequentialSampler,
1783  samplers::SequentialSampler>>(
1784  data_reader,
1785  sampler,
1786  sampler,
1787  datasets::ChunkDatasetOptions(1, batch_size));
1788 
1789  auto data_loader = torch::data::make_data_loader(
1790  dataset, DataLoaderOptions(batch_size).workers(0));
1791 
1792  for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1793  ++iterator) {
1794  std::vector<int> batch = *iterator;
1795  auto batch_size = batch.size();
1796  if (batch_size == 17) {
1797  ASSERT_TRUE(batch.size() == 17 || batch.size() == 3);
1798  }
1799  if (batch_size == 30) {
1800  ASSERT_TRUE(batch.size() == 20);
1801  }
1802  }
1803  }
1804 }
1805 
1806 TEST(DataLoaderTest, CanAccessChunkSamplerWithChunkDataSet) {
1807  const size_t prefetch_count = 2;
1808  const size_t batch_size = 5;
1809 
1810  DummyChunkDataReader data_reader;
1811  samplers::SequentialSampler sampler(0);
1815  samplers::SequentialSampler>>
1816  dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1817  DummyChunkDataReader,
1818  samplers::SequentialSampler,
1819  samplers::SequentialSampler>>(
1820  data_reader,
1821  sampler,
1822  sampler,
1823  datasets::ChunkDatasetOptions(prefetch_count, batch_size));
1824 
1825  samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler();
1826 
1827  auto data_loader = torch::data::make_data_loader(
1828  dataset.map(transforms::BatchLambda<std::vector<int>, int>(
1829  [](std::vector<int> batch) {
1830  return std::accumulate(batch.begin(), batch.end(), 0);
1831  })),
1832  DataLoaderOptions(batch_size).workers(0));
1833 
1834  // before we start, the index should be 0.
1835  ASSERT_EQ(chunk_sampler.index(), 0);
1836 
1837  size_t sum = 0;
1838  for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1839  ++iterator) {
1840  sum += *iterator;
1841  }
1842  ASSERT_EQ(sum, 595); // sum([0, 35))
1843  // 3 chunks, and when exhausted the value is already incremented.
1844  ASSERT_EQ(chunk_sampler.index(), 3);
1845 }
1846 
1847 TEST(DataLoaderTest, ChunkDatasetDoesNotHang) {
1848  const size_t prefetch_count = 2;
1849  const size_t batch_size = 5;
1850  // this will make the preloaders to wait till the `get_batch()` calls.
1851  const size_t cache_size = 10;
1852 
1853  DummyChunkDataReader data_reader;
1854  samplers::SequentialSampler sampler(0);
1858  samplers::SequentialSampler>>
1859  dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1860  DummyChunkDataReader,
1861  samplers::SequentialSampler,
1862  samplers::SequentialSampler>>(
1863  data_reader,
1864  sampler,
1865  sampler,
1867  prefetch_count, batch_size, cache_size));
1868 
1869  samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler();
1870 
1871  auto data_loader = torch::data::make_data_loader(
1872  dataset.map(transforms::BatchLambda<std::vector<int>, int>(
1873  [](std::vector<int> batch) {
1874  return std::accumulate(batch.begin(), batch.end(), 0);
1875  })),
1876  DataLoaderOptions(batch_size).workers(0));
1877  // simply creates the iterator but no iteration. chunk preloaders are waiting
1878  // to fill the batch buffer but it is not draining. Still we need to exit
1879  // cleanly.
1880  auto iterator = data_loader->begin();
1881 }
size_t next_sequence_number_
The monotonically increasing sequence number we expect.
Definition: sequencers.h:105
Interface for chunk reader, which performs data chunking and reading of entire chunks.
Definition: chunk.h:16
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized.
Definition: dataloader.cpp:615
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized.
Definition: dataloader.cpp:796
void save(torch::serialize::OutputArchive &archive) const override
Serializes the Sampler to the archive.
Definition: dataloader.cpp:814
A stateful dataset that support hierarchical sampling and prefetching of entre chunks.
Definition: chunk.h:283
TORCH_API void reset(optional< size_t > new_size=nullopt) override
Resets the RandomSampler to a new set of indices.
Definition: random.cpp:15
void load(torch::serialize::InputArchive &archive) override
Deserializes the Sampler from the archive.
Definition: dataloader.cpp:815
TORCH_API size_t index() const noexcept
Returns the current index of the DistributedSequentialSampler.
TensorExample get(size_t index) override
Returns a single TensorExample.
Definition: tensor.h:24
void reset() override
This will clear any internal state associate with this reader.
Definition: dataloader.cpp:88
AT_CPP14_CONSTEXPR const T & front() const
front - Get the first element.
Definition: ArrayRef.h:143
MapDataset< Self, TransformType > map(TransformType transform)&
Creates a MapDataset that applies the given transform to this dataset.
Definition: base.h:57
BatchType read_chunk(size_t chunk_index) override
Read an entire chunk.
Definition: dataloader.cpp:71
TORCH_API size_t index() const noexcept
Returns the current index of the DistributedRandomSampler.
Definition: distributed.cpp:92
TORCH_API optional< BatchSize > next(size_t batch_size) override
Returns a BatchSize object with the number of elements to fetch in the next batch.
Definition: stream.cpp:30
Like DataLoaderOptions, but without any unconfigured state.
Job pop_job()
Returns the next job, blocking until there is one available.
Definition: data_shuttle.h:41
An Example from a dataset.
Definition: example.h:12
TORCH_API optional< std::vector< size_t > > next(size_t batch_size) override
Returns the next batch of indices.
Definition: sequential.cpp:21
A Sequencer that does not enforce any ordering.
Definition: sequencers.h:38
A dataset of tensors.
Definition: tensor.h:16
TORCH_API optional< std::vector< size_t > > next(size_t batch_size) override
Returns the next batch of indices.
Options to configure a ChunkDataset.
Definition: chunk.h:237
std::vector< ExampleType > get_batch(ArrayRef< size_t > indices) override
Returns a batch of data.
Definition: base.h:86
void set_epoch(size_t epoch)
Set the epoch for the current enumeration.
Definition: distributed.h:41
A dataset that can yield data only in batches.
Definition: base.h:40
torch::Tensor operator()(torch::Tensor input) override
Transforms a single input tensor to an output tensor.
Definition: dataloader.cpp:483
optional< Result > next(ResultProducer next_result) override
Buffers results until the next one in the expected order is received.
Definition: sequencers.h:71
size_t chunk_count() override
Returns the number of chunks available in this reader.
Definition: dataloader.cpp:84
size_t clear()
Empties the queue and returns the number of elements that were present at the start of the function...
Definition: queue.h:68
A Collation for Example<Tensor, NoTarget> types that stacks all data tensors into one tensor...
Definition: stack.h:36
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized.
Definition: dataloader.cpp:495
torch::optional< size_t > size() const noexcept
Returns the size of the dataset, or an empty optional if it is unsized.
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized.
Definition: dataloader.cpp:743
A Sequencer that buffers results and returns them in order of their sequence number.
Definition: sequencers.h:63
TORCH_API optional< std::vector< size_t > > next(size_t batch_size) override
Returns the next batch of indices.
Definition: distributed.cpp:27
Options to configure a DataLoader.
A Sampler that returns random indices.
Definition: random.h:22
std::vector< int > get_batch(TestIndex index) override
Returns a batch of data given an index.
Definition: dataloader.cpp:789
A Sampler is an object that yields an index with which to access a dataset.
Definition: base.h:23
A BatchTransform that applies a user-provided functor to a batch.
Definition: lambda.h:15
optional< Result > pop_result(optional< std::chrono::milliseconds > timeout=nullopt)
Returns the result of a job, or nullopt if all jobs were exhausted.
Definition: data_shuttle.h:47
A dataset that can yield data in batches, or as individual examples.
Definition: base.h:76
void push_job(Job job)
Pushes a new job. Called by the main thread.
Definition: data_shuttle.h:29
Normalizes input tensors by subtracting the supplied mean and dividing by the given standard deviatio...
Definition: tensor.h:57
Encapsulates the full life cycle of DataLoader jobs.
Definition: data_shuttle.h:26
TORCH_API void reset(optional< size_t > new_size=nullopt) override
Resets the internal state of the sampler.
Definition: stream.cpp:23
A stateful dataset is a dataset that maintains some internal state, which will be reset() at the begi...
Definition: stateful.h:28
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized.
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized.
Definition: dataloader.cpp:534
A dataset that wraps another dataset in a shared pointer and implements the BatchDataset API...
Definition: shared.h:21
A basic locked, blocking MPMC queue.
Definition: queue.h:27
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized.
Definition: dataloader.cpp:36
An exception thrown when a DataLoader&#39;s worker thread throws an exception, which is caught...
void push(T value)
Pushes a new value to the back of the Queue and notifies one thread on the waiting side about this ev...
Definition: queue.h:31
void push_result(Result result)
Pushes the result of a job. Called by worker threads.
Definition: data_shuttle.h:35
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41
TORCH_API void reset(optional< size_t > new_size=nullopt) override
Resets the SequentialSampler to zero.
Definition: sequential.cpp:14
A Sampler that returns indices sequentially.
Definition: sequential.h:22
TORCH_API size_t index() const noexcept
Returns the current index of the RandomSampler.
Definition: random.cpp:67
A sampler for (potentially infinite) streams of data.
Definition: stream.h:35
std::exception_ptr original_exception
The original exception thrown in the worker thread.
A transformation of a batch to a new batch.
Definition: base.h:14
std::vector< optional< Result > > buffer_
A fixed-size buffer (after construction).
Definition: sequencers.h:108
size_t size() const override
The number of elements accessed by this index.
Definition: dataloader.cpp:777
A Lambda specialized for the typical Example<Tensor, Tensor> input type.
Definition: tensor.h:37
torch::optional< size_t > size() const override
Returns the size of the dataset, or an empty optional if it is unsized.
Definition: dataloader.cpp:143
T pop(optional< std::chrono::milliseconds > timeout=nullopt)
Blocks until at least one element is ready to be popped from the front of the queue.
Definition: queue.h:43
A base class for custom index types.
Definition: static.cpp:70
void drain()
Discards any jobs that are not yet in flight, and waits for all in-flight jobs to finish...
Definition: data_shuttle.h:59
A transformation of individual input examples to individual output examples.
Definition: base.h:32
torch::optional< TestIndex > next(size_t batch_size) override
Returns the next index if possible, or an empty optional if the sampler is exhausted for this epoch...
Definition: dataloader.cpp:805
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32
TORCH_API size_t index() const noexcept
Returns the current index of the SequentialSampler.
Definition: sequential.cpp:49
A Transform that is specialized for the typical Example<Tensor, Tensor> combination.
Definition: tensor.h:18
TORCH_API optional< std::vector< size_t > > next(size_t batch_size) override
Returns the next batch of indices.
Definition: random.cpp:23
void reset(torch::optional< size_t > new_size=torch::nullopt) override
Resets the Sampler&#39;s internal state.
Definition: dataloader.cpp:804