Caffe2 - C++ API
A deep learning, cross platform ML framework
data_shuttle.h
1 #pragma once
2 
3 #include <torch/data/detail/queue.h>
4 #include <torch/types.h>
5 
6 #include <c10/util/Exception.h>
7 #include <c10/util/Optional.h>
8 
9 #include <chrono>
10 #include <utility>
11 
12 namespace torch {
13 namespace data {
14 namespace detail {
15 
25 template <typename Job, typename Result>
26 class DataShuttle {
27  public:
29  void push_job(Job job) {
30  new_jobs_.push(std::move(job));
31  ++in_flight_jobs_;
32  }
33 
35  void push_result(Result result) {
36  results_.push(std::move(result));
37  }
38 
41  Job pop_job() {
42  return new_jobs_.pop();
43  }
44 
48  optional<std::chrono::milliseconds> timeout = nullopt) {
49  if (in_flight_jobs_ > 0) {
50  auto result = results_.pop(timeout);
51  --in_flight_jobs_;
52  return result;
53  }
54  return nullopt;
55  }
56 
59  void drain() {
60  // Clear all inputs so that no further jobs are scheduled.
61  auto number_cleared = new_jobs_.clear();
62  in_flight_jobs_ -= number_cleared;
63  // Remove any outstanding results.
64  while (in_flight_jobs_ > 0) {
65  pop_result();
66  }
67  }
68 
71  size_t in_flight_jobs() const noexcept {
72  return in_flight_jobs_;
73  }
74 
75  private:
77  Queue<Job> new_jobs_;
80  size_t in_flight_jobs_ = 0;
82  Queue<Result> results_;
83 };
84 
85 } // namespace detail
86 } // namespace data
87 } // namespace torch
Job pop_job()
Returns the next job, blocking until there is one available.
Definition: data_shuttle.h:41
size_t in_flight_jobs() const noexcept
Returns the number of jobs that are still in progress.
Definition: data_shuttle.h:71
size_t clear()
Empties the queue and returns the number of elements that were present at the start of the function...
Definition: queue.h:68
optional< Result > pop_result(optional< std::chrono::milliseconds > timeout=nullopt)
Returns the result of a job, or nullopt if all jobs were exhausted.
Definition: data_shuttle.h:47
void push_job(Job job)
Pushes a new job. Called by the main thread.
Definition: data_shuttle.h:29
Encapsulates the full life cycle of DataLoader jobs.
Definition: data_shuttle.h:26
Definition: jit_type.h:17
void push(T value)
Pushes a new value to the back of the Queue and notifies one thread on the waiting side about this ev...
Definition: queue.h:31
void push_result(Result result)
Pushes the result of a job. Called by worker threads.
Definition: data_shuttle.h:35
T pop(optional< std::chrono::milliseconds > timeout=nullopt)
Blocks until at least one element is ready to be popped from the front of the queue.
Definition: queue.h:43
void drain()
Discards any jobs that are not yet in flight, and waits for all in-flight jobs to finish...
Definition: data_shuttle.h:59