Caffe2 - C++ API
A deep learning, cross platform ML framework
event.h
1 
17 #ifndef CAFFE2_CORE_EVENT_H_
18 #define CAFFE2_CORE_EVENT_H_
19 
20 #include "caffe2/core/common.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/proto/caffe2.pb.h"
23 
24 namespace caffe2 {
25 
26 constexpr int MaxDeviceTypes = DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES;
27 class Event;
28 
29 enum EventStatus {
30  EVENT_INITIALIZED = 0,
31  EVENT_SCHEDULED = 1,
32  EVENT_SUCCESS = 2,
33  EVENT_FAILED = 3,
34 };
35 
36 // For the following functions, void* shall be interpreted as the corresponding
37 // context object corresponding to the device type associated with the
38 // functions.
39 
40 // Initializes event
41 typedef void (*EventCreateFunction)(const DeviceOption& option, Event*);
42 
43 // Called on event to signal that CPU part of operation is finished,
44 // Optionally accepts error message from CPU part.
45 // Should be called no more than once per event
46 typedef void (*EventRecordFunction)(Event*, const void*, const char*);
47 
48 // Waits and returns as soon as possible in order schedule next operation,
49 // e.g. for CUDA->CUDA waits only for CPU part of CUDA op,
50 // for CUDA->CPU waits till the CUDA op is fully completed.
51 // Prepares context to synchronize device part of operation.
52 // Can be called concurrently from multiple threads
53 typedef void (*EventWaitFunction)(const Event*, void*);
54 
55 // Waits till operation is fully finished,
56 // can be called concurrently from multiple threads
57 typedef void (*EventFinishFunction)(const Event*);
58 
59 // Queries current status of operation,
60 // can be called concurrently from multiple threads
61 typedef EventStatus (*EventQueryFunction)(const Event*);
62 typedef const std::string& (*EventErrorMessageFunction)(const Event*);
63 typedef void (*EventSetFinishedFunction)(const Event*, const char*);
64 typedef void (*EventResetFunction)(Event*);
65 
66 class Event {
67  public:
68  explicit Event(const DeviceOption& option)
69  : event_(), type_(option.device_type()), option_(option) {
70  CAFFE_ENFORCE_LT(type_, MaxDeviceTypes);
71  CAFFE_ENFORCE(event_creator_[type_]);
72  event_creator_[type_](option, this);
73  }
74 
75  // Nothing needs to be done in the destructor, as the event creator should
76  // set the proper destruction process for the unique_ptr.
77  ~Event() {}
78 
79  void Record(
80  int recorder_type,
81  const void* context,
82  const char* err_msg = nullptr) {
83  CAFFE_ENFORCE_EQ(
84  recorder_type,
85  type_,
86  "You are trying to record with a wrong device type.");
87  CAFFE_ENFORCE(event_recorder_[recorder_type]);
88  event_recorder_[recorder_type](this, context, err_msg);
89  }
90 
91  void Wait(int waiter_type, void* context) const {
92  CAFFE_ENFORCE(event_waiter_[waiter_type][type_]);
93  event_waiter_[waiter_type][type_](this, context);
94  }
95 
96  void Finish() const {
97  CAFFE_ENFORCE(event_finisher_[type_]);
98  event_finisher_[type_](this);
99  }
100 
101  EventStatus Query() const {
102  CAFFE_ENFORCE(event_querier_[type_]);
103  return event_querier_[type_](this);
104  }
105 
106  const std::string& ErrorMessage() const {
107  CAFFE_ENFORCE(event_err_msg_getter_[type_]);
108  return event_err_msg_getter_[type_](this);
109  }
110 
111  void Reset() {
112  CAFFE_ENFORCE(event_resetter_[type_]);
113  event_resetter_[type_](this);
114  }
115 
116  const DeviceOption& GetDeviceOption() const {
117  return option_;
118  }
119 
120  bool IsScheduled() const {
121  return Query() == EventStatus::EVENT_SCHEDULED;
122  }
123 
124  bool IsFinished() const {
125  auto status = Query();
126  return status == EventStatus::EVENT_SUCCESS ||
127  status == EventStatus::EVENT_FAILED;
128  }
129 
130  void SetFinished(const char* err_msg = nullptr) {
131  CAFFE_ENFORCE(event_finished_setter_[type_]);
132  return event_finished_setter_[type_](this, err_msg);
133  }
134 
135  // If parent op has succeeded, then we can run any child op;
136  // If parent op is in scheduled state, we need to check that:
137  // - child op supports async scheduling
138  // - there's a way to setup synchronization between async parent and
139  // child - both child and parent should use the same type of device,
140  // non-blocking synchronization between different device types is not
141  // supported
142  // If parent op is in another state (initialized or failed) then scheduling
143  // is not possible
144  bool CanSchedule(const Event& child_event, bool supports_async) const {
145  return CanSchedule(type_, Query(), child_event.GetType(), supports_async);
146  }
147 
148  static bool CanSchedule(
149  int parent_type,
150  EventStatus parent_status,
151  int child_type,
152  bool child_supports_async) {
153  if (parent_status == EventStatus::EVENT_SUCCESS) {
154  return true;
155  }
156  if (parent_status == EventStatus::EVENT_SCHEDULED) {
157  return (parent_type == child_type) && child_supports_async;
158  }
159  return false;
160  }
161 
162  int GetType() const {
163  return type_;
164  }
165 
166  // event_ is going to be accessed by the EventCreate/Record/Wait/Finish
167  // functions, but one should not use it outside the own Event functionalities.
168  // In the future we may move it to a private member.
169  std::shared_ptr<void> event_;
170 
171  private:
172  int type_;
173  DeviceOption option_;
174 
175  CAFFE2_API static EventCreateFunction event_creator_[MaxDeviceTypes];
176  CAFFE2_API static EventRecordFunction event_recorder_[MaxDeviceTypes];
177  CAFFE2_API static EventWaitFunction event_waiter_[MaxDeviceTypes]
178  [MaxDeviceTypes];
179  CAFFE2_API static EventFinishFunction event_finisher_[MaxDeviceTypes];
180 
181  CAFFE2_API static EventQueryFunction event_querier_[MaxDeviceTypes];
182  CAFFE2_API static EventErrorMessageFunction
183  event_err_msg_getter_[MaxDeviceTypes];
184  CAFFE2_API static EventSetFinishedFunction
185  event_finished_setter_[MaxDeviceTypes];
186  CAFFE2_API static EventResetFunction event_resetter_[MaxDeviceTypes];
187 
188  template <int d>
189  friend struct EventCreateFunctionRegisterer;
190  template <int d>
191  friend struct EventRecordFunctionRegisterer;
192  template <int w, int d>
193  friend struct EventWaitFunctionRegisterer;
194  template <int d>
195  friend struct EventFinishFunctionRegisterer;
196 
197  template <int d>
198  friend struct EventQueryFunctionRegisterer;
199  template <int d>
201  template <int d>
203  template <int d>
204  friend struct EventResetFunctionRegisterer;
205 };
206 
207 template <int d>
209  explicit EventCreateFunctionRegisterer(EventCreateFunction f) {
210  static_assert(d < MaxDeviceTypes, "");
211  Event::event_creator_[d] = f;
212  }
213 };
214 #define REGISTER_EVENT_CREATE_FUNCTION(d, f) \
215  namespace { \
216  static EventCreateFunctionRegisterer<d> g_event_create_##d(f); \
217  }
218 
219 template <int d>
221  explicit EventRecordFunctionRegisterer(EventRecordFunction f) {
222  static_assert(d < MaxDeviceTypes, "");
223  Event::event_recorder_[d] = f;
224  }
225 };
226 #define REGISTER_EVENT_RECORD_FUNCTION(d, f) \
227  namespace { \
228  static EventRecordFunctionRegisterer<d> g_event_record_##d(f); \
229  }
230 
231 template <int waiter_type, int event_type>
233  explicit EventWaitFunctionRegisterer(EventWaitFunction f) {
234  static_assert(waiter_type < MaxDeviceTypes, "");
235  static_assert(event_type < MaxDeviceTypes, "");
236  Event::event_waiter_[waiter_type][event_type] = f;
237  }
238 };
239 #define REGISTER_EVENT_WAIT_FUNCTION(w, d, f) \
240  namespace { \
241  static EventWaitFunctionRegisterer<w, d> g_event_wait_##w##_##d(f); \
242  }
243 
244 template <int d>
246  explicit EventQueryFunctionRegisterer(EventQueryFunction f) {
247  static_assert(d < MaxDeviceTypes, "");
248  Event::event_querier_[d] = f;
249  }
250 };
251 #define REGISTER_EVENT_QUERY_FUNCTION(d, f) \
252  namespace { \
253  static EventQueryFunctionRegisterer<d> g_event_query_##d(f); \
254  }
255 
256 template <int d>
258  explicit EventErrorMessageFunctionRegisterer(EventErrorMessageFunction f) {
259  static_assert(d < MaxDeviceTypes, "");
260  Event::event_err_msg_getter_[d] = f;
261  }
262 };
263 #define REGISTER_EVENT_ERROR_MESSAGE_FUNCTION(d, f) \
264  namespace { \
265  static EventErrorMessageFunctionRegisterer<d> g_event_err_msg_##d(f); \
266  }
267 
268 template <int d>
270  explicit EventSetFinishedFunctionRegisterer(EventSetFinishedFunction f) {
271  static_assert(d < MaxDeviceTypes, "");
272  Event::event_finished_setter_[d] = f;
273  }
274 };
275 #define REGISTER_EVENT_SET_FINISHED_FUNCTION(d, f) \
276  namespace { \
277  static EventSetFinishedFunctionRegisterer<d> g_event_set_finished_##d(f); \
278  }
279 
280 template <int d>
282  explicit EventFinishFunctionRegisterer(EventFinishFunction f) {
283  static_assert(d < MaxDeviceTypes, "");
284  Event::event_finisher_[d] = f;
285  }
286 };
287 #define REGISTER_EVENT_FINISH_FUNCTION(d, f) \
288  namespace { \
289  static EventFinishFunctionRegisterer<d> g_event_finish_##d(f); \
290  }
291 
292 template <int d>
294  explicit EventResetFunctionRegisterer(EventResetFunction f) {
295  static_assert(d < MaxDeviceTypes, "");
296  Event::event_resetter_[d] = f;
297  }
298 };
299 #define REGISTER_EVENT_RESET_FUNCTION(d, f) \
300  namespace { \
301  static EventResetFunctionRegisterer<d> g_event_reset_##d(f); \
302  }
303 
304 } // namespace caffe2
305 
306 #endif // CAFFE2_CORE_EVENT_H_
Copyright (c) 2016-present, Facebook, Inc.