19
19
20
20
namespace ygm {
21
21
22
- class comm ::impl {
22
+ class comm ::impl : public std::enable_shared_from_this<comm::impl> {
23
23
public:
24
24
impl (MPI_Comm c, int buffer_capacity) {
25
25
ASSERT_MPI (MPI_Comm_dup (c, &m_comm_async));
@@ -35,7 +35,6 @@ class comm::impl {
35
35
}
36
36
37
37
~impl () {
38
- barrier ();
39
38
// send kill signal to self (listener thread)
40
39
ASSERT_RELEASE (MPI_Send (NULL , 0 , MPI_BYTE, m_comm_rank, 0 , m_comm_async) ==
41
40
MPI_SUCCESS);
@@ -399,12 +398,12 @@ class comm::impl {
399
398
std::forward<const PackArgs>(args)...);
400
399
ASSERT_DEBUG (sizeof (Lambda) == 1 );
401
400
402
- void (*fun_ptr)(impl *, cereal::YGMInputArchive &) =
403
- [](impl *t , cereal::YGMInputArchive &bia) {
401
+ void (*fun_ptr)(comm *, cereal::YGMInputArchive &) =
402
+ [](comm *c , cereal::YGMInputArchive &bia) {
404
403
std::tuple<PackArgs...> ta;
405
404
bia (ta);
406
405
Lambda *pl = nullptr ;
407
- auto t1 = std::make_tuple ((impl *)t );
406
+ auto t1 = std::make_tuple ((comm *)c );
408
407
409
408
// \pp was: std::apply(*pl, std::tuple_cat(t1, ta));
410
409
ygm::meta::apply_optional (*pl, std::move (t1), std::move (ta));
@@ -430,6 +429,7 @@ class comm::impl {
430
429
*/
431
430
bool process_receive_queue () {
432
431
bool received = false ;
432
+ comm tmp_comm (shared_from_this ());
433
433
while (true ) {
434
434
auto buffer = receive_queue_try_pop ();
435
435
if (buffer.second == 0 ) break ;
@@ -439,9 +439,9 @@ class comm::impl {
439
439
int64_t iptr;
440
440
iarchive (iptr);
441
441
iptr += (int64_t )&reference;
442
- void (*fun_ptr)(impl *, cereal::YGMInputArchive &);
442
+ void (*fun_ptr)(comm *, cereal::YGMInputArchive &);
443
443
memcpy (&fun_ptr, &iptr, sizeof (uint64_t ));
444
- fun_ptr (this , iarchive);
444
+ fun_ptr (&tmp_comm , iarchive);
445
445
m_recv_count++;
446
446
m_local_rpc_calls++;
447
447
}
@@ -496,10 +496,13 @@ inline comm::comm(MPI_Comm mcomm, int buffer_capacity = 16 * 1024 * 1024) {
496
496
pimpl = std::make_shared<comm::impl>(mcomm, buffer_capacity);
497
497
}
498
498
499
+ inline comm::comm (std::shared_ptr<impl> impl_ptr) : pimpl(impl_ptr) {}
500
+
499
501
inline comm::~comm () {
500
- ASSERT_RELEASE (MPI_Barrier (MPI_COMM_WORLD) == MPI_SUCCESS);
502
+ if (pimpl.use_count () == 1 ) {
503
+ barrier ();
504
+ }
501
505
pimpl.reset ();
502
- ASSERT_RELEASE (MPI_Barrier (MPI_COMM_WORLD) == MPI_SUCCESS);
503
506
pimpl_if.reset ();
504
507
}
505
508
0 commit comments