1
1
// single thread right now.
2
+ #include " infinistore.h"
3
+
2
4
#include < arpa/inet.h>
3
5
#include < assert.h>
4
6
#include < cuda.h>
11
13
#include < sys/socket.h>
12
14
#include < time.h>
13
15
#include < unistd.h>
14
- #include < uv.h>
15
16
16
17
#include < boost/lockfree/spsc_queue.hpp>
17
18
#include < chrono>
21
22
#include < string>
22
23
#include < unordered_map>
23
24
24
- #include " config.h"
25
25
#include " ibv_helper.h"
26
- #include " log.h"
27
- #include " mempool.h"
28
26
#include " protocol.h"
29
- #include " utils.h"
30
27
31
28
server_config_t global_config;
32
29
@@ -48,22 +45,6 @@ bool extend_in_flight = false;
48
45
// indicate the number of cudaIpcOpenMemHandle
49
46
std::atomic<unsigned int > opened_ipc{0 };
50
47
51
- // PTR is shared by kv_map and inflight_rdma_kv_map
52
- class PTR : public IntrusivePtrTarget {
53
- public:
54
- void *ptr;
55
- size_t size;
56
- int pool_idx;
57
- bool committed;
58
- PTR (void *ptr, size_t size, int pool_idx, bool committed = false )
59
- : ptr(ptr), size(size), pool_idx(pool_idx), committed(committed) {}
60
- ~PTR () {
61
- if (ptr) {
62
- DEBUG (" deallocate ptr: {}, size: {}, pool_idx: {}" , ptr, size, pool_idx);
63
- mm->deallocate (ptr, size, pool_idx);
64
- }
65
- }
66
- };
67
48
68
49
enum CUDA_TASK_TYPE {
69
50
CUDA_READ,
@@ -80,10 +61,9 @@ struct CUDA_TASK {
80
61
cudaEvent_t event;
81
62
};
82
63
83
- std::unordered_map<uintptr_t , boost::intrusive_ptr<PTR>> inflight_rdma_kv_map;
84
- std::unordered_map<std::string, boost::intrusive_ptr<PTR>> kv_map;
64
+ std::unordered_map<uintptr_t , boost::intrusive_ptr<PTR>> inflight_rdma_writes;
85
65
86
- int get_kvmap_len () { return kv_map. size (); }
66
+ std::unordered_map<std::string, boost::intrusive_ptr<PTR>> kv_map;
87
67
88
68
typedef enum {
89
69
READ_HEADER,
@@ -130,11 +110,6 @@ struct Client {
130
110
131
111
uv_poll_t poll_handle_;
132
112
133
- struct block {
134
- uint32_t lkey;
135
- uintptr_t local_addr;
136
- };
137
-
138
113
Client () = default ;
139
114
Client (const Client &) = delete ;
140
115
~Client ();
@@ -173,6 +148,7 @@ Client::~Client() {
173
148
if (poll_handle_.data ) {
174
149
uv_poll_stop (&poll_handle_);
175
150
}
151
+
176
152
if (handle_) {
177
153
free (handle_);
178
154
handle_ = NULL ;
@@ -274,15 +250,15 @@ void Client::cq_poll_handle(uv_poll_t *handle, int status, int events) {
274
250
ERROR (" remote_addrs size should not be 0" );
275
251
}
276
252
for (auto addr : *request->remote_addrs ()) {
277
- auto it = inflight_rdma_kv_map .find (addr);
278
- if (it == inflight_rdma_kv_map .end ()) {
253
+ auto it = inflight_rdma_writes .find (addr);
254
+ if (it == inflight_rdma_writes .end ()) {
279
255
ERROR (" commit msg: Key not found: {}" , addr);
280
256
continue ;
281
257
}
282
258
it->second ->committed = true ;
283
- inflight_rdma_kv_map .erase (it);
259
+ inflight_rdma_writes .erase (it);
284
260
}
285
- DEBUG (" inflight_rdma_kv_map size: {}" , inflight_rdma_kv_map .size ());
261
+ DEBUG (" inflight_rdma_kv_map size: {}" , inflight_rdma_writes .size ());
286
262
break ;
287
263
}
288
264
default :
@@ -332,6 +308,12 @@ void Client::cq_poll_handle(uv_poll_t *handle, int status, int events) {
332
308
delete[] sges;
333
309
outstanding_rdma_writes_queue_.pop_front ();
334
310
}
311
+
312
+ if (wc.wr_id > 0 ) {
313
+ // last WR will inform that all RDMA write is finished,so we can dereference PTR
314
+ auto inflight_rdma_reads = (std::vector<boost::intrusive_ptr<PTR>> *)wc.wr_id ;
315
+ delete inflight_rdma_reads;
316
+ }
335
317
}
336
318
else {
337
319
ERROR (" Unexpected wc opcode: {}" , (int )wc.opcode );
@@ -376,7 +358,7 @@ int Client::allocate_rdma(const RemoteMetaRequest *req) {
376
358
377
359
// save in inflight_rdma_kv_map, when write is finished, we can merge it
378
360
// into kv_map
379
- inflight_rdma_kv_map [(uintptr_t )addr] = ptr;
361
+ inflight_rdma_writes [(uintptr_t )addr] = ptr;
380
362
381
363
blocks.push_back (RemoteBlock (rkey, (uint64_t )addr));
382
364
key_idx++;
@@ -439,8 +421,9 @@ int Client::read_rdma_cache(const RemoteMetaRequest *remote_meta_req) {
439
421
return -1 ;
440
422
}
441
423
442
- std::vector<block> blocks;
443
- blocks.reserve (remote_meta_req->keys ()->size ());
424
+ auto *inflight_rdma_reads = new std::vector<boost::intrusive_ptr<PTR>>;
425
+
426
+ inflight_rdma_reads->reserve (remote_meta_req->keys ()->size ());
444
427
445
428
for (const auto *key : *remote_meta_req->keys ()) {
446
429
auto it = kv_map.find (key->str ());
@@ -459,7 +442,7 @@ int Client::read_rdma_cache(const RemoteMetaRequest *remote_meta_req) {
459
442
DEBUG (" rkey: {}, local_addr: {}, size : {}" , mm->get_lkey (ptr->pool_idx ),
460
443
(uintptr_t )ptr->ptr , ptr->size );
461
444
462
- blocks. push_back ({. lkey = mm-> get_lkey (ptr-> pool_idx ), . local_addr = ( uintptr_t )ptr-> ptr } );
445
+ inflight_rdma_reads-> push_back (ptr);
463
446
}
464
447
465
448
const size_t max_wr = MAX_WR_BATCH;
@@ -477,13 +460,12 @@ int Client::read_rdma_cache(const RemoteMetaRequest *remote_meta_req) {
477
460
wrs = new struct ibv_send_wr [max_wr];
478
461
sges = new struct ibv_sge [max_wr];
479
462
}
480
-
481
463
for (size_t i = 0 ; i < remote_meta_req->keys ()->size (); i++) {
482
- sges[num_wr].addr = blocks [i]. local_addr ;
464
+ sges[num_wr].addr = ( uintptr_t )(*inflight_rdma_reads) [i]-> ptr ;
483
465
sges[num_wr].length = remote_meta_req->block_size ();
484
- sges[num_wr].lkey = blocks [i]. lkey ;
466
+ sges[num_wr].lkey = mm-> get_lkey ((*inflight_rdma_reads) [i]-> pool_idx ) ;
485
467
486
- wrs[num_wr].wr_id = i ;
468
+ wrs[num_wr].wr_id = 0 ;
487
469
wrs[num_wr].opcode = (i == remote_meta_req->keys ()->size () - 1 ) ? IBV_WR_RDMA_WRITE_WITH_IMM
488
470
: IBV_WR_RDMA_WRITE;
489
471
wrs[num_wr].sg_list = &sges[num_wr];
@@ -498,6 +480,10 @@ int Client::read_rdma_cache(const RemoteMetaRequest *remote_meta_req) {
498
480
? IBV_SEND_SIGNALED
499
481
: 0 ;
500
482
483
+ if (i == remote_meta_req->keys ()->size () - 1 ) {
484
+ wrs[num_wr].wr_id = (uintptr_t )inflight_rdma_reads;
485
+ }
486
+
501
487
num_wr++;
502
488
503
489
if (num_wr == max_wr || i == remote_meta_req->keys ()->size () - 1 ) {
@@ -617,19 +603,13 @@ int Client::read_cache(const LocalMetaRequest *meta_req) {
617
603
618
604
DEBUG (" key: {}, local_addr: {}, size : {}" , key, (uintptr_t )h_src, block_size);
619
605
620
- task->ptrs .push_back (kv_map[key]);
606
+ task->ptrs .push_back (kv_map[key]); // keep PTR in task as reference count.
621
607
remote_addrs.push_back ((uintptr_t )((char *)d_ptr + block->offset ()));
622
608
idx++;
623
609
}
624
610
625
611
assert (task->ptrs .size () == remote_addrs.size ());
626
612
627
- // for (auto &task : tasks) {
628
- // // CHECK_CUDA(cudaMemcpyAsync((void *)task.dst, task.ptr->ptr, block_size,
629
- // // cudaMemcpyHostToDevice, cuda_streams[task.stream_idx]));
630
-
631
- // }
632
-
633
613
for (int i = 0 ; i < task->ptrs .size (); i++) {
634
614
CHECK_CUDA (cudaMemcpyAsync ((void *)remote_addrs[i], task->ptrs [i]->ptr , block_size,
635
615
cudaMemcpyHostToDevice, cuda_stream));
@@ -711,8 +691,7 @@ void add_mempool_completion(uv_work_t *req, int status) {
711
691
}
712
692
713
693
int Client::write_cache (const LocalMetaRequest *meta_req) {
714
- INFO (" do write_cache..., num of blocks: {}, stream num {}" , meta_req->blocks ()->size (),
715
- global_config.num_stream );
694
+ INFO (" do write_cache..., num of blocks: {}" , meta_req->blocks ()->size ());
716
695
717
696
void *d_ptr;
718
697
int return_code = TASK_ACCEPTED;
@@ -1270,11 +1249,11 @@ void on_new_connection(uv_stream_t *server, int status) {
1270
1249
}
1271
1250
1272
1251
void signal_handler (int signum) {
1273
- void *array[10 ];
1252
+ void *array[32 ];
1274
1253
size_t size;
1275
1254
if (signum == SIGSEGV) {
1276
1255
ERROR (" Caught SIGSEGV: segmentation fault" );
1277
- size = backtrace (array, 10 );
1256
+ size = backtrace (array, 32 );
1278
1257
// print signum's name
1279
1258
ERROR (" Error: signal {}" , signum);
1280
1259
// backtrace_symbols_fd(array, size, STDERR_FILENO);
@@ -1335,4 +1314,4 @@ int register_server(unsigned long loop_ptr, server_config_t config) {
1335
1314
INFO (" register server done" );
1336
1315
1337
1316
return 0 ;
1338
- }
1317
+ }
0 commit comments