| #ifndef CAFFE2_CORE_EVENT_H_ |
| #define CAFFE2_CORE_EVENT_H_ |
| |
| #include "caffe2/core/common.h" |
| #include "caffe2/core/logging.h" |
| #include "caffe2/proto/caffe2.pb.h" |
| |
| namespace caffe2 { |
| |
| constexpr int MaxDeviceTypes = DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES; |
| class Event; |
| |
| enum EventStatus { |
| EVENT_INITIALIZED = 0, |
| EVENT_SCHEDULED = 1, |
| EVENT_SUCCESS = 2, |
| EVENT_FAILED = 3, |
| }; |
| |
| // For the following functions, void* shall be interpreted as the corresponding |
| // context object corresponding to the device type associated with the |
| // functions. |
| |
| // Initializes event |
| typedef void (*EventCreateFunction)(const DeviceOption& option, Event*); |
| |
| // Called on event to signal that CPU part of operation is finished, |
| // Optionally accepts error message from CPU part. |
| // Should be called no more than once per event |
| typedef void (*EventRecordFunction)(Event*, const void*, const char*); |
| |
| // Waits and returns as soon as possible in order schedule next operation, |
| // e.g. for CUDA->CUDA waits only for CPU part of CUDA op, |
| // for CUDA->CPU waits till the CUDA op is fully completed. |
| // Prepares context to synchronize device part of operation. |
| // Can be called concurrently from multiple threads |
| typedef void (*EventWaitFunction)(const Event*, void*); |
| |
| // Waits till operation is fully finished, |
| // can be called concurrently from multiple threads |
| typedef void (*EventFinishFunction)(const Event*); |
| |
| // Queries current status of operation, |
| // can be called concurrently from multiple threads |
| typedef EventStatus (*EventQueryFunction)(const Event*); |
| typedef const std::string& (*EventErrorMessageFunction)(const Event*); |
| typedef void (*EventSetFinishedFunction)(const Event*, const char*); |
| typedef void (*EventResetFunction)(Event*); |
| |
| // Sets callback that is called when event is finished |
| typedef std::function<void()> EventCallbackFunction; |
| typedef void (*EventSetCallbackFunction)(Event*, EventCallbackFunction); |
| |
| class Event { |
| public: |
| explicit Event(const DeviceOption& option) |
| : event_(), type_(option.device_type()), option_(option) { |
| CAFFE_ENFORCE_LT(type_, MaxDeviceTypes); |
| CAFFE_ENFORCE(event_creator_[type_]); |
| event_creator_[type_](option, this); |
| } |
| |
| // Nothing needs to be done in the destructor, as the event creator should |
| // set the proper destruction process for the unique_ptr. |
| ~Event() {} |
| |
| void Record( |
| int recorder_type, |
| const void* context, |
| const char* err_msg = nullptr) { |
| CAFFE_ENFORCE_EQ( |
| recorder_type, |
| type_, |
| "You are trying to record with a wrong device type."); |
| CAFFE_ENFORCE(event_recorder_[recorder_type]); |
| event_recorder_[recorder_type](this, context, err_msg); |
| } |
| |
| void Wait(int waiter_type, void* context) const { |
| CAFFE_ENFORCE(event_waiter_[waiter_type][type_]); |
| event_waiter_[waiter_type][type_](this, context); |
| } |
| |
| void Finish() const { |
| CAFFE_ENFORCE(event_finisher_[type_]); |
| event_finisher_[type_](this); |
| } |
| |
| EventStatus Query() const { |
| CAFFE_ENFORCE(event_querier_[type_]); |
| return event_querier_[type_](this); |
| } |
| |
| const std::string& ErrorMessage() const { |
| CAFFE_ENFORCE(event_err_msg_getter_[type_]); |
| return event_err_msg_getter_[type_](this); |
| } |
| |
| void Reset() { |
| CAFFE_ENFORCE(event_resetter_[type_]); |
| event_resetter_[type_](this); |
| } |
| |
| const DeviceOption& GetDeviceOption() const { |
| return option_; |
| } |
| |
| bool IsScheduled() const { |
| return Query() == EventStatus::EVENT_SCHEDULED; |
| } |
| |
| bool IsFinished() const { |
| auto status = Query(); |
| return status == EventStatus::EVENT_SUCCESS || |
| status == EventStatus::EVENT_FAILED; |
| } |
| |
| void SetFinished(const char* err_msg = nullptr) { |
| CAFFE_ENFORCE(event_finished_setter_[type_]); |
| return event_finished_setter_[type_](this, err_msg); |
| } |
| |
| bool SupportsCallback() const { |
| return event_callback_setter_[type_] != nullptr; |
| } |
| |
| void SetCallback(EventCallbackFunction callback) { |
| CAFFE_ENFORCE( |
| event_callback_setter_[type_], "Event does not support callbacks"); |
| event_callback_setter_[type_](this, callback); |
| } |
| |
| // If parent op has succeeded, then we can run any child op; |
| // If parent op is in scheduled state, we need to check that: |
| // - child op supports async scheduling |
| // - there's a way to setup synchronization between async parent and |
| // child - both child and parent should use the same type of device, |
| // non-blocking synchronization between different device types is not |
| // supported |
| // If parent op is in another state (initialized or failed) then scheduling |
| // is not possible |
| bool CanSchedule(const Event& child_event, bool supports_async) const { |
| return CanSchedule(type_, Query(), child_event.GetType(), supports_async); |
| } |
| |
| static bool CanSchedule( |
| int parent_type, |
| EventStatus parent_status, |
| int child_type, |
| bool child_supports_async) { |
| if (parent_status == EventStatus::EVENT_SUCCESS) { |
| return true; |
| } |
| if (parent_status == EventStatus::EVENT_SCHEDULED) { |
| return (parent_type == child_type) && child_supports_async; |
| } |
| return false; |
| } |
| |
| int GetType() const { |
| return type_; |
| } |
| |
| // event_ is going to be accessed by the EventCreate/Record/Wait/Finish |
| // functions, but one should not use it outside the own Event functionalities. |
| // In the future we may move it to a private member. |
| std::shared_ptr<void> event_; |
| |
| private: |
| int type_; |
| DeviceOption option_; |
| |
| CAFFE2_API static EventCreateFunction event_creator_[MaxDeviceTypes]; |
| CAFFE2_API static EventRecordFunction event_recorder_[MaxDeviceTypes]; |
| CAFFE2_API static EventWaitFunction event_waiter_[MaxDeviceTypes] |
| [MaxDeviceTypes]; |
| CAFFE2_API static EventFinishFunction event_finisher_[MaxDeviceTypes]; |
| |
| CAFFE2_API static EventQueryFunction event_querier_[MaxDeviceTypes]; |
| CAFFE2_API static EventErrorMessageFunction |
| event_err_msg_getter_[MaxDeviceTypes]; |
| CAFFE2_API static EventSetFinishedFunction |
| event_finished_setter_[MaxDeviceTypes]; |
| CAFFE2_API static EventResetFunction event_resetter_[MaxDeviceTypes]; |
| |
| CAFFE2_API static EventSetCallbackFunction |
| event_callback_setter_[MaxDeviceTypes]; |
| |
| template <int d> |
| friend struct EventCreateFunctionRegisterer; |
| template <int d> |
| friend struct EventRecordFunctionRegisterer; |
| template <int w, int d> |
| friend struct EventWaitFunctionRegisterer; |
| template <int d> |
| friend struct EventFinishFunctionRegisterer; |
| |
| template <int d> |
| friend struct EventQueryFunctionRegisterer; |
| template <int d> |
| friend struct EventErrorMessageFunctionRegisterer; |
| template <int d> |
| friend struct EventSetFinishedFunctionRegisterer; |
| template <int d> |
| friend struct EventSetCallbackFunctionRegisterer; |
| template <int d> |
| friend struct EventResetFunctionRegisterer; |
| }; |
| |
| template <int d> |
| struct EventCreateFunctionRegisterer { |
| explicit EventCreateFunctionRegisterer(EventCreateFunction f) { |
| static_assert(d < MaxDeviceTypes, ""); |
| Event::event_creator_[d] = f; |
| } |
| }; |
| #define REGISTER_EVENT_CREATE_FUNCTION(d, f) \ |
| namespace { \ |
| static EventCreateFunctionRegisterer<d> g_event_create_##d(f); \ |
| } |
| |
| template <int d> |
| struct EventRecordFunctionRegisterer { |
| explicit EventRecordFunctionRegisterer(EventRecordFunction f) { |
| static_assert(d < MaxDeviceTypes, ""); |
| Event::event_recorder_[d] = f; |
| } |
| }; |
| #define REGISTER_EVENT_RECORD_FUNCTION(d, f) \ |
| namespace { \ |
| static EventRecordFunctionRegisterer<d> g_event_record_##d(f); \ |
| } |
| |
| template <int waiter_type, int event_type> |
| struct EventWaitFunctionRegisterer { |
| explicit EventWaitFunctionRegisterer(EventWaitFunction f) { |
| static_assert(waiter_type < MaxDeviceTypes, ""); |
| static_assert(event_type < MaxDeviceTypes, ""); |
| Event::event_waiter_[waiter_type][event_type] = f; |
| } |
| }; |
| #define REGISTER_EVENT_WAIT_FUNCTION(w, d, f) \ |
| namespace { \ |
| static EventWaitFunctionRegisterer<w, d> g_event_wait_##w##_##d(f); \ |
| } |
| |
| template <int d> |
| struct EventQueryFunctionRegisterer { |
| explicit EventQueryFunctionRegisterer(EventQueryFunction f) { |
| static_assert(d < MaxDeviceTypes, ""); |
| Event::event_querier_[d] = f; |
| } |
| }; |
| #define REGISTER_EVENT_QUERY_FUNCTION(d, f) \ |
| namespace { \ |
| static EventQueryFunctionRegisterer<d> g_event_query_##d(f); \ |
| } |
| |
| template <int d> |
| struct EventErrorMessageFunctionRegisterer { |
| explicit EventErrorMessageFunctionRegisterer(EventErrorMessageFunction f) { |
| static_assert(d < MaxDeviceTypes, ""); |
| Event::event_err_msg_getter_[d] = f; |
| } |
| }; |
| #define REGISTER_EVENT_ERROR_MESSAGE_FUNCTION(d, f) \ |
| namespace { \ |
| static EventErrorMessageFunctionRegisterer<d> g_event_err_msg_##d(f); \ |
| } |
| |
| template <int d> |
| struct EventSetFinishedFunctionRegisterer { |
| explicit EventSetFinishedFunctionRegisterer(EventSetFinishedFunction f) { |
| static_assert(d < MaxDeviceTypes, ""); |
| Event::event_finished_setter_[d] = f; |
| } |
| }; |
| #define REGISTER_EVENT_SET_FINISHED_FUNCTION(d, f) \ |
| namespace { \ |
| static EventSetFinishedFunctionRegisterer<d> g_event_set_finished_##d(f); \ |
| } |
| |
| template <int d> |
| struct EventSetCallbackFunctionRegisterer { |
| explicit EventSetCallbackFunctionRegisterer(EventSetCallbackFunction f) { |
| static_assert(d < MaxDeviceTypes, ""); |
| Event::event_callback_setter_[d] = f; |
| } |
| }; |
| #define REGISTER_EVENT_SET_CALLBACK_FUNCTION(d, f) \ |
| namespace { \ |
| static EventSetCallbackFunctionRegisterer<d> g_event_set_callback_##d(f); \ |
| } |
| |
| template <int d> |
| struct EventFinishFunctionRegisterer { |
| explicit EventFinishFunctionRegisterer(EventFinishFunction f) { |
| static_assert(d < MaxDeviceTypes, ""); |
| Event::event_finisher_[d] = f; |
| } |
| }; |
| #define REGISTER_EVENT_FINISH_FUNCTION(d, f) \ |
| namespace { \ |
| static EventFinishFunctionRegisterer<d> g_event_finish_##d(f); \ |
| } |
| |
| template <int d> |
| struct EventResetFunctionRegisterer { |
| explicit EventResetFunctionRegisterer(EventResetFunction f) { |
| static_assert(d < MaxDeviceTypes, ""); |
| Event::event_resetter_[d] = f; |
| } |
| }; |
| #define REGISTER_EVENT_RESET_FUNCTION(d, f) \ |
| namespace { \ |
| static EventResetFunctionRegisterer<d> g_event_reset_##d(f); \ |
| } |
| |
| } // namespace caffe2 |
| |
| #endif // CAFFE2_CORE_EVENT_H_ |