forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
event.cc
148 lines (122 loc) · 4.86 KB
/
event.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#include "caffe2/core/event_cpu.h"
namespace caffe2 {
CAFFE2_API EventCreateFunction Event::event_creator_[MaxDeviceTypes];
CAFFE2_API EventRecordFunction Event::event_recorder_[MaxDeviceTypes];
CAFFE2_API EventWaitFunction
Event::event_waiter_[MaxDeviceTypes][MaxDeviceTypes];
CAFFE2_API EventFinishFunction Event::event_finisher_[MaxDeviceTypes];
CAFFE2_API EventQueryFunction Event::event_querier_[MaxDeviceTypes];
CAFFE2_API EventErrorMessageFunction
Event::event_err_msg_getter_[MaxDeviceTypes];
CAFFE2_API EventSetFinishedFunction
Event::event_finished_setter_[MaxDeviceTypes];
CAFFE2_API EventResetFunction Event::event_resetter_[MaxDeviceTypes];
CAFFE2_API EventSetCallbackFunction
Event::event_callback_setter_[MaxDeviceTypes];
namespace {
const std::string kNoError = "No error";
}
void EventCreateCPU(const DeviceOption& option, Event* event) {
event->event_ = std::make_shared<CPUEventWrapper>(option);
}
void EventRecordCPU(
Event* event,
const void* /* unused */,
const char* err_msg) {
auto* wrapper = static_cast<CPUEventWrapper*>(event->event_.get());
std::unique_lock<std::mutex> lock(wrapper->mutex_);
// Possible state changes:
// INITIALIZED -> SCHEDULED or SUCCESS/FAILED
// SCHEDULED -> SUCCESS/FAILED
// SUCCESS/FAILED - terminal, no further changes to status_/err_msg_
CAFFE_ENFORCE(
wrapper->status_ != EventStatus::EVENT_SCHEDULED,
"Calling Record multiple times");
// Event might be in SUCCESS/FAILED state in case an op has
// finished async execution part first
if (wrapper->status_ == EventStatus::EVENT_INITIALIZED) {
if (!err_msg) {
wrapper->status_ = EventStatus::EVENT_SCHEDULED;
} else {
wrapper->err_msg_ = err_msg;
wrapper->status_ = EventStatus::EVENT_FAILED;
wrapper->cv_completed_.notify_all();
}
}
}
void EventFinishCPU(const Event* event) {
auto* wrapper = static_cast<CPUEventWrapper*>(event->event_.get());
std::unique_lock<std::mutex> lock(wrapper->mutex_);
while (wrapper->status_ != EventStatus::EVENT_SUCCESS &&
wrapper->status_ != EventStatus::EVENT_FAILED) {
wrapper->cv_completed_.wait(lock);
}
}
void EventWaitCPUCPU(const Event* event, void* /* context */) {
EventFinishCPU(event);
}
EventStatus EventQueryCPU(const Event* event) {
auto* wrapper = static_cast<CPUEventWrapper*>(event->event_.get());
return static_cast<EventStatus>(wrapper->status_.load());
}
const std::string& EventErrorMessageCPU(const Event* event) {
auto* wrapper = static_cast<CPUEventWrapper*>(event->event_.get());
if (wrapper->status_ == EventStatus::EVENT_FAILED) {
// Failed is a terminal state, not synchronizing,
// err_msg_ should not be changed anymore
return wrapper->err_msg_;
} else {
return kNoError;
}
}
void EventSetFinishedCPU(const Event* event, const char* err_msg) {
auto* wrapper = static_cast<CPUEventWrapper*>(event->event_.get());
std::unique_lock<std::mutex> lock(wrapper->mutex_);
if (wrapper->status_ == EventStatus::EVENT_FAILED) {
LOG(WARNING) << "SetFinished called on a finished event. "
<< "Most likely caused by an external cancellation. "
<< "old message: " << wrapper->err_msg_ << ", "
<< "new message: " << err_msg;
return;
}
CAFFE_ENFORCE(
wrapper->status_ == EventStatus::EVENT_INITIALIZED ||
wrapper->status_ == EventStatus::EVENT_SCHEDULED,
"Calling SetFinished on finished event");
if (!err_msg) {
wrapper->status_ = EventStatus::EVENT_SUCCESS;
} else {
wrapper->err_msg_ = err_msg;
wrapper->status_ = EventStatus::EVENT_FAILED;
}
for (auto& callback : wrapper->callbacks_) {
callback();
}
wrapper->cv_completed_.notify_all();
}
void EventSetCallbackCPU(Event* event, EventCallbackFunction callback) {
auto* wrapper = static_cast<CPUEventWrapper*>(event->event_.get());
std::unique_lock<std::mutex> lock(wrapper->mutex_);
wrapper->callbacks_.push_back(callback);
if (wrapper->status_ == EventStatus::EVENT_SUCCESS ||
wrapper->status_ == EventStatus::EVENT_FAILED) {
callback();
}
}
void EventResetCPU(Event* event) {
auto* wrapper = static_cast<CPUEventWrapper*>(event->event_.get());
std::unique_lock<std::mutex> lock(wrapper->mutex_);
wrapper->status_ = EventStatus::EVENT_INITIALIZED;
wrapper->err_msg_ = "";
wrapper->callbacks_.clear();
}
REGISTER_EVENT_CREATE_FUNCTION(CPU, EventCreateCPU);
REGISTER_EVENT_RECORD_FUNCTION(CPU, EventRecordCPU);
REGISTER_EVENT_WAIT_FUNCTION(CPU, CPU, EventWaitCPUCPU);
REGISTER_EVENT_FINISH_FUNCTION(CPU, EventFinishCPU);
REGISTER_EVENT_QUERY_FUNCTION(CPU, EventQueryCPU);
REGISTER_EVENT_ERROR_MESSAGE_FUNCTION(CPU, EventErrorMessageCPU);
REGISTER_EVENT_SET_FINISHED_FUNCTION(CPU, EventSetFinishedCPU);
REGISTER_EVENT_RESET_FUNCTION(CPU, EventResetCPU);
REGISTER_EVENT_SET_CALLBACK_FUNCTION(CPU, EventSetCallbackCPU);
} // namespace caffe2