Caffe2 - C++ API
A deep learning, cross platform ML framework
event.h
1 #ifndef CAFFE2_CORE_EVENT_H_
2 #define CAFFE2_CORE_EVENT_H_
3 
4 #include <chrono>
5 
6 #include <c10/core/DeviceType.h>
7 #include "caffe2/core/common.h"
8 #include "caffe2/core/logging.h"
9 #include "caffe2/proto/caffe2_pb.h"
10 
11 namespace caffe2 {
12 
13 constexpr int MaxDeviceTypes =
14  DeviceTypeProto::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES;
15 class Event;
16 
17 enum EventStatus {
18  EVENT_INITIALIZED = 0,
19  EVENT_SCHEDULED = 1,
20  EVENT_SUCCESS = 2,
21  EVENT_FAILED = 3,
22 };
23 
24 // For the following functions, void* shall be interpreted as the corresponding
25 // context object corresponding to the device type associated with the
26 // functions.
27 
28 // Initializes event
29 typedef void (*EventCreateFunction)(const DeviceOption& option, Event*);
30 
31 // Called on event to signal that CPU part of operation is finished,
32 // Optionally accepts error message from CPU part.
33 // Should be called no more than once per event
34 typedef void (*EventRecordFunction)(Event*, const void*, const char*);
35 
36 // Waits and returns as soon as possible in order schedule next operation,
37 // e.g. for CUDA->CUDA waits only for CPU part of CUDA op,
38 // for CUDA->CPU waits till the CUDA op is fully completed.
39 // Prepares context to synchronize device part of operation.
40 // Can be called concurrently from multiple threads
41 typedef void (*EventWaitFunction)(const Event*, void*);
42 
43 // Waits till operation is fully finished,
44 // can be called concurrently from multiple threads
45 typedef void (*EventFinishFunction)(const Event*);
46 
47 // Queries current status of operation,
48 // can be called concurrently from multiple threads
49 typedef EventStatus (*EventQueryFunction)(const Event*);
50 typedef const std::string& (*EventErrorMessageFunction)(const Event*);
51 typedef void (*EventSetFinishedFunction)(const Event*, const char*);
52 typedef void (*EventResetFunction)(Event*);
53 
54 // Sets callback that is called when event is finished
55 typedef std::function<void()> EventCallbackFunction;
56 typedef void (*EventSetCallbackFunction)(Event*, EventCallbackFunction);
57 
58 class CAFFE2_API Event {
59  public:
60  explicit Event(const DeviceOption& option)
61  : event_(), type_(option.device_type()), option_(option) {
62  CAFFE_ENFORCE_LT(type_, MaxDeviceTypes);
63  CAFFE_ENFORCE(event_creator_[type_]);
64  event_creator_[type_](option, this);
65  }
66 
67  // Nothing needs to be done in the destructor, as the event creator should
68  // set the proper destruction process for the unique_ptr.
69  ~Event() {}
70 
71  void Record(
72  DeviceType recorder_type,
73  const void* context,
74  const char* err_msg = nullptr) {
75  auto recorder_index = TypeToProto(recorder_type);
76  CAFFE_ENFORCE_EQ(
77  recorder_index,
78  type_,
79  "You are trying to record with a wrong device type.");
80  CAFFE_ENFORCE(event_recorder_[recorder_index]);
81  event_recorder_[recorder_index](this, context, err_msg);
82  }
83 
84  void Wait(DeviceType waiter_type, void* context) const {
85  auto waiter_index = TypeToProto(waiter_type);
86  CAFFE_ENFORCE(event_waiter_[waiter_index][type_]);
87  event_waiter_[waiter_index][type_](this, context);
88  }
89 
90  void Finish() const {
91  CAFFE_ENFORCE(event_finisher_[type_]);
92  event_finisher_[type_](this);
93  }
94 
95  EventStatus Query() const {
96  CAFFE_ENFORCE(event_querier_[type_]);
97  return event_querier_[type_](this);
98  }
99 
100  const std::string& ErrorMessage() const {
101  CAFFE_ENFORCE(event_err_msg_getter_[type_]);
102  return event_err_msg_getter_[type_](this);
103  }
104 
105  void Reset() {
106  CAFFE_ENFORCE(event_resetter_[type_]);
107  event_resetter_[type_](this);
108 #ifdef CAFFE2_USE_EXCEPTION_PTR
109  caught_exception_ = nullptr;
110  exception_timestamp_ = 0;
111 #endif // CAFFE2_USE_EXCEPTION_PTR
112  }
113 
114  const DeviceOption& GetDeviceOption() const {
115  return option_;
116  }
117 
118  bool IsScheduled() const {
119  return Query() == EventStatus::EVENT_SCHEDULED;
120  }
121 
122  bool IsFinished() const {
123  auto status = Query();
124  return status == EventStatus::EVENT_SUCCESS ||
125  status == EventStatus::EVENT_FAILED;
126  }
127 
128  void SetFinished(const char* err_msg = nullptr) {
129  CAFFE_ENFORCE(event_finished_setter_[type_]);
130  return event_finished_setter_[type_](this, err_msg);
131  }
132 
133  bool SupportsCallback() const {
134  return event_callback_setter_[type_] != nullptr;
135  }
136 
137  void SetCallback(EventCallbackFunction callback) {
138  CAFFE_ENFORCE(
139  event_callback_setter_[type_], "Event does not support callbacks");
140  event_callback_setter_[type_](this, callback);
141  }
142 
143  // If parent op has succeeded, then we can run any child op;
144  // If parent op is in scheduled state, we need to check that:
145  // - child op supports async scheduling
146  // - there's a way to setup synchronization between async parent and
147  // child - both child and parent should use the same type of device,
148  // non-blocking synchronization between different device types is not
149  // supported
150  // If parent op is in another state (initialized or failed) then scheduling
151  // is not possible
152  bool CanSchedule(const Event& child_event, bool supports_async) const {
153  return CanSchedule(type_, Query(), child_event.GetType(), supports_async);
154  }
155 
156  static bool CanSchedule(
157  int parent_type,
158  EventStatus parent_status,
159  int child_type,
160  bool child_supports_async) {
161  if (parent_status == EventStatus::EVENT_SUCCESS) {
162  return true;
163  }
164  if (parent_status == EventStatus::EVENT_SCHEDULED) {
165  return (parent_type == child_type) && child_supports_async;
166  }
167  return false;
168  }
169 
170  int GetType() const {
171  return type_;
172  }
173 
174  void SetFinishedWithException(const char* err_msg = nullptr) {
175 #ifdef CAFFE2_USE_EXCEPTION_PTR
176  if (!caught_exception_) {
177  caught_exception_ = std::current_exception();
178  typedef std::chrono::high_resolution_clock clock;
179  exception_timestamp_ =
180  clock::now().time_since_epoch() / std::chrono::milliseconds(1);
181  }
182  CAFFE_ENFORCE(caught_exception_, "No exception found");
183 #else
184  VLOG(1) << "No support for exceptions in Event";
185 #endif // CAFFE2_USE_EXCEPTION_PTR
186  if (err_msg) {
187  SetFinished(err_msg);
188  } else {
189  SetFinished("Error happened during an operator run");
190  }
191  }
192 
193  bool HasException() const {
194 #ifdef CAFFE2_USE_EXCEPTION_PTR
195  return (bool)caught_exception_;
196 #else
197  VLOG(1) << "No support for exceptions in Event";
198  return false;
199 #endif // CAFFE2_USE_EXCEPTION_PTR
200  }
201 
202  int64_t ExceptionTimestamp() const {
203 #ifdef CAFFE2_USE_EXCEPTION_PTR
204  return exception_timestamp_;
205 #else
206  VLOG(1) << "No support for exceptions in Event";
207  return 0;
208 #endif // CAFFE2_USE_EXCEPTION_PTR
209  }
210 
211  void RethrowException() const {
212 #ifdef CAFFE2_USE_EXCEPTION_PTR
213  if (caught_exception_) {
214  std::rethrow_exception(caught_exception_);
215  }
216 #else
217  VLOG(1) << "No support for exceptions in Event";
218 #endif // CAFFE2_USE_EXCEPTION_PTR
219  }
220 
221  // event_ is going to be accessed by the EventCreate/Record/Wait/Finish
222  // functions, but one should not use it outside the own Event functionalities.
223  // In the future we may move it to a private member.
224  std::shared_ptr<void> event_;
225 
226  private:
227  int type_;
228  DeviceOption option_;
229 
230 #ifdef CAFFE2_USE_EXCEPTION_PTR
231  std::exception_ptr caught_exception_;
232  int64_t exception_timestamp_{};
233 #endif // CAFFE2_USE_EXCEPTION_PTR
234 
235  static EventCreateFunction event_creator_[MaxDeviceTypes];
236  static EventRecordFunction event_recorder_[MaxDeviceTypes];
237  static EventWaitFunction event_waiter_[MaxDeviceTypes]
238  [MaxDeviceTypes];
239  static EventFinishFunction event_finisher_[MaxDeviceTypes];
240 
241  static EventQueryFunction event_querier_[MaxDeviceTypes];
242  static EventErrorMessageFunction
243  event_err_msg_getter_[MaxDeviceTypes];
244  static EventSetFinishedFunction
245  event_finished_setter_[MaxDeviceTypes];
246  static EventResetFunction event_resetter_[MaxDeviceTypes];
247 
248  static EventSetCallbackFunction event_callback_setter_[MaxDeviceTypes];
249 
250  template <DeviceType t>
251  friend struct EventCreateFunctionRegisterer;
252  template <DeviceType t>
253  friend struct EventRecordFunctionRegisterer;
254  template <DeviceType w, DeviceType d>
255  friend struct EventWaitFunctionRegisterer;
256  template <DeviceType t>
257  friend struct EventFinishFunctionRegisterer;
258 
259  template <DeviceType t>
260  friend struct EventQueryFunctionRegisterer;
261  template <DeviceType t>
263  template <DeviceType t>
265  template <DeviceType t>
267  template <DeviceType t>
268  friend struct EventResetFunctionRegisterer;
269 };
270 
271 template <DeviceType t>
273  explicit EventCreateFunctionRegisterer(EventCreateFunction f) {
274  auto d = TypeToProto(t);
275  Event::event_creator_[d] = f;
276  }
277 };
278 #define REGISTER_EVENT_CREATE_FUNCTION(t, f) \
279  namespace { \
280  static EventCreateFunctionRegisterer<t> g_event_create_##d(f); \
281  }
282 
283 template <DeviceType t>
285  explicit EventRecordFunctionRegisterer(EventRecordFunction f) {
286  auto d = TypeToProto(t);
287  Event::event_recorder_[d] = f;
288  }
289 };
290 #define REGISTER_EVENT_RECORD_FUNCTION(t, f) \
291  namespace { \
292  static EventRecordFunctionRegisterer<t> g_event_record_##d(f); \
293  }
294 
295 template <DeviceType waiter_type, DeviceType event_type>
297  explicit EventWaitFunctionRegisterer(EventWaitFunction f) {
298  auto waiter_index = TypeToProto(waiter_type);
299  auto event_index = TypeToProto(event_type);
300  Event::event_waiter_[waiter_index][event_index] = f;
301  }
302 };
303 #define REGISTER_EVENT_WAIT_FUNCTION(w, d, f) \
304  namespace { \
305  static EventWaitFunctionRegisterer<w, d> g_event_wait_##w##_##d(f); \
306  }
307 
308 template <DeviceType t>
310  explicit EventQueryFunctionRegisterer(EventQueryFunction f) {
311  auto d = TypeToProto(t);
312  Event::event_querier_[d] = f;
313  }
314 };
315 #define REGISTER_EVENT_QUERY_FUNCTION(t, f) \
316  namespace { \
317  static EventQueryFunctionRegisterer<t> g_event_query_##d(f); \
318  }
319 
320 template <DeviceType t>
322  explicit EventErrorMessageFunctionRegisterer(EventErrorMessageFunction f) {
323  auto d = TypeToProto(t);
324  Event::event_err_msg_getter_[d] = f;
325  }
326 };
327 #define REGISTER_EVENT_ERROR_MESSAGE_FUNCTION(t, f) \
328  namespace { \
329  static EventErrorMessageFunctionRegisterer<t> g_event_err_msg_##d(f); \
330  }
331 
332 template <DeviceType t>
334  explicit EventSetFinishedFunctionRegisterer(EventSetFinishedFunction f) {
335  auto d = TypeToProto(t);
336  Event::event_finished_setter_[d] = f;
337  }
338 };
339 #define REGISTER_EVENT_SET_FINISHED_FUNCTION(t, f) \
340  namespace { \
341  static EventSetFinishedFunctionRegisterer<t> g_event_set_finished_##d(f); \
342  }
343 
344 template <DeviceType t>
346  explicit EventSetCallbackFunctionRegisterer(EventSetCallbackFunction f) {
347  auto d = TypeToProto(t);
348  Event::event_callback_setter_[d] = f;
349  }
350 };
351 #define REGISTER_EVENT_SET_CALLBACK_FUNCTION(t, f) \
352  namespace { \
353  static EventSetCallbackFunctionRegisterer<t> g_event_set_callback_##d(f); \
354  }
355 
356 template <DeviceType t>
358  explicit EventFinishFunctionRegisterer(EventFinishFunction f) {
359  auto d = TypeToProto(t);
360  Event::event_finisher_[d] = f;
361  }
362 };
363 #define REGISTER_EVENT_FINISH_FUNCTION(t, f) \
364  namespace { \
365  static EventFinishFunctionRegisterer<t> g_event_finish_##d(f); \
366  }
367 
368 template <DeviceType t>
370  explicit EventResetFunctionRegisterer(EventResetFunction f) {
371  auto d = TypeToProto(t);
372  Event::event_resetter_[d] = f;
373  }
374 };
375 #define REGISTER_EVENT_RESET_FUNCTION(t, f) \
376  namespace { \
377  static EventResetFunctionRegisterer<t> g_event_reset_##d(f); \
378  }
379 
380 } // namespace caffe2
381 
382 #endif // CAFFE2_CORE_EVENT_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13