3 #include <torch/data/datasets/stateful.h> 15 template <
typename Chunk = std::vector<Example<>>>
18 using ChunkType = Chunk;
21 virtual ChunkType
read_chunk(
size_t chunk_index) = 0;
27 virtual void reset() = 0;
37 typename UnwrappedBatch = std::vector<Example<>>,
41 using UnwrappedBatchType = UnwrappedBatch;
43 using BatchRequestType =
typename ExampleSampler::BatchRequestType;
47 ExampleSampler& example_sampler,
48 size_t queue_capacity)
49 : batch_size_(batch_size),
50 example_sampler_(example_sampler),
51 queue_capacity_(queue_capacity) {}
56 std::unique_lock<std::mutex> lock(queue_mutex_);
57 cv_read_.wait(lock, [
this] {
61 this->total_example_count_in_queue_ >= batch_size_ ||
64 if (batch_queue_.empty()) {
76 total_example_count_in_queue_ -= batch.
batch_data.size();
78 cv_write_.notify_all();
86 std::unique_lock<std::mutex> lock(queue_mutex_);
87 cv_write_.wait(lock, [
this] {
89 return this->total_example_count_in_queue_ < this->queue_capacity_ ||
98 auto data_size = data.size();
99 auto remaining_size = data_size;
100 example_sampler_.reset(data_size);
102 auto fill_batch = [&](
size_t example_count, UnwrappedBatchType& batch) {
103 auto batch_example_indices = this->example_sampler_.next(example_count);
105 batch_example_indices &&
106 batch_example_indices.value().size() == example_count)
107 BatchRequestType& indices = batch_example_indices.value();
108 for (
size_t i : indices) {
109 AT_CHECK(i < data_size,
"Index out of range");
110 batch.emplace_back(std::move(data[i]));
112 remaining_size -= example_count;
115 if (!batch_queue_.empty()) {
118 auto& batch = batch_queue_.back();
119 size_t current_count = batch.batch_data.size();
120 if (current_count < batch_size_) {
122 std::min(remaining_size, batch_size_ - current_count);
123 fill_batch(example_count, batch.batch_data);
129 while (remaining_size > 0) {
130 UnwrappedBatchType current_batch;
133 current_batch.reserve(batch_size_);
135 auto example_count = std::min(remaining_size, batch_size_);
136 fill_batch(example_count, current_batch);
137 batch_queue_.emplace(std::move(current_batch));
139 total_example_count_in_queue_ += data_size;
141 cv_read_.notify_all();
147 std::unique_lock<std::mutex> lock(queue_mutex_);
148 cv_write_.wait(lock, [
this] {
151 this->total_example_count_in_queue_ < this->queue_capacity_ ||
161 batch_queue_.emplace(e_ptr);
163 cv_read_.notify_all();
179 std::lock_guard<std::mutex> lock(queue_mutex_);
184 cv_write_.notify_all();
186 cv_read_.notify_all();
191 size_t batch_size_ = 0;
194 size_t total_example_count_in_queue_ = 0;
216 std::mutex queue_mutex_;
218 std::condition_variable cv_read_;
219 std::condition_variable cv_write_;
221 ExampleSampler& example_sampler_;
224 size_t queue_capacity_;
240 size_t preloader_count,
242 size_t cache_size = 2048)
243 : preloader_count_(preloader_count),
244 batch_size_(batch_size),
245 cache_size_(cache_size) {
247 preloader_count_ > 0,
248 "Preloader count is 0. At least one preloader needs to be specified.");
251 "Batch size is 0. A positive batch size needs to be specified.");
254 "Cache size is 0. A positive cache size needs to be specified.");
256 cache_size_ >= batch_size_,
257 "Cache size is less than batch size. Cache needs to be large enough to " 258 "hold at least one batch.");
262 TORCH_ARG(
size_t, preloader_count);
265 TORCH_ARG(
size_t, batch_size);
268 TORCH_ARG(
size_t, cache_size) = 2048;
280 typename ChunkReader,
285 ChunkDataset<ChunkReader, ChunkSampler, ExampleSampler>,
286 typename ChunkReader::BatchType,
290 using UnwrappedBatchType =
typename ChunkReader::BatchType;
291 using BatchRequestType = size_t;
292 using ChunkSamplerType = ChunkSampler;
293 using ExampleSamplerType = ExampleSampler;
296 ChunkReader chunk_reader,
297 ChunkSampler chunk_sampler,
298 ExampleSampler example_sampler,
300 : chunk_reader_(std::move(chunk_reader)),
301 chunk_sampler_(std::move(chunk_sampler)),
302 example_sampler_(std::move(example_sampler)),
303 options_(std::move(options)),
305 running_preloaders_(0) {}
310 batch_buffer_->stop();
321 batch_buffer_ !=
nullptr,
322 "Dataset needs to call reset() before calling get_batch().");
325 batch_size == options_.batch_size_,
326 "The requested batch size does not match with the initialized batch size.\n" 327 " The requested batch size is ", batch_size,
328 ", while the dataset is created with batch size equal to ", options_.batch_size_);
330 return batch_buffer_->get_batch();
338 batch_buffer_->stop();
342 preload_threads_.clear();
344 chunk_reader_.reset();
346 chunk_sampler_.reset(chunk_reader_.chunk_count());
350 batch_buffer_ = torch::make_unique<
352 options_.batch_size_,
354 options_.cache_size_);
357 quit_worker_ =
false;
359 AT_ASSERT(running_preloaders_ == 0);
360 running_preloaders_ = options_.preloader_count_;
361 for (
size_t i = 0; i < options_.preloader_count_; ++i) {
362 preload_threads_.emplace_back([
this, i]() { this->preloader(i); });
368 return torch::nullopt;
373 ChunkSamplerType& chunk_sampler() {
374 return chunk_sampler_;
379 void preloader(
size_t id) {
380 while (!quit_worker_.load()) {
384 std::lock_guard<std::mutex> lock(chunk_index_guard_);
385 if (
auto chunk_sampler_result = chunk_sampler_.next(1)) {
386 chunk_id = chunk_sampler_result.value()[0];
391 UnwrappedBatchType data = chunk_reader_.read_chunk(chunk_id);
393 batch_buffer_->add_chunk_data(std::move(data));
396 batch_buffer_->add_chunk_data(std::current_exception());
399 AT_ASSERT(running_preloaders_.load() > 0);
400 --running_preloaders_;
401 if (running_preloaders_.load() == 0) {
403 batch_buffer_->stop();
408 void free_workers() {
409 if (!quit_worker_.load()) {
411 for (
auto& worker_thread : preload_threads_) {
412 worker_thread.join();
421 ChunkReader chunk_reader_;
424 ChunkSamplerType chunk_sampler_;
427 ExampleSamplerType example_sampler_;
430 std::shared_ptr<detail::BatchDataBuffer<UnwrappedBatchType, ExampleSamplerType>>
434 std::vector<std::thread> preload_threads_;
440 std::atomic<bool> quit_worker_;
444 std::atomic<size_t> running_preloaders_;
447 std::mutex chunk_index_guard_;
Interface for chunk reader, which performs data chunking and reading of entire chunks.
optional< size_t > size() const override
size is not used for chunk dataset.
A stateful dataset that support hierarchical sampling and prefetching of entre chunks.
BatchDataBuffer manages a queue of UnwrappedBatchData.
virtual ChunkType read_chunk(size_t chunk_index)=0
Read an entire chunk.
void reset() override
This will clear any internal state and starts the internal prefetching mechanism for the chunk datase...
Options to configure a ChunkDataset.
virtual void reset()=0
This will clear any internal state associate with this reader.
std::exception_ptr exception
exception pointer which captures any abnormal exceptions while creating the batch.
A Sampler that returns random indices.
virtual size_t chunk_count()=0
Returns the number of chunks available in this reader.
void add_chunk_data(UnwrappedBatchType data)
Push preloaded chunks to batch queue.
UnwrappedBatchType batch_data
batch data to return
BatchType get_batch()
Return batch data from the queue.
A stateful dataset is a dataset that maintains some internal state, which will be reset() at the begi...
An exception thrown when a DataLoader's worker thread throws an exception, which is caught...
BatchType get_batch(size_t batch_size) override
Default get_batch method of BatchDataset.
struct that contains a raw unwrapped batch unit.
void add_chunk_data(std::exception_ptr e_ptr)
Push exceptions thrown during preloading into batch queue.
std::queue< UnwrappedBatchData > batch_queue_
local cache to store example batches from loaded chunk