Skip to content

Commit a1bf816

Browse files
authored
[pthread] init async gpu -> cpu (#49)
* [pthread] init async gpu -> cpu * [chore] add callback * [chore] add pinned mem buffer * [tmp] non-blocking somehow not working * [h2d] add individual sync for h2d * [chore] enable notify when submitting tensor write task * [chore] remove api * [chore] remove api
1 parent b2f9944 commit a1bf816

File tree

12 files changed

+151
-12
lines changed

12 files changed

+151
-12
lines changed

csrc/aio.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#include <stdexcept>
2-
#include <memory>
31
#include "aio.h"
42

53
AIOAsyncIO::AIOAsyncIO(unsigned int n_entries)
@@ -126,4 +124,21 @@ void AIOAsyncIO::readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned l
126124
io_submit(this->io_ctx, 1, &iocbs); /* 提交这个I/O不会堵塞 */
127125

128126
this->n_read_events++;
129-
}
127+
}
128+
129+
void AIOAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) {
130+
if (t.is_cuda()) {
131+
if (pinned.has_value()) {
132+
pinned.value().copy_(t);
133+
t = pinned.value();
134+
} else {
135+
t = t.to(torch::kCPU);
136+
}
137+
}
138+
void *buffer = t.data_ptr();
139+
size_t n_bytes = t.numel() * t.element_size();
140+
this->write(fd, buffer, n_bytes, offset, callback);
141+
}
142+
143+
void AIOAsyncIO::register_h2d(unsigned int num_tensors) {}
144+
void AIOAsyncIO::sync_h2d() {}

csrc/async_file_io.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
#include "asyncio.h"
21
#include "async_file_io.h"
3-
#include "backend.h"
4-
#include <stdexcept>
52

63
AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend) : fd(fd), aio(create_asyncio(n_entries, backend)) {}
74

@@ -11,6 +8,18 @@ void AsyncFileWriter::write(size_t buffer, size_t n_bytes, unsigned long long of
118
this->aio->write(this->fd, ptr, n_bytes, offset, callback);
129
}
1310

11+
void AsyncFileWriter::write_tensor(torch::Tensor tensor, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) {
12+
this->aio->write_tensor(this->fd, tensor, offset, callback, pinned);
13+
}
14+
15+
void AsyncFileWriter::register_h2d(unsigned int num_tensors) {
16+
this->aio->register_h2d(num_tensors);
17+
}
18+
19+
void AsyncFileWriter::sync_h2d() {
20+
this->aio->sync_h2d();
21+
}
22+
1423
void AsyncFileWriter::synchronize()
1524
{
1625
this->aio->synchronize();

csrc/pthread_backend.cpp

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,49 @@ void PthreadAsyncIO::synchronize() {
7676
this->get_event(WAIT);
7777
}
7878

79-
void PthreadAsyncIO::register_file(int fd) {}
79+
void PthreadAsyncIO::register_file(int fd) {}
80+
81+
void PthreadAsyncIO::register_h2d(unsigned int num_tensors) {
82+
this->h2d_in_progress.store(num_tensors); // register tensors to write for this run
83+
}
84+
85+
void PthreadAsyncIO::sync_h2d() {
86+
std::unique_lock<std::mutex> lock(this->mtx);
87+
this->cv.wait(lock, [this] { return this->h2d_in_progress == 0; }); // block until all in-progress h2d are completed
88+
}
89+
90+
void PthreadAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) {
91+
auto stream = c10::cuda::getCurrentCUDAStream();
92+
if (!t.is_cuda()) {
93+
this->h2d_in_progress.fetch_sub(1); // already moved to cpu
94+
if (this->h2d_in_progress.load() == 0) { // notify when all h2d are completed and safe to optimizer.step()
95+
std::lock_guard<std::mutex> lock(this->mtx);
96+
cv.notify_one();
97+
}
98+
}
99+
auto fut = this->pool.submit_task(
100+
[this, fd, t, offset, pinned, stream] {
101+
torch::Tensor cpu_tensor;
102+
if (t.is_cuda()) {
103+
at::cuda::CUDAStreamGuard guard(stream); // https://pytorch.org/cppdocs/notes/tensor_cuda_stream.html
104+
if (pinned.has_value()) {
105+
pinned.value().copy_(t, /*non_blocking*/ false);
106+
cpu_tensor = pinned.value();
107+
} else {
108+
cpu_tensor = t.to(t.options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false); // modified from torch::Tensor::cpu()
109+
}
110+
this->h2d_in_progress.fetch_sub(1);
111+
if (this->h2d_in_progress.load() == 0) { // notify when all h2d are completed and safe to optimizer.step()
112+
std::lock_guard<std::mutex> lock(this->mtx);
113+
cv.notify_one();
114+
}
115+
} else {
116+
cpu_tensor = t;
117+
}
118+
void *buf = cpu_tensor.data_ptr();
119+
size_t n_bytes = cpu_tensor.numel() * cpu_tensor.element_size();
120+
return pwrite(fd, buf, n_bytes, offset);
121+
}
122+
);
123+
this->write_fut.push_back(std::make_tuple(std::move(fut), callback));
124+
}

csrc/py_api.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
2929
py::class_<AsyncFileWriter>(m, "AsyncFileWriter")
3030
.def(py::init<int, unsigned int, const std::string &>(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio")
3131
.def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset"), py::arg("callback") = py::none())
32-
.def("synchronize", &AsyncFileWriter::synchronize);
32+
.def("write_tensor", &AsyncFileWriter::write_tensor, py::arg("tensor"), py::arg("offset"), py::arg("callback") = py::none(), py::arg("pinned") = py::none())
33+
.def("synchronize", &AsyncFileWriter::synchronize)
34+
.def("sync_h2d", &AsyncFileWriter::sync_h2d)
35+
.def("register_h2d", &AsyncFileWriter::register_h2d, py::arg("num_tensors"));
3336
}

csrc/uring.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,21 @@ void UringAsyncIO::readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned
9797
io_uring_sqe_set_data(sqe, data);
9898
io_uring_submit(&this->ring);
9999
this->n_read_events++;
100-
}
100+
}
101+
102+
void UringAsyncIO::write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) {
103+
if (t.is_cuda()) {
104+
if (pinned.has_value()) {
105+
pinned.value().copy_(t);
106+
t = pinned.value();
107+
} else {
108+
t = t.to(torch::kCPU);
109+
}
110+
}
111+
void *buffer = t.data_ptr<float>();
112+
size_t n_bytes = t.numel() * t.element_size();
113+
this->write(fd, buffer, n_bytes, offset, callback);
114+
}
115+
116+
void UringAsyncIO::register_h2d(unsigned int num_tensors) {}
117+
void UringAsyncIO::sync_h2d() {}

include/aio.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#pragma once
22

33
#include <libaio.h>
4+
#include <torch/torch.h>
5+
#include <stdexcept>
6+
#include <memory>
47
#include "asyncio.h"
58

69
class AIOAsyncIO : public AsyncIO
@@ -24,9 +27,12 @@ class AIOAsyncIO : public AsyncIO
2427
void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);
2528
void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);
2629

30+
void register_h2d(unsigned int num_tensors);
31+
void sync_h2d();
2732
void sync_write_events();
2833
void sync_read_events();
2934
void synchronize();
3035

3136
void register_file(int fd);
37+
void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned);
3238
};

include/async_file_io.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
#pragma once
22
#include <string>
3+
#include <torch/torch.h>
4+
#include <optional>
5+
36
#include "asyncio.h"
7+
#include "backend.h"
8+
49
#ifndef DISABLE_URING
510
#include "uring.h"
611
#endif
12+
713
#ifndef DISABLE_AIO
814
#include "aio.h"
915
#endif
@@ -13,7 +19,10 @@ class AsyncFileWriter
1319
public:
1420
AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend);
1521
void write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback);
22+
void write_tensor(torch::Tensor tensor, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned);
1623
void synchronize();
24+
void register_h2d(unsigned int num_tensors);
25+
void sync_h2d();
1726
~AsyncFileWriter();
1827

1928
private:

include/asyncio.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <fcntl.h>
44
#include <functional>
5+
#include <torch/torch.h>
56

67
using callback_t = std::function<void()>;
78

@@ -44,7 +45,10 @@ class AsyncIO
4445
virtual void get_event(WaitType wt) = 0;
4546
virtual void sync_write_events() = 0;
4647
virtual void sync_read_events() = 0;
48+
virtual void register_h2d(unsigned int num_tensors) = 0;
49+
virtual void sync_h2d() = 0;
4750
virtual void synchronize() = 0;
4851

4952
virtual void register_file(int fd) = 0;
53+
virtual void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned) = 0;
5054
};

include/pthread_backend.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
#include <queue>
1010
#include <tuple>
1111
#include <functional>
12+
#include <iostream>
13+
#include <c10/cuda/CUDAStream.h>
14+
#include <c10/cuda/CUDAGuard.h>
15+
#include <atomic>
16+
#include <condition_variable>
17+
#include <mutex>
1218

1319
#include "asyncio.h"
1420
#include "threadpool.hpp"
@@ -18,12 +24,15 @@ class PthreadAsyncIO : public AsyncIO
1824
{
1925
private:
2026
BS::thread_pool pool;
27+
std::atomic<unsigned int> h2d_in_progress;
28+
std::condition_variable cv;
29+
std::mutex mtx;
2130
std::deque<std::tuple<std::future<ssize_t>, callback_t>> write_fut;
2231
std::deque<std::tuple<std::future<ssize_t>, callback_t>> read_fut;
2332

2433
public:
2534
PthreadAsyncIO(unsigned int n_entries)
26-
: pool(n_entries) {}
35+
: pool(n_entries), h2d_in_progress(0) {}
2736

2837
~PthreadAsyncIO() {}
2938

@@ -35,7 +44,11 @@ class PthreadAsyncIO : public AsyncIO
3544
void get_event(WaitType wt);
3645
void sync_write_events();
3746
void sync_read_events();
47+
void register_h2d(unsigned int num_tensors);
48+
void sync_h2d();
3849
void synchronize();
3950

4051
void register_file(int fd);
52+
53+
void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned);
4154
};

include/uring.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@ class UringAsyncIO : public AsyncIO
2121
void writev(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);
2222
void readv(int fd, const iovec *iov, unsigned int iovcnt, unsigned long long offset, callback_t callback);
2323

24+
void register_h2d(unsigned int num_tensors);
25+
void sync_h2d();
2426
void sync_write_events();
2527
void sync_read_events();
2628
void synchronize();
2729

2830
void register_file(int fd);
31+
void write_tensor(int fd, torch::Tensor t, unsigned long long offset, callback_t callback, std::optional<torch::Tensor> pinned);
2932
};

tensornvme/_C/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@ def probe_backend(backend: str) -> bool: ...
2222
class AsyncFileWriter:
2323
def __init__(self, fd: int, n_entries: int, backend: str = "aio") -> None: ...
2424
def write(self, buffer: int, n_bytes: int, offset: int, callback: Optional[Callable[[], None]] = None) -> None: ...
25+
def write_tensor(self, tensor: Tensor, offset: int, callback: Optional[Callable[[], None]] = None, pinned: Optional[Tensor] = None) -> None: ...
2526
def synchronize(self) -> None: ...

tensornvme/async_file_io.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import ctypes
2+
import torch
23
from functools import partial
3-
4-
from typing import List
4+
from torch import Tensor
5+
from typing import List, Optional
56
from io import IOBase
67
from tensornvme._C import AsyncFileWriter as AsyncFileWriterC
78

@@ -16,6 +17,7 @@ def __init__(self, fp: IOBase, n_entries: int = 16, backend=None) -> None:
1617
self.offset = 0
1718
# must ensure the data is not garbage collected
1819
self.buffers = []
20+
self.comm_stream = torch.cuda.Stream()
1921

2022
def write(self, data: bytes) -> int:
2123
ptr = ctypes.cast(data, ctypes.POINTER(ctypes.c_char))
@@ -31,6 +33,18 @@ def write_raw(self, py_ref: object, buffer: int, n_bytes: int, offset: int) -> N
3133
self.io.write(buffer, n_bytes, offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1))
3234
self.offset += n_bytes
3335

36+
def write_tensor(self, tensor: Tensor, pinned: Optional[Tensor] = None) -> None:
37+
with torch.cuda.stream(self.comm_stream):
38+
self.buffers.append(tensor) # append before callback is called
39+
self.io.write_tensor(tensor, self.offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1), pinned)
40+
self.offset += tensor.numel() * tensor.element_size()
41+
42+
def register_h2d(self, num_tensors: int) -> None:
43+
self.io.register_h2d(num_tensors)
44+
45+
def sync_before_step(self):
46+
self.io.sync_h2d()
47+
3448
@staticmethod
3549
def gc_callback(listt: List, idx: int) -> None:
3650
listt[idx] = None

0 commit comments

Comments
 (0)