Skip to content

Commit 4164e6e

Browse files
committed
net/tls: Wait for data_{source,sink}::close()
Fixes scylladb#799 data_{source,sink}::close() return a future. If it is not ready on close() return, then the current tls session close() may result in use after free. Converting close_after_shutdown() to a coroutine and sequentially co_awaiting on close() addresses this issue. The waiting is done sequentially, as this is shutdown path anyway.
1 parent a2cb707 commit 4164e6e

File tree

1 file changed

+25
-23
lines changed

1 file changed

+25
-23
lines changed

src/net/tls.cc

+25-23
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ module;
2424
#endif
2525

2626
#include <any>
27+
#include <coroutine>
2728
#include <filesystem>
2829
#include <stdexcept>
2930
#include <system_error>
@@ -129,7 +130,7 @@ static future<file_result> read_fully(const sstring& name, const sstring& what)
129130
return do_with(std::move(f), [name = std::move(name)](file& f) mutable {
130131
return f.stat().then([&f, name = std::move(name)](struct stat s) mutable {
131132
return f.dma_read_bulk<char>(0, s.st_size).then([s, name = std::move(name)](temporary_buffer<char> buf) mutable {
132-
return file_result{ std::move(buf), file_info{
133+
return file_result{ std::move(buf), file_info{
133134
std::move(name), std::chrono::system_clock::from_time_t(s.st_mtim.tv_sec) +
134135
std::chrono::duration_cast<std::chrono::system_clock::duration>(std::chrono::nanoseconds(s.st_mtim.tv_nsec))
135136
} };
@@ -223,14 +224,14 @@ class tls::dh_params::impl : gnutlsobj {
223224
return dh_ptr(params, &gnutls_dh_params_deinit);
224225
}
225226
public:
226-
impl(dh_ptr p)
227-
: _params(std::move(p))
227+
impl(dh_ptr p)
228+
: _params(std::move(p))
228229
{}
229230
impl(level lvl)
230231
#if GNUTLS_VERSION_NUMBER >= 0x030506
231232
: _params(nullptr, &gnutls_dh_params_deinit)
232233
, _sec_param(to_gnutls_level(lvl))
233-
#else
234+
#else
234235
: impl([&] {
235236
auto bits = gnutls_sec_param_to_pk_bits(GNUTLS_PK_DH, to_gnutls_level(lvl));
236237
auto ptr = new_dh_params();
@@ -245,14 +246,14 @@ class tls::dh_params::impl : gnutlsobj {
245246
blob_wrapper w(pkcs3);
246247
gtls_chk(gnutls_dh_params_import_pkcs3(ptr.get(), &w, gnutls_x509_crt_fmt_t(fmt)));
247248
return ptr;
248-
}())
249+
}())
249250
{}
250251
impl(const impl& v)
251252
: impl([&v] {
252253
auto ptr = new_dh_params();
253254
gtls_chk(gnutls_dh_params_cpy(ptr.get(), v));
254255
return ptr;
255-
}())
256+
}())
256257
{}
257258
~impl() = default;
258259

@@ -854,9 +855,9 @@ class tls::reloadable_credentials_base {
854855
auto i = _watches.find(e.id);
855856
if (i != _watches.end()) {
856857
auto& filename = i->second.second;
857-
// only add actual file watches to
858+
// only add actual file watches to
858859
// query set. If this was a directory
859-
// watch, the file should already be
860+
// watch, the file should already be
860861
// in there.
861862
if (_all_files.count(filename)) {
862863
_files[filename] = e.mask;
@@ -918,7 +919,7 @@ class tls::reloadable_credentials_base {
918919
} catch (...) {
919920
if (std::any_of(_files.begin(), _files.end(), [](auto& p) { return p.second == fsnotifier::flags::ignored; })) {
920921
// if any file in the reload set was deleted - i.e. we have not seen a "closed" yet - assume
921-
// this is a spurious reload and we'd better wait for next event - hopefully a "closed" -
922+
// this is a spurious reload and we'd better wait for next event - hopefully a "closed" -
922923
// and try again
923924
return;
924925
}
@@ -931,7 +932,7 @@ class tls::reloadable_credentials_base {
931932
}
932933
void on_success() {
933934
_files.clear();
934-
// remove all directory watches, since we've successfully
935+
// remove all directory watches, since we've successfully
935936
// reloaded -> the file watches themselves should suffice now
936937
auto i = _watches.begin();
937938
auto e = _watches.end();
@@ -967,7 +968,7 @@ class tls::reloadable_credentials_base {
967968
future<fsnotifier::watch_token> add_watch(const sstring& filename, fsnotifier::flags flags = fsnotifier::flags::close_write|fsnotifier::flags::delete_self) {
968969
return _fsn.create_watch(filename, flags).then([this, filename = filename](fsnotifier::watch w) {
969970
auto t = w.token();
970-
// we might create multiple watches for same token in case of dirs, avoid deleting previously
971+
// we might create multiple watches for same token in case of dirs, avoid deleting previously
971972
// created one
972973
if (_watches.count(t)) {
973974
w.release();
@@ -1086,15 +1087,15 @@ class session : public enable_lw_shared_from_this<session> {
10861087
}
10871088
// Maybe set up server session ticket support
10881089
switch (_creds->get_session_resume_mode()) {
1089-
case session_resume_mode::NONE:
1090+
case session_resume_mode::NONE:
10901091
default:
10911092
break;
10921093
case session_resume_mode::TLS13_SESSION_TICKET:
10931094
gnutls_session_ticket_enable_server(*this, _creds->get_session_resume_key());
10941095
break;
10951096
}
10961097
}
1097-
1098+
10981099
auto prio = _creds->get_priority();
10991100
if (prio) {
11001101
gtls_chk(gnutls_priority_set(*this, prio));
@@ -1307,7 +1308,7 @@ class session : public enable_lw_shared_from_this<session> {
13071308
ss << stat_str;
13081309
if (stat_str.back() != ' ') {
13091310
ss << ' ';
1310-
}
1311+
}
13111312
ss << "(Issuer=[" << dn->issuer << "], Subject=[" << dn->subject << "])";
13121313
stat_str = ss.str();
13131314
}
@@ -1584,8 +1585,8 @@ class session : public enable_lw_shared_from_this<session> {
15841585
std::bind(&session::do_shutdown, this)).then(
15851586
std::bind(&session::wait_for_eof, this)).finally([me = shared_from_this()] {});
15861587
// note moved finally clause above. It is theorethically possible
1587-
// that we could complete do_shutdown just before the close calls
1588-
// below, get pre-empted, have "close()" finish, get freed, and
1588+
// that we could complete do_shutdown just before the close calls
1589+
// below, get pre-empted, have "close()" finish, get freed, and
15891590
// then call wait_for_eof on stale pointer.
15901591
}
15911592
void close() noexcept {
@@ -1604,17 +1605,18 @@ class session : public enable_lw_shared_from_this<session> {
16041605
future<> close_after_shutdown() {
16051606
_eof = true;
16061607
try {
1607-
(void)_in.close().handle_exception([](std::exception_ptr) {}); // should wake any waiters
1608+
co_await _in.close(); // should wake any waiters
16081609
} catch (...) {
16091610
}
16101611
try {
1611-
(void)_out.close().handle_exception([](std::exception_ptr) {});
1612+
co_await _out.close();
16121613
} catch (...) {
16131614
}
1615+
16141616
// make sure to wait for handshake attempt to leave semaphores. Must be in same order as
16151617
// handshake aqcuire, because in worst case, we get here while a reader is attempting
16161618
// re-handshake.
1617-
return with_semaphore(_in_sem, 1, [this] {
1619+
co_await with_semaphore(_in_sem, 1, [this] {
16181620
return with_semaphore(_out_sem, 1, [] {});
16191621
});
16201622
}
@@ -1659,13 +1661,13 @@ class session : public enable_lw_shared_from_this<session> {
16591661
future<session_data> get_session_resume_data() {
16601662
return state_checked_access([this] {
16611663
/**
1662-
* Session ticket data is not available just because handshake
1664+
* Session ticket data is not available just because handshake
16631665
* was done. First off, of course other part must support it,
16641666
* but we also (mostly?) need to actually transfer data before
16651667
* the ticket is received.
1666-
*
1668+
*
16671669
* Check session flags so we can return no data in the case
1668-
* none is avail. Gnutls returns a 4-byte "empty marker"
1670+
* none is avail. Gnutls returns a 4-byte "empty marker"
16691671
* on none avail.
16701672
*/
16711673
auto flags = gnutls_session_get_flags(*this);
@@ -1681,7 +1683,7 @@ class session : public enable_lw_shared_from_this<session> {
16811683
return state_checked_access([this] {
16821684
return extract_dn_information();
16831685
});
1684-
}
1686+
}
16851687
future<std::vector<subject_alt_name>> get_alt_name_information(std::unordered_set<subject_alt_name_type> types) {
16861688
return state_checked_access([this](std::unordered_set<subject_alt_name_type> types) {
16871689
std::vector<subject_alt_name> res;

0 commit comments

Comments
 (0)