Skip to content

Commit b2f9944

Browse files
authored
[fio] implement async io with safetensors format (#48)
* [fio] implement async io with safetensors format * [fio] use raw tensor ptr instead of numpy * [chore] refactor * [fio] add callback * [chore] refactor
1 parent ebc660e commit b2f9944

File tree

5 files changed

+21
-11
lines changed

5 files changed

+21
-11
lines changed

csrc/async_file_io.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1+
#include "asyncio.h"
12
#include "async_file_io.h"
23
#include "backend.h"
34
#include <stdexcept>
45

56
AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend) : fd(fd), aio(create_asyncio(n_entries, backend)) {}
67

7-
void AsyncFileWriter::write(size_t buffer, size_t n_bytes, unsigned long long offset)
8+
void AsyncFileWriter::write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback)
89
{
910
void *ptr = reinterpret_cast<void *>(buffer);
10-
this->aio->write(this->fd, ptr, n_bytes, offset, nullptr);
11+
this->aio->write(this->fd, ptr, n_bytes, offset, callback);
1112
}
1213

1314
void AsyncFileWriter::synchronize()

csrc/py_api.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
2828
m.def("probe_backend", probe_backend, py::arg("backend"));
2929
py::class_<AsyncFileWriter>(m, "AsyncFileWriter")
3030
.def(py::init<int, unsigned int, const std::string &>(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio")
31-
.def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset"))
31+
.def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset"), py::arg("callback") = py::none())
3232
.def("synchronize", &AsyncFileWriter::synchronize);
3333
}

include/async_file_io.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class AsyncFileWriter
1212
{
1313
public:
1414
AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend);
15-
void write(size_t buffer, size_t n_bytes, unsigned long long offset);
15+
void write(size_t buffer, size_t n_bytes, unsigned long long offset, callback_t callback);
1616
void synchronize();
1717
~AsyncFileWriter();
1818

tensornvme/_C/__init__.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,5 @@ def probe_backend(backend: str) -> bool: ...
2121

2222
class AsyncFileWriter:
2323
def __init__(self, fd: int, n_entries: int, backend: str = "aio") -> None: ...
24-
def write(self, buffer, n_bytes: int, offset: int) -> None: ...
24+
def write(self, buffer: int, n_bytes: int, offset: int, callback: Optional[Callable[[], None]] = None) -> None: ...
2525
def synchronize(self) -> None: ...

tensornvme/async_file_io.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import ctypes
2-
from io import IOBase
2+
from functools import partial
33

4+
from typing import List
5+
from io import IOBase
46
from tensornvme._C import AsyncFileWriter as AsyncFileWriterC
57

6-
78
class AsyncFileWriter:
89
def __init__(self, fp: IOBase, n_entries: int = 16, backend=None) -> None:
910
fd = fp.fileno()
@@ -17,15 +18,23 @@ def __init__(self, fp: IOBase, n_entries: int = 16, backend=None) -> None:
1718
self.buffers = []
1819

1920
def write(self, data: bytes) -> int:
20-
if isinstance(data, memoryview):
21-
data = data.tobytes()
2221
ptr = ctypes.cast(data, ctypes.POINTER(ctypes.c_char))
2322
addr = ctypes.addressof(ptr.contents)
24-
self.io.write(addr, len(data), self.offset)
23+
self.buffers.append(data) # append before callback is called
24+
self.io.write(addr, len(data), self.offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1))
2525
self.offset += len(data)
26-
self.buffers.append(data)
26+
2727
return len(data)
2828

29+
def write_raw(self, py_ref: object, buffer: int, n_bytes: int, offset: int) -> None:
30+
self.buffers.append(py_ref) # append before callback is called
31+
self.io.write(buffer, n_bytes, offset, partial(AsyncFileWriter.gc_callback, self.buffers, len(self.buffers) - 1))
32+
self.offset += n_bytes
33+
34+
@staticmethod
35+
def gc_callback(listt: List, idx: int) -> None:
36+
listt[idx] = None
37+
2938
def flush(self) -> None:
3039
pass
3140

0 commit comments

Comments
 (0)