|
27 | 27 | #elif USE_ASCEND_ADAPTOR
|
28 | 28 | #include <ATen/cuda/CUDAEvent.h>
|
29 | 29 | #include <cuda_runtime.h>
|
| 30 | +#include <acl/acl.h> |
30 | 31 | #endif
|
31 | 32 |
|
32 | 33 | namespace c10d {
|
@@ -205,31 +206,45 @@ class flagcxXpuEvent : public flagcxEvent {
|
205 | 206 | at::cuda::CUDAEvent cudaEvent_;
|
206 | 207 | };
|
207 | 208 | #elif USE_ASCEND_ADAPTOR
|
208 |
| -class flagcxXpuEvent : public flagcxEvent { |
| 209 | +class flagcxNpuEvent() { |
209 | 210 | public:
|
210 |
| - flagcxAscendEvent() { cudaEvent_ = at::cuda::CUDAEvent(cudaEventDisableTiming); } |
211 |
| - |
212 |
| - void record(const int deviceId) override { |
213 |
| - cudaEvent_.record(at::cuda::getCurrentCUDAStream(deviceId)); |
214 |
| - } |
215 |
| - |
216 |
| - void record(const flagcxStream_t &stream, const int deviceId) override { |
217 |
| - cudaEvent_.record( |
218 |
| - at::cuda::getStreamFromExternal(*(cudaStream_t *)stream, deviceId)); |
219 |
| - } |
220 |
| - |
221 |
| - void block(const int deviceId) override { |
222 |
| - cudaEvent_.block(at::cuda::getCurrentCUDAStream(deviceId)); |
223 |
| - } |
224 |
| - |
225 |
| - void block(const flagcxStream_t &stream, const int deviceId) override { |
226 |
| - cudaEvent_.block( |
227 |
| - at::cuda::getStreamFromExternal(*(cudaStream_t *)stream, deviceId)); |
228 |
| - } |
| 211 | + flagcxNpuEvent() { |
| 212 | + aclError ret = aclrtCreateEvent(&aclEvent_); |
| 213 | + if (ret != ACL_SUCCESS) { |
| 214 | + throw std::runtime_error("Failed to create NPU event"); |
| 215 | + } |
| 216 | + } |
| 217 | + |
| 218 | + ~flagcxNpuEvent() { |
| 219 | + aclrtDestoryEvent(aclEvent_); |
| 220 | + } |
| 221 | + |
| 222 | + void record(const int devicedId) override { |
| 223 | + aclrtStream currentStream; |
| 224 | + aclrtGetCurrentStream(¤tStream); |
| 225 | + aclrtEventRecord(aclEvent_, currentStream); |
| 226 | + } |
| 227 | + |
| 228 | + void record(const flagcxStream_t &stream, const int deviceId) override { |
| 229 | + aclrtStream targetStream = reinterpret_cast<aclrtStream>(stream); |
| 230 | + aclrtEventRecord(aclEvent_, targetStream); |
| 231 | + } |
| 232 | + |
| 233 | + void block(const int deviceId) override { |
| 234 | + aclrtStream currentStream; |
| 235 | + aclrtGetCurrentStream(¤tStream); |
| 236 | + aclrtStreamWaitEvent(currentStream, aclEvent_); |
| 237 | + } |
| 238 | + |
| 239 | + void block(const flagcxStream_t &stream, const int deviceId) override { |
| 240 | + aclrtStream targetStream = reinterpret_cast<aclrtStream>(stream); |
| 241 | + aclrtStreamWaitEvent(targetStream, aclEvent_); |
| 242 | + } |
229 | 243 |
|
230 | 244 | private:
|
231 |
| - at::cuda::CUDAEvent cudaEvent_; |
| 245 | + aclrtEvent aclEvent_; |
232 | 246 | };
|
| 247 | + |
233 | 248 | #endif
|
234 | 249 |
|
235 | 250 | } // namespace c10d
|
0 commit comments