Skip to content

Commit 5eed5fb

Browse files
committed
[environ] move environ overwrite to create_asyncio
1 parent e28a80a commit 5eed5fb

File tree

4 files changed

+23
-26
lines changed

4 files changed

+23
-26
lines changed

csrc/backend.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <cstring>
55
#include <cassert>
66
#include <memory>
7+
#include <iostream>
78

89
#ifndef DISABLE_URING
910
#include "uring.h"
@@ -119,11 +120,26 @@ bool probe_backend(const std::string &backend)
119120
}
120121
}
121122

122-
AsyncIO *create_asyncio(unsigned int n_entries, const std::string &backend)
123+
std::string get_default_backend() {
124+
const char* env = getenv("TENSORNVME_BACKEND");
125+
if (env == nullptr) {
126+
return std::string("");
127+
}
128+
return std::string(env);
129+
}
130+
131+
AsyncIO *create_asyncio(unsigned int n_entries, std::string backend)
123132
{
124133
std::unordered_set<std::string> backends = get_backends();
125134
if (backends.empty())
126135
throw std::runtime_error("No asyncio backend is installed");
136+
137+
std::string default_backend = get_default_backend();
138+
if (default_backend.size() > 0) {
139+
backend = default_backend;
140+
}
141+
std::cout << "current backend: " << backend << std::endl;
142+
127143
if (backends.find(backend) == backends.end())
128144
throw std::runtime_error("Unsupported backend: " + backend);
129145
if (!probe_backend(backend))

csrc/offload.cpp

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,9 @@ iovec *tensors_to_iovec(const std::vector<at::Tensor> &tensors)
2626
return iovs;
2727
}
2828

29-
std::string Offloader::get_default_backend() {
30-
const char* env = getenv("TENSORNVME_BACKEND");
31-
if (env == nullptr) {
32-
return std::string("");
33-
}
34-
return std::string(env);
35-
}
36-
3729
Offloader::Offloader(const std::string &filename, unsigned int n_entries, const std::string &backend) : filename(filename), space_mgr(SpaceManager(0))
38-
{
39-
std::string default_backend = get_default_backend();
40-
if (default_backend.size() > 0) {
41-
if (get_backends().count(default_backend) == 0) {
42-
throw std::runtime_error("Cannot find backend: " + default_backend + ", please check if TENSORNVME_BACKEND is set correctly");
43-
}
44-
this->aio = create_asyncio(n_entries, default_backend);
45-
} else {
46-
if (get_backends().count(backend) == 0) {
47-
throw std::runtime_error("Cannot find backend: " + backend + ", please check the passed backend is set correctly");
48-
}
49-
this->aio = create_asyncio(n_entries, backend);
50-
}
51-
30+
{
31+
this->aio = create_asyncio(n_entries, backend);
5232
this->fd = open(filename.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
5333
this->aio->register_file(fd);
5434
}

include/backend.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
#include "asyncio.h"
22
#include <string>
33
#include <unordered_set>
4+
#include <cstdlib>
45

56
std::unordered_set<std::string> get_backends();
67

78
bool probe_backend(const std::string &backend);
89

9-
AsyncIO *create_asyncio(unsigned int n_entries, const std::string &backend);
10+
std::string get_default_backend();
11+
12+
AsyncIO *create_asyncio(unsigned int n_entries, std::string backend);

include/offload.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include "aio.h"
1212
#endif
1313

14-
#include <cstdlib>
1514
class Offloader
1615
{
1716
public:
@@ -32,7 +31,6 @@ class Offloader
3231
void async_readv(const std::vector<at::Tensor> &tensors, const std::string &key, callback_t callback = nullptr);
3332
void sync_writev(const std::vector<at::Tensor> &tensors, const std::string &key);
3433
void sync_readv(const std::vector<at::Tensor> &tensors, const std::string &key);
35-
static std::string get_default_backend();
3634
private:
3735
const std::string filename;
3836
int fd;

0 commit comments

Comments
 (0)