3 #include <torch/data/dataloader_options.h> 4 #include <torch/data/detail/data_shuttle.h> 5 #include <torch/data/detail/sequencers.h> 6 #include <torch/data/iterator.h> 7 #include <torch/data/samplers/random.h> 8 #include <torch/data/worker_exception.h> 9 #include <torch/types.h> 11 #include <torch/csrc/utils/memory.h> 12 #include <torch/csrc/utils/variadic.h> 14 #include <c10/util/Exception.h> 20 #include <type_traits> 26 template <
typename Dataset,
typename Batch,
typename BatchRequest>
29 using BatchType = Batch;
30 using BatchRequestType = BatchRequest;
37 std::unique_ptr<Dataset> main_thread_dataset =
nullptr)
60 "Attempted to get a new DataLoader iterator " 61 "while another iterator is not yet exhausted");
63 return Iterator<Batch>(torch::make_unique<detail::ValidIterator<Batch>>(
64 [
this] {
return this->
next(); }));
71 torch::make_unique<detail::SentinelIterator<Batch>>());
85 for (
size_t w = 0; w <
options_.workers; ++w) {
103 Sequenced(
size_t sqn) : sequence_number(sqn) {}
104 size_t sequence_number;
114 Job(BatchRequest&& i,
size_t sqn)
115 :
Sequenced(sqn), batch_request(std::move(i)) {}
125 Result(std::exception_ptr exception,
size_t sqn)
126 :
Sequenced(sqn), exception(std::move(exception)) {}
128 std::exception_ptr exception;
148 for (
size_t r = 0; r < requested_jobs; ++r) {
150 this->
push_job(std::move(*batch_request));
168 if (result->exception) {
170 }
else if (result->batch) {
172 return std::move(result->batch);
189 auto batch = dataset.get_batch(std::move(*job.batch_request));
190 shuttle_.push_result({std::move(batch), job.sequence_number});
192 shuttle_.push_result({std::current_exception(), job.sequence_number});
199 template <
typename T>
214 return torch::make_unique<detail::sequencers::OrderedSequencer<Result>>(
217 return torch::make_unique<detail::sequencers::NoSequencer<Result>>();
240 std::unique_ptr<detail::sequencers::Sequencer<Result>>
sequencer_;
void prefetch()
Schedules the maximum number of jobs (based on the max_jobs option).
void join()
Joins the DataLoader's worker threads and drains internal queues.
size_t sequence_number_
The sequence number for the next batch to be retrieved from the dataset.
Like DataLoaderOptions, but without any unconfigured state.
const FullDataLoaderOptions & options() const noexcept
Returns the options with which the DataLoader was configured.
bool joined_
True if the DataLoader has joined its worker threads.
const FullDataLoaderOptions options_
The options the DataLoader was configured with.
Simple mix-in to give something a sequence number.
The finished result of a job.
virtual optional< BatchRequestType > get_batch_request()=0
Subclass hook for getting the next batch request.
Options to configure a DataLoader.
Iterator< Batch > end()
Returns a special "sentinel" iterator that compares equal with a non-sentinel iterator once the DataL...
DataLoaderBase(DataLoaderOptions options, std::unique_ptr< Dataset > main_thread_dataset=nullptr)
Constructs a new DataLoader from a dataset to sample from, options to configure the DataLoader with...
Encapsulates the full life cycle of DataLoader jobs.
virtual void reset()
Resets the internal state of the DataLoader, optionally pre-fetching new jobs.
void prefetch(size_t requested_jobs)
Schedules requested_jobs many new batches to be fetched.
optional< Result > pop_result()
Convenience method that gets the next result from the sequencer.
An exception thrown when a DataLoader's worker thread throws an exception, which is caught...
detail::DataShuttle< Job, Result > shuttle_
The DataShuttle which takes care of the life cycle of a job.
std::vector< std::thread > workers_
The worker threads, running the worker_thread() method.
std::unique_ptr< detail::sequencers::Sequencer< Result > > sequencer_
The Sequencer, which handles optional ordering of batches.
Iterator< Batch > begin()
Returns an iterator into the DataLoader.
void push_job(T value)
Convenience method that calls shuttle_.push_job() with the next sequence number.
optional< BatchType > next()
Returns the next batch of data, or an empty optional if the DataLoader is exhausted.
std::unique_ptr< detail::sequencers::Sequencer< Result > > new_sequencer()
Convenience method that creates a new sequencer based on the enforce_ordering option.
A Job is either a BatchRequest (new indices to fetch data at) or a QuitWorker object, to indicate the worker should shut down.
std::unique_ptr< Dataset > main_thread_dataset_
The dataset for the main thread, only has a value if the number of worker threads was configured as z...
void worker_thread(Dataset &dataset)
The function that worker threads run.