|
1 | 1 | #include "pthread_backend.h"
|
2 | 2 |
|
3 |
| -#include <iostream> |
4 |
| - |
5 | 3 | void PthreadAsyncIO::write(int fd, void *buffer, size_t n_bytes, unsigned long long offset, callback_t callback) {
|
6 | 4 | auto fut = this->pool.submit_task(
|
7 | 5 | [fd, buffer, n_bytes, offset] {
|
@@ -81,21 +79,23 @@ void PthreadAsyncIO::synchronize() {
|
81 | 79 | void PthreadAsyncIO::register_file(int fd) {}
|
82 | 80 |
|
83 | 81 | void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) {
|
| 82 | + auto stream = c10::cuda::getCurrentCUDAStream(); |
| 83 | + at::cuda::CUDAStreamGuard guard(stream); // https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html |
| 84 | + auto event_ptr = std::make_shared<c10::Event>(torch::kCUDA); // make a shared ptr here since event is not copyable |
| 85 | + if (t.is_cuda()) { |
| 86 | + if (pinned.has_value()) { |
| 87 | + pinned.value().copy_(t, /*non_blocking*/ true); |
| 88 | + t = pinned.value(); |
| 89 | + } else { |
| 90 | + t = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ true, /*copy*/ false); // modified from torch::Tensor::cpu() |
| 91 | + } |
| 92 | + } |
| 93 | + event_ptr->record(stream); |
84 | 94 | auto fut = this->pool.submit_task(
|
85 |
| - [fd, t, offset, pinned] { |
86 |
| - torch::Tensor cpu_tensor; |
87 |
| - if (t.is_cuda()) { |
88 |
| - if (pinned.has_value()) { |
89 |
| - pinned.value().copy_(t); |
90 |
| - cpu_tensor = pinned.value(); |
91 |
| - } else { |
92 |
| - cpu_tensor = t.to(torch::kCPU); |
93 |
| - } |
94 |
| - } else { |
95 |
| - cpu_tensor = t; |
96 |
| - } |
97 |
| - void *buf = cpu_tensor.data_ptr(); |
98 |
| - size_t n_bytes = cpu_tensor.numel() * cpu_tensor.element_size(); |
| 95 | + [fd, t, offset, pinned, event_ptr] { |
| 96 | + event_ptr->synchronize(); // sync with comm stream |
| 97 | + void *buf = t.data_ptr(); |
| 98 | + size_t n_bytes = t.numel() * t.element_size(); |
99 | 99 | return pwrite(fd, buf, n_bytes, offset);
|
100 | 100 | }
|
101 | 101 | );
|
|
0 commit comments