1 #ifndef CAFFE2_CORE_EVENT_H_ 2 #define CAFFE2_CORE_EVENT_H_ 6 #include <c10/core/DeviceType.h> 7 #include "caffe2/core/common.h" 8 #include "caffe2/core/logging.h" 9 #include "caffe2/proto/caffe2_pb.h" 13 constexpr
int MaxDeviceTypes =
14 DeviceTypeProto::PROTO_COMPILE_TIME_MAX_DEVICE_TYPES;
18 EVENT_INITIALIZED = 0,
29 typedef void (*EventCreateFunction)(
const DeviceOption& option, Event*);
34 typedef void (*EventRecordFunction)(Event*,
const void*,
const char*);
41 typedef void (*EventWaitFunction)(
const Event*,
void*);
45 typedef void (*EventFinishFunction)(
const Event*);
49 typedef EventStatus (*EventQueryFunction)(
const Event*);
50 typedef const std::string& (*EventErrorMessageFunction)(
const Event*);
51 typedef void (*EventSetFinishedFunction)(
const Event*,
const char*);
52 typedef void (*EventResetFunction)(Event*);
55 typedef std::function<void()> EventCallbackFunction;
56 typedef void (*EventSetCallbackFunction)(Event*, EventCallbackFunction);
58 class CAFFE2_API Event {
60 explicit Event(
const DeviceOption& option)
61 : event_(), type_(option.device_type()), option_(option) {
62 CAFFE_ENFORCE_LT(type_, MaxDeviceTypes);
63 CAFFE_ENFORCE(event_creator_[type_]);
64 event_creator_[type_](option,
this);
72 DeviceType recorder_type,
74 const char* err_msg =
nullptr) {
75 auto recorder_index = TypeToProto(recorder_type);
79 "You are trying to record with a wrong device type.");
80 CAFFE_ENFORCE(event_recorder_[recorder_index]);
81 event_recorder_[recorder_index](
this, context, err_msg);
84 void Wait(DeviceType waiter_type,
void* context)
const {
85 auto waiter_index = TypeToProto(waiter_type);
86 CAFFE_ENFORCE(event_waiter_[waiter_index][type_]);
87 event_waiter_[waiter_index][type_](
this, context);
91 CAFFE_ENFORCE(event_finisher_[type_]);
92 event_finisher_[type_](
this);
95 EventStatus Query()
const {
96 CAFFE_ENFORCE(event_querier_[type_]);
97 return event_querier_[type_](
this);
100 const std::string& ErrorMessage()
const {
101 CAFFE_ENFORCE(event_err_msg_getter_[type_]);
102 return event_err_msg_getter_[type_](
this);
106 CAFFE_ENFORCE(event_resetter_[type_]);
107 event_resetter_[type_](
this);
108 #ifdef CAFFE2_USE_EXCEPTION_PTR 109 caught_exception_ =
nullptr;
110 exception_timestamp_ = 0;
111 #endif // CAFFE2_USE_EXCEPTION_PTR 114 const DeviceOption& GetDeviceOption()
const {
118 bool IsScheduled()
const {
119 return Query() == EventStatus::EVENT_SCHEDULED;
122 bool IsFinished()
const {
123 auto status = Query();
124 return status == EventStatus::EVENT_SUCCESS ||
125 status == EventStatus::EVENT_FAILED;
128 void SetFinished(
const char* err_msg =
nullptr) {
129 CAFFE_ENFORCE(event_finished_setter_[type_]);
130 return event_finished_setter_[type_](
this, err_msg);
133 bool SupportsCallback()
const {
134 return event_callback_setter_[type_] !=
nullptr;
137 void SetCallback(EventCallbackFunction callback) {
139 event_callback_setter_[type_],
"Event does not support callbacks");
140 event_callback_setter_[type_](
this, callback);
152 bool CanSchedule(
const Event& child_event,
bool supports_async)
const {
153 return CanSchedule(type_, Query(), child_event.GetType(), supports_async);
156 static bool CanSchedule(
158 EventStatus parent_status,
160 bool child_supports_async) {
161 if (parent_status == EventStatus::EVENT_SUCCESS) {
164 if (parent_status == EventStatus::EVENT_SCHEDULED) {
165 return (parent_type == child_type) && child_supports_async;
170 int GetType()
const {
174 void SetFinishedWithException(
const char* err_msg =
nullptr) {
175 #ifdef CAFFE2_USE_EXCEPTION_PTR 176 if (!caught_exception_) {
177 caught_exception_ = std::current_exception();
178 typedef std::chrono::high_resolution_clock clock;
179 exception_timestamp_ =
180 clock::now().time_since_epoch() / std::chrono::milliseconds(1);
182 CAFFE_ENFORCE(caught_exception_,
"No exception found");
184 VLOG(1) <<
"No support for exceptions in Event";
185 #endif // CAFFE2_USE_EXCEPTION_PTR 187 SetFinished(err_msg);
189 SetFinished(
"Error happened during an operator run");
193 bool HasException()
const {
194 #ifdef CAFFE2_USE_EXCEPTION_PTR 195 return (
bool)caught_exception_;
197 VLOG(1) <<
"No support for exceptions in Event";
199 #endif // CAFFE2_USE_EXCEPTION_PTR 202 int64_t ExceptionTimestamp()
const {
203 #ifdef CAFFE2_USE_EXCEPTION_PTR 204 return exception_timestamp_;
206 VLOG(1) <<
"No support for exceptions in Event";
208 #endif // CAFFE2_USE_EXCEPTION_PTR 211 void RethrowException()
const {
212 #ifdef CAFFE2_USE_EXCEPTION_PTR 213 if (caught_exception_) {
214 std::rethrow_exception(caught_exception_);
217 VLOG(1) <<
"No support for exceptions in Event";
218 #endif // CAFFE2_USE_EXCEPTION_PTR 224 std::shared_ptr<void> event_;
228 DeviceOption option_;
230 #ifdef CAFFE2_USE_EXCEPTION_PTR 231 std::exception_ptr caught_exception_;
232 int64_t exception_timestamp_{};
233 #endif // CAFFE2_USE_EXCEPTION_PTR 235 static EventCreateFunction event_creator_[MaxDeviceTypes];
236 static EventRecordFunction event_recorder_[MaxDeviceTypes];
237 static EventWaitFunction event_waiter_[MaxDeviceTypes]
239 static EventFinishFunction event_finisher_[MaxDeviceTypes];
241 static EventQueryFunction event_querier_[MaxDeviceTypes];
242 static EventErrorMessageFunction
243 event_err_msg_getter_[MaxDeviceTypes];
244 static EventSetFinishedFunction
245 event_finished_setter_[MaxDeviceTypes];
246 static EventResetFunction event_resetter_[MaxDeviceTypes];
248 static EventSetCallbackFunction event_callback_setter_[MaxDeviceTypes];
250 template <DeviceType t>
252 template <DeviceType t>
254 template <DeviceType w, DeviceType d>
256 template <DeviceType t>
259 template <DeviceType t>
261 template <DeviceType t>
263 template <DeviceType t>
265 template <DeviceType t>
267 template <DeviceType t>
271 template <DeviceType t>
274 auto d = TypeToProto(t);
275 Event::event_creator_[d] = f;
278 #define REGISTER_EVENT_CREATE_FUNCTION(t, f) \ 280 static EventCreateFunctionRegisterer<t> g_event_create_##d(f); \ 283 template <DeviceType t>
286 auto d = TypeToProto(t);
287 Event::event_recorder_[d] = f;
290 #define REGISTER_EVENT_RECORD_FUNCTION(t, f) \ 292 static EventRecordFunctionRegisterer<t> g_event_record_##d(f); \ 295 template <DeviceType waiter_type, DeviceType event_type>
298 auto waiter_index = TypeToProto(waiter_type);
299 auto event_index = TypeToProto(event_type);
300 Event::event_waiter_[waiter_index][event_index] = f;
303 #define REGISTER_EVENT_WAIT_FUNCTION(w, d, f) \ 305 static EventWaitFunctionRegisterer<w, d> g_event_wait_##w##_##d(f); \ 308 template <DeviceType t>
311 auto d = TypeToProto(t);
312 Event::event_querier_[d] = f;
315 #define REGISTER_EVENT_QUERY_FUNCTION(t, f) \ 317 static EventQueryFunctionRegisterer<t> g_event_query_##d(f); \ 320 template <DeviceType t>
323 auto d = TypeToProto(t);
324 Event::event_err_msg_getter_[d] = f;
327 #define REGISTER_EVENT_ERROR_MESSAGE_FUNCTION(t, f) \ 329 static EventErrorMessageFunctionRegisterer<t> g_event_err_msg_##d(f); \ 332 template <DeviceType t>
335 auto d = TypeToProto(t);
336 Event::event_finished_setter_[d] = f;
339 #define REGISTER_EVENT_SET_FINISHED_FUNCTION(t, f) \ 341 static EventSetFinishedFunctionRegisterer<t> g_event_set_finished_##d(f); \ 344 template <DeviceType t>
347 auto d = TypeToProto(t);
348 Event::event_callback_setter_[d] = f;
351 #define REGISTER_EVENT_SET_CALLBACK_FUNCTION(t, f) \ 353 static EventSetCallbackFunctionRegisterer<t> g_event_set_callback_##d(f); \ 356 template <DeviceType t>
359 auto d = TypeToProto(t);
360 Event::event_finisher_[d] = f;
363 #define REGISTER_EVENT_FINISH_FUNCTION(t, f) \ 365 static EventFinishFunctionRegisterer<t> g_event_finish_##d(f); \ 368 template <DeviceType t>
371 auto d = TypeToProto(t);
372 Event::event_resetter_[d] = f;
375 #define REGISTER_EVENT_RESET_FUNCTION(t, f) \ 377 static EventResetFunctionRegisterer<t> g_event_reset_##d(f); \ 382 #endif // CAFFE2_CORE_EVENT_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...