Caffe2 - C++ API
A deep learning, cross platform ML framework
prefetch_op.h
1 #ifndef CAFFE2_OPERATORS_PREFETCH_OP_H_
2 #define CAFFE2_OPERATORS_PREFETCH_OP_H_
3 
4 #include <condition_variable>
5 #include <mutex>
6 #include <thread> // NOLINT
7 
8 #include "caffe2/core/context.h"
9 #include "caffe2/core/operator.h"
10 
11 namespace caffe2 {
12 
13 // PrefetchOperator is an operator that prefetches the next batch. It should
14 // almost always be used to read things from disk, so I am setting the input to
15 // zero blobs.
16 //
17 // For any operator that is derived from PrefetchOperator, it should
18 // explicitly call the Finalize() function in its destructor, so that the
19 // prefetching thread is properly destructed.
20 
21 // Note: We inherit from OperatorBase since we control the
22 // synchronization properties of this operator ourselves (we inform
23 // the waiting producer after we synchronize). This is a special-case
24 // - you should generally inherit from Operator<Context> directly.
25 template <class Context>
27  public:
28  PrefetchOperator(const OperatorDef& operator_def, Workspace* ws)
29  : OperatorBase(operator_def, ws),
30  context_(operator_def.device_option()),
31  prefetched_(false),
32  prefetch_success_(true),
33  finalize_(false),
34  no_prefetch_(GetSingleArgument<bool>("no_prefetch", false)) {
35  context_.SwitchToDevice();
36  }
37 
38  virtual ~PrefetchOperator() noexcept {
39  CHECK(finalize_ || !prefetch_thread_.get()) <<
40  "YOU MADE A PROGRAMING ERROR: derived class of PrefetchOperator "
41  "should call Finalize() in its destructor so the prefetching "
42  "thread is joined. ";
43  }
44 
45  void Finalize() {
46  if (prefetch_thread_.get()) {
47  {
48  std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
49  while (!prefetched_)
50  consumer_.wait(lock);
51  finalize_ = true;
52  prefetched_ = false;
53  }
54  producer_.notify_one();
55  prefetch_thread_->join();
56  prefetch_thread_.reset();
57  } else {
58  // If we never initialized the prefetch thread, just set
59  // finalize anyway.
60  finalize_ = true;
61  }
62  }
63 
64  bool Run(int /* unused */ /*stream_id*/) override {
65  if (no_prefetch_) {
66  context_.SwitchToDevice();
67  bool result = Prefetch() && CopyPrefetched();
68  context_.FinishDeviceComputation();
69  return result;
70  }
71  // Note(jiayq): We only start the prefetch_thread at the Run() function
72  // instead of in the constructor, because the prefetch_thread needs to start
73  // after all derived classes' constructors finish.
74  if (!prefetch_thread_) {
75  prefetch_thread_.reset(
76  new std::thread([this] { this->PrefetchWorker(); }));
77  }
78  context_.SwitchToDevice();
79  std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
80  while (!prefetched_)
81  consumer_.wait(lock);
82  if (!prefetch_success_) {
83  LOG(ERROR) << "Prefetching failed.";
84  return false;
85  }
86  if (!CopyPrefetched()) {
87  LOG(ERROR) << "Error when copying prefetched data.";
88  return false;
89  }
90  prefetched_ = false;
91  context_.FinishDeviceComputation();
92  producer_.notify_one();
93  return true;
94  }
95 
96  void PrefetchWorker() {
97  context_.SwitchToDevice();
98  std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
99  while (prefetched_)
100  producer_.wait(lock);
101  while (!finalize_) {
102  // We will need to run a FinishDeviceComputation() call because the
103  // prefetcher thread and the main thread are potentially using different
104  // streams (like on GPU).
105  try {
106  prefetch_success_ = Prefetch();
107  context_.FinishDeviceComputation();
108  } catch (const std::exception& e) {
109  // TODO: propagate exception_ptr to the caller side
110  LOG(ERROR) << "Prefetching error " << e.what();
111  prefetch_success_ = false;
112  }
113  prefetched_ = true;
114  consumer_.notify_one();
115  while (prefetched_)
116  producer_.wait(lock);
117  }
118  }
119 
120  // You will need to implement this instead of the Run function.
121  virtual bool Prefetch() = 0;
122  virtual bool CopyPrefetched() = 0;
123 
124  protected:
125  Context context_;
126  std::mutex prefetch_access_mutex_;
127  std::condition_variable producer_, consumer_;
128  // prefetched_ is used to tell the operator that it is done.
129  std::atomic<bool> prefetched_;
130  // prefetch_success_ is used to see if prefetching failed or not.
131  std::atomic<bool> prefetch_success_;
132  // finalize_ is used to tell the prefetcher to quit.
133  std::atomic<bool> finalize_;
134  unique_ptr<std::thread> prefetch_thread_;
135 
136  // Whether to do prefetching or run this as a normal operator
137  const bool no_prefetch_;
138 };
139 
140 } // namespace caffe2
141 
142 #endif // CAFFE2_OPERATORS_PREFETCH_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13