Skip to content

Commit ebc660e

Browse files
authored
[backend] add environ to overwrite passed backend (#47)
* [backend] add environ to overwrite passed backend * [environ] move environ overwrite to create_asyncio * [backend] add overwrite info * [backend] add backend option to async file writer * [backend] add debug flag
1 parent 51ed242 commit ebc660e

File tree

9 files changed

+60
-27
lines changed

9 files changed

+60
-27
lines changed

csrc/async_file_io.cpp

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,8 @@
11
#include "async_file_io.h"
22
#include "backend.h"
33
#include <stdexcept>
4-
#include <string>
54

6-
AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries) : fd(fd)
7-
{
8-
for (const std::string &backend : get_backends())
9-
{
10-
if (probe_backend(backend))
11-
{
12-
this->aio = create_asyncio(n_entries, backend);
13-
return;
14-
}
15-
}
16-
throw std::runtime_error("No asyncio backend is installed");
17-
}
5+
AsyncFileWriter::AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend) : fd(fd), aio(create_asyncio(n_entries, backend)) {}
186

197
void AsyncFileWriter::write(size_t buffer, size_t n_bytes, unsigned long long offset)
208
{

csrc/backend.cpp

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,50 @@ bool probe_backend(const std::string &backend)
119119
}
120120
}
121121

122-
AsyncIO *create_asyncio(unsigned int n_entries, const std::string &backend)
122+
std::string get_default_backend() {
123+
const char* env = getenv("TENSORNVME_BACKEND");
124+
if (env == nullptr) {
125+
return std::string("");
126+
}
127+
return std::string(env);
128+
}
129+
130+
bool get_debug_flag() {
131+
const char* env_ = getenv("TENSORNVME_DEBUG");
132+
if (env_ == nullptr) {
133+
return false;
134+
}
135+
std::string env(env_);
136+
std::transform(env.begin(), env.end(), env.begin(),
137+
[](unsigned char c) { return std::tolower(c); });
138+
return env == "1" || env == "true";
139+
}
140+
141+
AsyncIO *create_asyncio(unsigned int n_entries, std::string backend)
123142
{
124143
std::unordered_set<std::string> backends = get_backends();
144+
std::string default_backend = get_default_backend();
145+
bool is_debugging = get_debug_flag();
146+
125147
if (backends.empty())
126148
throw std::runtime_error("No asyncio backend is installed");
127-
if (backends.find(backend) == backends.end())
128-
throw std::runtime_error("Unsupported backend: " + backend);
149+
150+
if (default_backend.size() > 0) { // priority 1: environ is set
151+
if (is_debugging) {
152+
std::cout << "[backend] backend is overwritten by environ TENSORNVME_BACKEND from " << backend << " to " << default_backend << std::endl;
153+
}
154+
backend = default_backend;
155+
} else if (backend.size() > 0) { // priority 2: backend is set
156+
if (backends.find(backend) == backends.end())
157+
throw std::runtime_error("Unsupported backend: " + backend);
158+
}
159+
if (is_debugging) {
160+
std::cout << "[backend] using backend: " << backend << std::endl;
161+
}
162+
129163
if (!probe_backend(backend))
130164
throw std::runtime_error("Backend \"" + backend + "\" is not install correctly");
165+
131166
#ifndef DISABLE_URING
132167
if (backend == "uring")
133168
return new UringAsyncIO(n_entries);

csrc/py_api.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace py = pybind11;
1212
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
1313
{
1414
py::class_<Offloader>(m, "Offloader")
15-
.def(py::init<const std::string &, unsigned int, const std::string &>(), py::arg("filename"), py::arg("n_entries"), py::arg("backend") = "uring")
15+
.def(py::init<const std::string &, unsigned int, const std::string &>(), py::arg("filename"), py::arg("n_entries"), py::arg("backend") = "aio")
1616
.def("async_write", &Offloader::async_write, py::arg("tensor"), py::arg("key"), py::arg("callback") = py::none())
1717
.def("async_read", &Offloader::async_read, py::arg("tensor"), py::arg("key"), py::arg("callback") = py::none())
1818
.def("sync_write", &Offloader::sync_write, py::arg("tensor"), py::arg("key"))
@@ -27,7 +27,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
2727
m.def("get_backends", get_backends);
2828
m.def("probe_backend", probe_backend, py::arg("backend"));
2929
py::class_<AsyncFileWriter>(m, "AsyncFileWriter")
30-
.def(py::init<int, unsigned int>(), py::arg("fd"), py::arg("n_entries"))
30+
.def(py::init<int, unsigned int, const std::string &>(), py::arg("fd"), py::arg("n_entries"), py::arg("backend") = "aio")
3131
.def("write", &AsyncFileWriter::write, py::arg("buffer"), py::arg("n_bytes"), py::arg("offset"))
3232
.def("synchronize", &AsyncFileWriter::synchronize);
3333
}

include/async_file_io.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#pragma once
2+
#include <string>
23
#include "asyncio.h"
34
#ifndef DISABLE_URING
45
#include "uring.h"
@@ -10,7 +11,7 @@
1011
class AsyncFileWriter
1112
{
1213
public:
13-
AsyncFileWriter(int fd, unsigned int n_entries);
14+
AsyncFileWriter(int fd, unsigned int n_entries, const std::string &backend);
1415
void write(size_t buffer, size_t n_bytes, unsigned long long offset);
1516
void synchronize();
1617
~AsyncFileWriter();

include/backend.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
#include "asyncio.h"
22
#include <string>
3+
#include <algorithm>
4+
#include <cctype>
35
#include <unordered_set>
6+
#include <cstdlib>
7+
#include <iostream>
48

59
std::unordered_set<std::string> get_backends();
610

711
bool probe_backend(const std::string &backend);
812

9-
AsyncIO *create_asyncio(unsigned int n_entries, const std::string &backend);
13+
std::string get_default_backend();
14+
15+
bool get_debug_flag();
16+
17+
AsyncIO *create_asyncio(unsigned int n_entries, std::string backend);

include/offload.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
class Offloader
1515
{
1616
public:
17-
Offloader(const std::string &filename, unsigned int n_entries, const std::string &backend = "uring");
17+
Offloader(const std::string &filename, unsigned int n_entries, const std::string &backend);
1818
SpaceInfo prepare_write(const at::Tensor &tensor, const std::string &key);
1919
SpaceInfo prepare_read(const at::Tensor &tensor, const std::string &key);
2020
void async_write(const at::Tensor &tensor, const std::string &key, callback_t callback = nullptr);

tensornvme/_C/__init__.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Set
33
from torch import Tensor
44

55
class Offloader:
6-
def __init__(self, filename: str, n_entries: int, backend: str = "uring") -> None: ...
6+
def __init__(self, filename: str, n_entries: int, backend: str = "aio") -> None: ...
77
def async_write(self, tensor: Tensor, key: str, callback: Optional[Callable[[], None]] = None) -> None: ...
88
def async_read(self, tensor: Tensor, key: str, callback: Optional[Callable[[], None]] = None) -> None: ...
99
def sync_write(self, tensor: Tensor, key: str) -> None: ...
@@ -20,6 +20,6 @@ def get_backends() -> Set[str]: ...
2020
def probe_backend(backend: str) -> bool: ...
2121

2222
class AsyncFileWriter:
23-
def __init__(self, fd: int, n_entries: int) -> None: ...
23+
def __init__(self, fd: int, n_entries: int, backend: str = "aio") -> None: ...
2424
def write(self, buffer, n_bytes: int, offset: int) -> None: ...
2525
def synchronize(self) -> None: ...

tensornvme/async_file_io.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55

66

77
class AsyncFileWriter:
8-
def __init__(self, fp: IOBase, n_entries: int = 16) -> None:
8+
def __init__(self, fp: IOBase, n_entries: int = 16, backend=None) -> None:
99
fd = fp.fileno()
10-
self.io = AsyncFileWriterC(fd, n_entries)
10+
if backend is not None:
11+
self.io = AsyncFileWriterC(fd, n_entries, backend=backend)
12+
else:
13+
self.io = AsyncFileWriterC(fd, n_entries)
1114
self.fp = fp
1215
self.offset = 0
1316
# must ensure the data is not garbage collected

tensornvme/offload.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
class DiskOffloader(Offloader):
99
def __init__(self, dir_name: str, n_entries: int = 16, backend: str = 'uring') -> None:
10-
assert backend in get_backends(
11-
), f'Unsupported backend: {backend}, please install tensornvme with this backend'
1210
if not os.path.exists(dir_name):
1311
os.mkdir(dir_name)
1412
assert os.path.isdir(dir_name)

0 commit comments

Comments
 (0)