Skip to content

Commit 5da46cd

Browse files
committed
[h2d] add individual sync for h2d
1 parent 39ea874 commit 5da46cd

File tree

11 files changed

+69
-19
lines changed

11 files changed

+69
-19
lines changed

csrc/aio.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,7 @@ void AIOAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset
138138
void *buffer = t.data_ptr();
139139
size_t n_bytes = t.numel() * t.element_size();
140140
this->write(fd, buffer, n_bytes, offset, callback);
141-
}
141+
}
142+
143+
void AIOAsyncIO::register_h2d(unsigned int num_tensors) {}
144+
void AIOAsyncIO::sync_h2d() {}

csrc/async_file_io.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ void AsyncFileWriter::write_tensor(torch::Tensor tensor, unsigned long long offs
1212
this->aio->write_tensor(this->fd, tensor, offset, callback, pinned);
1313
}
1414

15+
void AsyncFileWriter::register_h2d(unsigned int num_tensors) {
16+
this->aio->register_h2d(num_tensors);
17+
}
18+
19+
void AsyncFileWriter::sync_h2d() {
20+
this->aio->sync_h2d();
21+
}
1522

1623
void AsyncFileWriter::synchronize()
1724
{

csrc/pthread_backend.cpp

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -78,24 +78,36 @@ void PthreadAsyncIO::synchronize() {
7878

7979
void PthreadAsyncIO::register_file(int fd) {}
8080

81+
void PthreadAsyncIO::register_h2d(unsigned int num_tensors) {
82+
this->h2d_in_progress.store(num_tensors); // register tensors to write for this run
83+
}
84+
85+
void PthreadAsyncIO::sync_h2d() {
86+
std::unique_lock<std::mutex> lock(this->mtx);
87+
this->cv.wait(lock, [this] { return this->h2d_in_progress == 0; }); // block until all in-progress h2d are completed
88+
}
89+
8190
void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) {
8291
auto stream = c10::cuda::getCurrentCUDAStream();
83-
at::cuda::CUDAStreamGuard guard(stream); // https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html
84-
auto event_ptr = std::make_shared<c10::Event>(torch::kCUDA); // make a shared ptr here since event is not copyable
85-
if (t.is_cuda()) {
86-
if (pinned.has_value()) {
87-
pinned.value().copy_(t, /*non_blocking*/ true);
88-
t = pinned.value();
89-
} else {
90-
t = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ true, /*copy*/ false); // modified from torch::Tensor::cpu()
91-
}
92-
}
93-
event_ptr->record(stream);
9492
auto fut = this->pool.submit_task(
95-
[fd, t, offset, pinned, event_ptr] {
96-
event_ptr->synchronize(); // sync with comm stream
97-
void *buf = t.data_ptr();
98-
size_t n_bytes = t.numel() * t.element_size();
93+
[this, fd, t, offset, pinned, stream] {
94+
at::cuda::CUDAStreamGuard guard(stream); // https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html
95+
torch::Tensor cpu_tensor;
96+
if (t.is_cuda()) {
97+
if (pinned.has_value()) {
98+
pinned.value().copy_(t, /*non_blocking*/ false);
99+
cpu_tensor = pinned.value();
100+
} else {
101+
cpu_tensor = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); // modified from torch::Tensor::cpu()
102+
}
103+
}
104+
this->h2d_in_progress.fetch_sub(1);
105+
if (this->h2d_in_progress.load() == 0) { // notify when all h2d are completed and safe to optimizer.step()
106+
std::lock_guard<std::mutex> lock(this->mtx);
107+
cv.notify_one();
108+
}
109+
void *buf = cpu_tensor.data_ptr();
110+
size_t n_bytes = cpu_tensor.numel() * cpu_tensor.element_size();
99111
return pwrite(fd, buf, n_bytes, offset);
100112
}
101113
);

csrc/py_api.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
3030
.def(py::init<int, unsigned int, const std::string &>(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio")
3131
.def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset"), py::arg("callback") = py::none())
3232
.def("write_tensor", &AsyncFileWriter::write_tensor, py::arg("tensor"), py::arg("offset"), py::arg("callback") = py::none(), py::arg("pinned") = py::none())
33-
.def("synchronize", &AsyncFileWriter::synchronize);
33+
.def("synchronize", &AsyncFileWriter::synchronize)
34+
.def("sync_h2d", &AsyncFileWriter::sync_h2d)
35+
.def("register_h2d", &AsyncFileWriter::register_h2d, py::arg("num_tensors"));
3436
}

csrc/uring.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,7 @@ void UringAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offs
111111
void *buffer = t.data_ptr<float>();
112112
size_t n_bytes = t.numel() * t.element_size();
113113
this->write(fd, buffer, n_bytes, offset, callback);
114-
}
114+
}
115+
116+
void UringAsyncIO::register_h2d(unsigned int num_tensors) {}
117+
void UringAsyncIO::sync_h2d() {}

include/aio.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class AIOAsyncIO : public AsyncIO
2727
void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);
2828
void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);
2929

30+
void register_h2d(unsigned int num_tensors);
31+
void sync_h2d();
3032
void sync_write_events();
3133
void sync_read_events();
3234
void synchronize();

include/async_file_io.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ class AsyncFileWriter
2121
void write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback);
2222
void write_tensor(torch::Tensor tensor, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned);
2323
void synchronize();
24+
void register_h2d(unsigned int num_tensors);
25+
void sync_h2d();
2426
~AsyncFileWriter();
2527

2628
private:

include/asyncio.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class AsyncIO
4545
virtual void get_event(WaitType wt) = 0;
4646
virtual void sync_write_events() = 0;
4747
virtual void sync_read_events() = 0;
48+
virtual void register_h2d(unsigned int num_tensors) = 0;
49+
virtual void sync_h2d() = 0;
4850
virtual void synchronize() = 0;
4951

5052
virtual void register_file(int fd) = 0;

include/pthread_backend.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
#include <iostream>
1313
#include <c10/cuda/CUDAStream.h>
1414
#include <c10/cuda/CUDAGuard.h>
15+
#include <atomic>
16+
#include <condition_variable>
17+
#include <mutex>
1518

1619
#include "asyncio.h"
1720
#include "threadpool.hpp"
@@ -21,12 +24,15 @@ class PthreadAsyncIO : public AsyncIO
2124
{
2225
private:
2326
BS::thread_pool pool;
27+
std::atomic<unsigned int> h2d_in_progress;
28+
std::condition_variable cv;
29+
std::mutex mtx;
2430
std::deque<std::tuple<std::future<ssize_t>, callback_t>> write_fut;
2531
std::deque<std::tuple<std::future<ssize_t>, callback_t>> read_fut;
2632

2733
public:
2834
PthreadAsyncIO(unsigned int n_entries)
29-
: pool(n_entries) {}
35+
: pool(n_entries), h2d_in_progress(0) {}
3036

3137
~PthreadAsyncIO() {}
3238

@@ -38,6 +44,8 @@ class PthreadAsyncIO : public AsyncIO
3844
void get_event(WaitType wt);
3945
void sync_write_events();
4046
void sync_read_events();
47+
void register_h2d(unsigned int num_tensors);
48+
void sync_h2d();
4149
void synchronize();
4250

4351
void register_file(int fd);

include/uring.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ class UringAsyncIO : public AsyncIO
2121
void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);
2222
void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);
2323

24+
void register_h2d(unsigned int num_tensors);
25+
void sync_h2d();
2426
void sync_write_events();
2527
void sync_read_events();
2628
void synchronize();

tensornvme/async_file_io.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,19 @@ def write_tensor(self, tensor: Tensor, pinned: Optional[Tensor] = None) -> None:
3838
self.io.write_tensor(tensor, self.offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1), pinned)
3939
self.offset += tensor.numel() * tensor.element_size()
4040

41+
def sync_h2d(self) -> None:
42+
self.io.sync_h2d()
43+
44+
def register_h2d(self, num_tensors: int) -> None:
45+
self.io.register_h2d(num_tensors)
46+
4147
def write_gpu_tensor(self, tensor: Tensor, pinned: Optional[Tensor] = None) -> None:
4248
assert tensor.device.type == 'cuda', f"tensor must be on cuda device, got {tensor.device}"
4349
with torch.cuda.stream(self.comm_stream):
4450
self.write_tensor(tensor, pinned)
4551

4652
def sync_before_step(self):
53+
self.sync_h2d()
4754
self.comm_stream.synchronize()
4855

4956
@staticmethod

0 commit comments

Comments
 (0)