3 #include <torch/data/datasets/stateful.h>    15 template <
typename Chunk = std::vector<Example<>>>
    18   using ChunkType = Chunk;
    21   virtual ChunkType 
read_chunk(
size_t chunk_index) = 0;
    27   virtual void reset() = 0;
    37     typename UnwrappedBatch = std::vector<Example<>>,
    41   using UnwrappedBatchType = UnwrappedBatch;
    43   using BatchRequestType = 
typename ExampleSampler::BatchRequestType;
    47       ExampleSampler& example_sampler,
    48       size_t queue_capacity)
    49       : batch_size_(batch_size),
    50         example_sampler_(example_sampler),
    51         queue_capacity_(queue_capacity) {}
    56     std::unique_lock<std::mutex> lock(queue_mutex_);
    57     cv_read_.wait(lock, [
this] {
    61           this->total_example_count_in_queue_ >= batch_size_ ||
    64     if (batch_queue_.empty()) {
    76     total_example_count_in_queue_ -= batch.
batch_data.size();
    78     cv_write_.notify_all();
    86     std::unique_lock<std::mutex> lock(queue_mutex_);
    87     cv_write_.wait(lock, [
this] {
    89       return this->total_example_count_in_queue_ < this->queue_capacity_ ||
    98     auto data_size = data.size();
    99     auto remaining_size = data_size;
   100     example_sampler_.reset(data_size);
   102     auto fill_batch = [&](
size_t example_count, UnwrappedBatchType& batch) {
   103       auto batch_example_indices = this->example_sampler_.next(example_count);
   105           batch_example_indices &&
   106           batch_example_indices.value().size() == example_count)
   107       BatchRequestType& indices = batch_example_indices.value();
   108       for (
size_t i : indices) {
   109         AT_CHECK(i < data_size, 
"Index out of range");
   110         batch.emplace_back(std::move(data[i]));
   112       remaining_size -= example_count;
   115     if (!batch_queue_.empty()) {
   118       auto& batch = batch_queue_.back();
   119       size_t current_count = batch.batch_data.size();
   120       if (current_count < batch_size_) {
   122             std::min(remaining_size, batch_size_ - current_count);
   123         fill_batch(example_count, batch.batch_data);
   129     while (remaining_size > 0) {
   130       UnwrappedBatchType current_batch;
   133       current_batch.reserve(batch_size_);
   135       auto example_count = std::min(remaining_size, batch_size_);
   136       fill_batch(example_count, current_batch);
   137       batch_queue_.emplace(std::move(current_batch));
   139     total_example_count_in_queue_ += data_size;
   141     cv_read_.notify_all();
   147     std::unique_lock<std::mutex> lock(queue_mutex_);
   148     cv_write_.wait(lock, [
this] {
   151         this->total_example_count_in_queue_ < this->queue_capacity_ ||
   161     batch_queue_.emplace(e_ptr);
   163     cv_read_.notify_all();
   179       std::lock_guard<std::mutex> lock(queue_mutex_);
   184     cv_write_.notify_all();
   186     cv_read_.notify_all();
   191   size_t batch_size_ = 0;
   194   size_t total_example_count_in_queue_ = 0;
   216   std::mutex queue_mutex_;
   218   std::condition_variable cv_read_;
   219   std::condition_variable cv_write_;
   221   ExampleSampler& example_sampler_;
   224   size_t queue_capacity_;
   240       size_t preloader_count,
   242       size_t cache_size = 2048)
   243       : preloader_count_(preloader_count),
   244         batch_size_(batch_size),
   245         cache_size_(cache_size) {
   247         preloader_count_ > 0,
   248         "Preloader count is 0. At least one preloader needs to be specified.");
   251         "Batch size is 0. A positive batch size needs to be specified.");
   254         "Cache size is 0. A positive cache size needs to be specified.");
   256         cache_size_ >= batch_size_,
   257         "Cache size is less than batch size. Cache needs to be large enough to "   258         "hold at least one batch.");
   262   TORCH_ARG(
size_t, preloader_count);
   265   TORCH_ARG(
size_t, batch_size);
   268   TORCH_ARG(
size_t, cache_size) = 2048;
   280     typename ChunkReader,
   285           ChunkDataset<ChunkReader, ChunkSampler, ExampleSampler>,
   286           typename ChunkReader::BatchType,
   290   using UnwrappedBatchType = 
typename ChunkReader::BatchType;
   291   using BatchRequestType = size_t;
   292   using ChunkSamplerType = ChunkSampler;
   293   using ExampleSamplerType = ExampleSampler;
   296       ChunkReader chunk_reader,
   297       ChunkSampler chunk_sampler,
   298       ExampleSampler example_sampler,
   300       : chunk_reader_(std::move(chunk_reader)),
   301         chunk_sampler_(std::move(chunk_sampler)),
   302         example_sampler_(std::move(example_sampler)),
   303         options_(std::move(options)),
   305         running_preloaders_(0) {}
   310       batch_buffer_->stop();
   321       batch_buffer_ != 
nullptr,
   322       "Dataset needs to call reset() before calling get_batch().");
   325       batch_size == options_.batch_size_,
   326       "The requested batch size does not match with the initialized batch size.\n"   327       " The requested batch size is ", batch_size,
   328       ", while the dataset is created with batch size equal to ", options_.batch_size_);
   330     return batch_buffer_->get_batch();
   338       batch_buffer_->stop();
   342     preload_threads_.clear();
   344     chunk_reader_.reset();
   346     chunk_sampler_.reset(chunk_reader_.chunk_count());
   350     batch_buffer_ = torch::make_unique<
   352         options_.batch_size_,
   354         options_.cache_size_);
   357     quit_worker_ = 
false;
   359     AT_ASSERT(running_preloaders_ == 0);
   360     running_preloaders_ = options_.preloader_count_;
   361     for (
size_t i = 0; i < options_.preloader_count_; ++i) {
   362       preload_threads_.emplace_back([
this, i]() { this->preloader(i); });
   368     return torch::nullopt;
   373   ChunkSamplerType& chunk_sampler() {
   374     return chunk_sampler_;
   379   void preloader(
size_t id) {
   380     while (!quit_worker_.load()) {
   384           std::lock_guard<std::mutex> lock(chunk_index_guard_);
   385           if (
auto chunk_sampler_result = chunk_sampler_.next(1)) {
   386             chunk_id = chunk_sampler_result.value()[0];
   391         UnwrappedBatchType data = chunk_reader_.read_chunk(chunk_id);
   393           batch_buffer_->add_chunk_data(std::move(data));
   396         batch_buffer_->add_chunk_data(std::current_exception());
   399     AT_ASSERT(running_preloaders_.load() > 0);
   400     --running_preloaders_;
   401     if (running_preloaders_.load() == 0) {
   403       batch_buffer_->stop();
   408   void free_workers() {
   409     if (!quit_worker_.load()) {
   411       for (
auto& worker_thread : preload_threads_) {
   412         worker_thread.join();
   421   ChunkReader chunk_reader_;
   424   ChunkSamplerType chunk_sampler_;
   427   ExampleSamplerType example_sampler_;
   430   std::shared_ptr<detail::BatchDataBuffer<UnwrappedBatchType, ExampleSamplerType>>
   434   std::vector<std::thread> preload_threads_;
   440   std::atomic<bool> quit_worker_;
   444   std::atomic<size_t> running_preloaders_;
   447   std::mutex chunk_index_guard_;
 Interface for chunk reader, which performs data chunking and reading of entire chunks. 
 
optional< size_t > size() const override
size is not used for chunk dataset. 
 
A stateful dataset that support hierarchical sampling and prefetching of entre chunks. 
 
BatchDataBuffer manages a queue of UnwrappedBatchData. 
 
virtual ChunkType read_chunk(size_t chunk_index)=0
Read an entire chunk. 
 
void reset() override
This will clear any internal state and starts the internal prefetching mechanism for the chunk datase...
 
Options to configure a ChunkDataset. 
 
virtual void reset()=0
This will clear any internal state associate with this reader. 
 
std::exception_ptr exception
exception pointer which captures any abnormal exceptions while creating the batch. 
 
A Sampler that returns random indices. 
 
virtual size_t chunk_count()=0
Returns the number of chunks available in this reader. 
 
void add_chunk_data(UnwrappedBatchType data)
Push preloaded chunks to batch queue. 
 
UnwrappedBatchType batch_data
batch data to return 
 
BatchType get_batch()
Return batch data from the queue. 
 
A stateful dataset is a dataset that maintains some internal state, which will be reset() at the begi...
 
An exception thrown when a DataLoader's worker thread throws an exception, which is caught...
 
BatchType get_batch(size_t batch_size) override
Default get_batch method of BatchDataset. 
 
struct that contains a raw unwrapped batch unit. 
 
void add_chunk_data(std::exception_ptr e_ptr)
Push exceptions thrown during preloading into batch queue. 
 
std::queue< UnwrappedBatchData > batch_queue_
local cache to store example batches from loaded chunk