Skip to content

Commit 2005d82

Browse files
committed
fix event wip
1 parent 45df2a9 commit 2005d82

File tree

2 files changed

+37
-22
lines changed

2 files changed

+37
-22
lines changed

plugin/torch/flagcx/include/event_flagcx.hpp

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#elif USE_ASCEND_ADAPTOR
2828
#include <ATen/cuda/CUDAEvent.h>
2929
#include <cuda_runtime.h>
30+
#include <acl/acl.h>
3031
#endif
3132

3233
namespace c10d {
@@ -205,31 +206,45 @@ class flagcxXpuEvent : public flagcxEvent {
205206
at::cuda::CUDAEvent cudaEvent_;
206207
};
207208
#elif USE_ASCEND_ADAPTOR
208-
class flagcxXpuEvent : public flagcxEvent {
209+
class flagcxNpuEvent() {
209210
public:
210-
flagcxAscendEvent() { cudaEvent_ = at::cuda::CUDAEvent(cudaEventDisableTiming); }
211-
212-
void record(const int deviceId) override {
213-
cudaEvent_.record(at::cuda::getCurrentCUDAStream(deviceId));
214-
}
215-
216-
void record(const flagcxStream_t &stream, const int deviceId) override {
217-
cudaEvent_.record(
218-
at::cuda::getStreamFromExternal(*(cudaStream_t *)stream, deviceId));
219-
}
220-
221-
void block(const int deviceId) override {
222-
cudaEvent_.block(at::cuda::getCurrentCUDAStream(deviceId));
223-
}
224-
225-
void block(const flagcxStream_t &stream, const int deviceId) override {
226-
cudaEvent_.block(
227-
at::cuda::getStreamFromExternal(*(cudaStream_t *)stream, deviceId));
228-
}
211+
flagcxNpuEvent() {
212+
aclError ret = aclrtCreateEvent(&aclEvent_);
213+
if (ret != ACL_SUCCESS) {
214+
throw std::runtime_error("Failed to create NPU event");
215+
}
216+
}
217+
218+
~flagcxNpuEvent() {
219+
aclrtDestoryEvent(aclEvent_);
220+
}
221+
222+
void record(const int devicedId) override {
223+
aclrtStream currentStream;
224+
aclrtGetCurrentStream(&currentStream);
225+
aclrtEventRecord(aclEvent_, currentStream);
226+
}
227+
228+
void record(const flagcxStream_t &stream, const int deviceId) override {
229+
aclrtStream targetStream = reinterpret_cast<aclrtStream>(stream);
230+
aclrtEventRecord(aclEvent_, targetStream);
231+
}
232+
233+
void block(const int deviceId) override {
234+
aclrtStream currentStream;
235+
aclrtGetCurrentStream(&currentStream);
236+
aclrtStreamWaitEvent(currentStream, aclEvent_);
237+
}
238+
239+
void block(const flagcxStream_t &stream, const int deviceId) override {
240+
aclrtStream targetStream = reinterpret_cast<aclrtStream>(stream);
241+
aclrtStreamWaitEvent(targetStream, aclEvent_);
242+
}
229243

230244
private:
231-
at::cuda::CUDAEvent cudaEvent_;
245+
aclrtEvent aclEvent_;
232246
};
247+
233248
#endif
234249

235250
} // namespace c10d

plugin/torch/flagcx/src/backend_flagcx.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ std::unique_ptr<flagcxEvent> &flagcxBackend::getEventByIndex(int eventId) {
209209
#elif USE_KUNLUNXIN_ADAPTOR
210210
flagcxEvents_[eventId] = std::make_unique<flagcxXpuEvent>();
211211
#elif USE_ASCEND_ADAPTOR
212-
flagcxEvents_[eventId] = std::make_unique<flagcxXpuEvent>();
212+
flagcxEvents_[eventId] = std::make_unique<flagcxNpuEvent>();
213213
#endif
214214
return flagcxEvents_[eventId];
215215
}

0 commit comments

Comments
 (0)