Skip to content

Commit 2b313a5

Browse files
committed
[backend] add backend option to async file writer
1 parent 8f63b21 commit 2b313a5

File tree

7 files changed

+22
-26
lines changed

7 files changed

+22
-26
lines changed

csrc/async_file_io.cpp

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,8 @@
11
#include "async_file_io.h"
22
#include "backend.h"
33
#include <stdexcept>
4-
#include <string>
54

6-
AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries) : fd(fd)
7-
{
8-
for (const std::string &backend : get_backends())
9-
{
10-
if (probe_backend(backend))
11-
{
12-
this->aio = create_asyncio(n_entries, backend);
13-
return;
14-
}
15-
}
16-
throw std::runtime_error("No asyncio backend is installed");
17-
}
5+
AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend) : fd(fd), aio(create_asyncio(n_entries, backend)) {}
186

197
void AsyncFileWriter::write(size_t buffer, size_t n_bytes, unsigned long long offset)
208
{

csrc/backend.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,19 +130,23 @@ std::string get_default_backend() {
130130
AsyncIO *create_asyncio(unsigned int n_entries, std::string backend)
131131
{
132132
std::unordered_set<std::string> backends = get_backends();
133+
std::string default_backend = get_default_backend();
134+
133135
if (backends.empty())
134136
throw std::runtime_error("No asyncio backend is installed");
135137

136-
std::string default_backend = get_default_backend();
137-
if (default_backend.size() > 0) {
138-
std::cout << "[backend] backend is overwritten by environ TENSORNVME_BACKEND from " << backend << std::endl;
138+
if (default_backend.size() > 0) { // priority 1: environ is set
139+
std::cout << "[backend] backend is overwritten by environ TENSORNVME_BACKEND from " << backend << " to " << default_backend << std::endl;
139140
backend = default_backend;
141+
} else if (backend.size() > 0) { // priority 2: backend is set
142+
if (backends.find(backend) == backends.end())
143+
throw std::runtime_error("Unsupported backend: " + backend);
140144
}
141145
std::cout << "[backend] using backend: " << backend << std::endl;
142-
if (backends.find(backend) == backends.end())
143-
throw std::runtime_error("Unsupported backend: " + backend);
146+
144147
if (!probe_backend(backend))
145148
throw std::runtime_error("Backend \"" + backend + "\" is not install correctly");
149+
146150
#ifndef DISABLE_URING
147151
if (backend == "uring")
148152
return new UringAsyncIO(n_entries);

csrc/py_api.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace py = pybind11;
1212
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
1313
{
1414
py::class_<Offloader>(m, "Offloader")
15-
.def(py::init<const std::string &, unsigned int, const std::string &>(), py::arg("filename"), py::arg("n_entries"), py::arg("backend") = "uring")
15+
.def(py::init<const std::string &, unsigned int, const std::string &>(), py::arg("filename"), py::arg("n_entries"), py::arg("backend") = "aio")
1616
.def("async_write", &Offloader::async_write, py::arg("tensor"), py::arg("key"), py::arg("callback") = py::none())
1717
.def("async_read", &Offloader::async_read, py::arg("tensor"), py::arg("key"), py::arg("callback") = py::none())
1818
.def("sync_write", &Offloader::sync_write, py::arg("tensor"), py::arg("key"))
@@ -27,7 +27,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
2727
m.def("get_backends", get_backends);
2828
m.def("probe_backend", probe_backend, py::arg("backend"));
2929
py::class_<AsyncFileWriter>(m, "AsyncFileWriter")
30-
.def(py::init<int, unsigned int>(), py::arg("fd"), py::arg("n_entries"))
30+
.def(py::init<int, unsigned int, const std::string &>(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio")
3131
.def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset"))
3232
.def("synchronize", &AsyncFileWriter::synchronize);
3333
}

include/async_file_io.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#pragma once
2+
#include <string>
23
#include "asyncio.h"
34
#ifndef DISABLE_URING
45
#include "uring.h"
@@ -10,7 +11,7 @@
1011
class AsyncFileWriter
1112
{
1213
public:
13-
AsyncFileWriter(int fd, unsigned int n_entries);
14+
AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend);
1415
void write(size_t buffer, size_t n_bytes, unsigned long long offset);
1516
void synchronize();
1617
~AsyncFileWriter();

include/offload.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
class Offloader
1515
{
1616
public:
17-
Offloader(const std::string &filename, unsigned int n_entries, const std::string &backend = "uring");
17+
Offloader(const std::string &filename, unsigned int n_entries, const std::string &backend);
1818
SpaceInfo prepare_write(const at::Tensor &tensor, const std::string &key);
1919
SpaceInfo prepare_read(const at::Tensor &tensor, const std::string &key);
2020
void async_write(const at::Tensor &tensor, const std::string &key, callback_t callback = nullptr);

tensornvme/_C/__init__.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Set
33
from torch import Tensor
44

55
class Offloader:
6-
def __init__(self, filename: str, n_entries: int, backend: str = "uring") -> None: ...
6+
def __init__(self, filename: str, n_entries: int, backend: str = "aio") -> None: ...
77
def async_write(self, tensor: Tensor, key: str, callback: Optional[Callable[[], None]] = None) -> None: ...
88
def async_read(self, tensor: Tensor, key: str, callback: Optional[Callable[[], None]] = None) -> None: ...
99
def sync_write(self, tensor: Tensor, key: str) -> None: ...
@@ -20,6 +20,6 @@ def get_backends() -> Set[str]: ...
2020
def probe_backend(backend: str) -> bool: ...
2121

2222
class AsyncFileWriter:
23-
def __init__(self, fd: int, n_entries: int) -> None: ...
23+
def __init__(self, fd: int, n_entries: int, backend: str = "aio") -> None: ...
2424
def write(self, buffer, n_bytes: int, offset: int) -> None: ...
2525
def synchronize(self) -> None: ...

tensornvme/async_file_io.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55

66

77
class AsyncFileWriter:
8-
def __init__(self, fp: IOBase, n_entries: int = 16) -> None:
8+
def __init__(self, fp: IOBase, n_entries: int = 16, backend=None) -> None:
99
fd = fp.fileno()
10-
self.io = AsyncFileWriterC(fd, n_entries)
10+
if backend is not None:
11+
self.io = AsyncFileWriterC(fd, n_entries, backend=backend)
12+
else:
13+
self.io = AsyncFileWriterC(fd, n_entries)
1114
self.fp = fp
1215
self.offset = 0
1316
# must ensure the data is not garbage collected

0 commit comments

Comments
 (0)