Caffe2 - C++ API
A deep learning, cross platform ML framework
prefetch_op.h
1 
17 #ifndef CAFFE2_OPERATORS_PREFETCH_OP_H_
18 #define CAFFE2_OPERATORS_PREFETCH_OP_H_
19 
20 #include <condition_variable>
21 #include <mutex>
22 #include <thread> // NOLINT
23 
24 #include "caffe2/core/context.h"
25 #include "caffe2/core/operator.h"
26 
27 namespace caffe2 {
28 
29 // PrefetchOperator is an operator that prefetches the next batch. It should
30 // almost always be used to read things from disk, so I am setting the input to
31 // zero blobs.
32 //
33 // For any operator that is derived from PrefetchOperator, it should
34 // explicitly call the Finalize() function in its destructor, so that the
35 // prefetching thread is properly destructed.
36 
37 // Note: We inherit from OperatorBase since we control the
38 // synchronization properties of this operator ourselves (we inform
39 // the waiting producer after we synchronize). This is a special-case
40 // - you should generally inherit from Operator<Context> directly.
41 template <class Context>
43  public:
44  PrefetchOperator(const OperatorDef& operator_def, Workspace* ws)
45  : OperatorBase(operator_def, ws),
46  context_(operator_def.device_option()),
47  prefetched_(false),
48  prefetch_success_(true),
49  finalize_(false),
50  no_prefetch_(GetSingleArgument<bool>("no_prefetch", false)) {
51  context_.SwitchToDevice(0);
52  }
53 
54  virtual ~PrefetchOperator() noexcept {
55  CHECK(finalize_ || !prefetch_thread_.get()) <<
56  "YOU MADE A PROGRAMING ERROR: derived class of PrefetchOperator "
57  "should call Finalize() in its destructor so the prefetching "
58  "thread is joined. ";
59  }
60 
61  void Finalize() {
62  if (prefetch_thread_.get()) {
63  {
64  std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
65  while (!prefetched_)
66  consumer_.wait(lock);
67  finalize_ = true;
68  prefetched_ = false;
69  }
70  producer_.notify_one();
71  prefetch_thread_->join();
72  prefetch_thread_.reset();
73  } else {
74  // If we never initialized the prefetch thread, just set
75  // finalize anyway.
76  finalize_ = true;
77  }
78  }
79 
80  bool Run(int /* unused */ /*stream_id*/) override {
81  if (no_prefetch_) {
82  context_.SwitchToDevice(0);
83  bool result = Prefetch() && CopyPrefetched();
84  context_.FinishDeviceComputation();
85  return result;
86  }
87  // Note(jiayq): We only start the prefetch_thread at the Run() function
88  // instead of in the constructor, because the prefetch_thread needs to start
89  // after all derived classes' constructors finish.
90  if (!prefetch_thread_) {
91  prefetch_thread_.reset(
92  new std::thread([this] { this->PrefetchWorker(); }));
93  }
94  context_.SwitchToDevice(0);
95  std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
96  while (!prefetched_)
97  consumer_.wait(lock);
98  if (!prefetch_success_) {
99  LOG(ERROR) << "Prefetching failed.";
100  return false;
101  }
102  if (!CopyPrefetched()) {
103  LOG(ERROR) << "Error when copying prefetched data.";
104  return false;
105  }
106  prefetched_ = false;
107  context_.FinishDeviceComputation();
108  producer_.notify_one();
109  return true;
110  }
111 
112  void PrefetchWorker() {
113  context_.SwitchToDevice();
114  std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
115  while (prefetched_)
116  producer_.wait(lock);
117  while (!finalize_) {
118  // We will need to run a FinishDeviceComputation() call because the
119  // prefetcher thread and the main thread are potentially using different
120  // streams (like on GPU).
121  try {
122  prefetch_success_ = Prefetch();
123  context_.FinishDeviceComputation();
124  } catch (const std::exception& e) {
125  // TODO: propagate exception_ptr to the caller side
126  LOG(ERROR) << "Prefetching error " << e.what();
127  prefetch_success_ = false;
128  }
129  prefetched_ = true;
130  consumer_.notify_one();
131  while (prefetched_)
132  producer_.wait(lock);
133  }
134  }
135 
136  // You will need to implement this instead of the Run function.
137  virtual bool Prefetch() = 0;
138  virtual bool CopyPrefetched() = 0;
139 
140  protected:
141  Context context_;
142  std::mutex prefetch_access_mutex_;
143  std::condition_variable producer_, consumer_;
144  // prefetched_ is used to tell the operator that it is done.
145  std::atomic<bool> prefetched_;
146  // prefetch_success_ is used to see if prefetching failed or not.
147  std::atomic<bool> prefetch_success_;
148  // finalize_ is used to tell the prefetcher to quit.
149  std::atomic<bool> finalize_;
150  unique_ptr<std::thread> prefetch_thread_;
151 
152  // Whether to do prefetching or run this as a normal operator
153  const bool no_prefetch_;
154 };
155 
156 } // namespace caffe2
157 
158 #endif // CAFFE2_OPERATORS_PREFETCH_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.