Skip to content

Commit

Permalink
Fix usrsctp usage in Rust
Browse files Browse the repository at this point in the history
**WIP**

Fixes #1352

### Details

- Basically as described in the ticket. But not everything is done at all.
- Also, I'm testing this in Node by using UV async stuff (which doesn't make sense in mediasoup for Node but anyway).

### TODO

- None of these changes should take effect when in Node, so we need to pass (or to NOT pass) some `define` only from Rust to enable this in the C++ code. We don't want to deal with UV async stuff when in Node because it's not needed at all, so let's see how to do it.

- Missing thread X to initialize usrsctp and run the `Checker` singleton. And many other things.

- Crash when a `SctpAssociation` is closed. I think it's because somehow the `onAsync` callback is invoked asynchronously (of course) so when it calls `sctpAssociation->OnUsrSctpSendSctpData()` it happens that such a `SctpAssociation` has already been freed. Not sure how to resolve it. Here the logs:
  ```
  mediasoup:Transport close() +18s
  mediasoup:Channel request() [method:ROUTER_CLOSE_TRANSPORT] +8s
  mediasoup:Producer transportClosed() +19s
  mediasoup:DataProducer transportClosed() +18s
  mediasoup:DataProducer transportClosed() +0ms
  mediasoup:Transport close() +1ms
  mediasoup:Channel request() [method:ROUTER_CLOSE_TRANSPORT] +1ms
  mediasoup:Consumer transportClosed() +19s
  mediasoup:DataConsumer transportClosed() +18s
  mediasoup:DataConsumer transportClosed() +1ms
  mediasoup:Channel [pid:98040] RTC::SctpAssociation::ResetSctpStream() | SCTP_RESET_STREAMS sent [streamId:1] +1ms
  mediasoup:Channel request succeeded [method:ROUTER_CLOSE_TRANSPORT, id:39] +0ms
  DepUsrSCTP::onAsync() | ---------- onAsync!!
  DepUsrSCTP::onAsync() | ---------- onAsync, sending SCTP data!!
  mediasoup:Channel Producer Channel ended by the worker process +1ms
  mediasoup:ERROR:Worker worker process died unexpectedly [pid:98040, code:null, signal:SIGSEGV] +0ms
  ```
  • Loading branch information
ibc committed Mar 5, 2024
1 parent af7fb8c commit b85120f
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 12 deletions.
11 changes: 11 additions & 0 deletions worker/include/DepUsrSCTP.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@

class DepUsrSCTP
{
public:
struct SendSctpDataStore
{
RTC::SctpAssociation* sctpAssociation;
uint8_t* data;
size_t len;
};

private:
class Checker : public TimerHandle::Listener
{
Expand Down Expand Up @@ -37,12 +45,15 @@ class DepUsrSCTP
static void RegisterSctpAssociation(RTC::SctpAssociation* sctpAssociation);
static void DeregisterSctpAssociation(RTC::SctpAssociation* sctpAssociation);
static RTC::SctpAssociation* RetrieveSctpAssociation(uintptr_t id);
static void SendSctpData(RTC::SctpAssociation* sctpAssociation, uint8_t* data, size_t len);
static SendSctpDataStore* GetSendSctpDataStore(uv_async_t* handle);

private:
thread_local static Checker* checker;
static uint64_t numSctpAssociations;
static uintptr_t nextSctpAssociationId;
static absl::flat_hash_map<uintptr_t, RTC::SctpAssociation*> mapIdSctpAssociation;
static absl::flat_hash_map<const uv_async_t*, SendSctpDataStore> mapAsyncHandlerSendSctpData;
};

#endif
11 changes: 9 additions & 2 deletions worker/include/RTC/SctpAssociation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "RTC/DataConsumer.hpp"
#include "RTC/DataProducer.hpp"
#include <usrsctp.h>
#include <uv.h>

namespace RTC
{
Expand Down Expand Up @@ -80,7 +81,11 @@ namespace RTC
public:
flatbuffers::Offset<FBS::SctpParameters::SctpParameters> FillBuffer(
flatbuffers::FlatBufferBuilder& builder) const;
void TransportConnected();
uv_async_t* GetAsyncHandle() const
{
return this->uvAsyncHandle;
}
void InitializeSyncHandle(uv_async_cb callback);
SctpState GetState() const
{
return this->state;
Expand All @@ -89,6 +94,7 @@ namespace RTC
{
return this->sctpBufferedAmount;
}
void TransportConnected();
void ProcessSctpData(const uint8_t* data, size_t len) const;
void SendSctpMessage(
RTC::DataConsumer* dataConsumer,
Expand All @@ -106,7 +112,7 @@ namespace RTC

/* Callbacks fired by usrsctp events. */
public:
void OnUsrSctpSendSctpData(void* buffer, size_t len);
void OnUsrSctpSendSctpData(uint8_t* data, size_t len);
void OnUsrSctpReceiveSctpData(
uint16_t streamId, uint16_t ssn, uint32_t ppid, int flags, const uint8_t* data, size_t len);
void OnUsrSctpReceiveSctpNotification(union sctp_notification* notification, size_t len);
Expand All @@ -125,6 +131,7 @@ namespace RTC
size_t sctpBufferedAmount{ 0u };
bool isDataChannel{ false };
// Allocated by this.
uv_async_t* uvAsyncHandle{ nullptr };
uint8_t* messageBuffer{ nullptr };
// Others.
SctpState state{ SctpState::NEW };
Expand Down
107 changes: 100 additions & 7 deletions worker/src/DepUsrSCTP.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#define MS_CLASS "DepUsrSCTP"
// #define MS_LOG_DEV_LEVEL 3
#define MS_LOG_DEV_LEVEL 3

#include "DepUsrSCTP.hpp"
#ifdef MS_LIBURING_SUPPORTED
Expand All @@ -8,7 +8,8 @@
#include "DepLibUV.hpp"
#include "Logger.hpp"
#include <usrsctp.h>
#include <cstdio> // std::vsnprintf()
#include <cstdio> // std::vsnprintf()
#include <cstring> // std::memcpy()
#include <mutex>

/* Static. */
Expand All @@ -17,10 +18,40 @@ static constexpr size_t CheckerInterval{ 10u }; // In ms.
static std::mutex GlobalSyncMutex;
static size_t GlobalInstances{ 0u };

/* Static methods for UV callbacks. */

inline static void onAsync(uv_async_t* handle)
{
MS_TRACE();
MS_DUMP("---------- onAsync!!");

const std::lock_guard<std::mutex> lock(GlobalSyncMutex);

// Get the sending data from the map.
auto* store = DepUsrSCTP::GetSendSctpDataStore(handle);

if (!store)
{
MS_WARN_DEV("store not found");

return;
}

auto* sctpAssociation = store->sctpAssociation;
auto* data = store->data;
auto len = store->len;

MS_DUMP("---------- onAsync, sending SCTP data!!");

sctpAssociation->OnUsrSctpSendSctpData(data, len);
}

/* Static methods for usrsctp global callbacks. */

inline static int onSendSctpData(void* addr, void* data, size_t len, uint8_t /*tos*/, uint8_t /*setDf*/)
{
MS_TRACE();

auto* sctpAssociation = DepUsrSCTP::RetrieveSctpAssociation(reinterpret_cast<uintptr_t>(addr));

if (!sctpAssociation)
Expand All @@ -30,7 +61,7 @@ inline static int onSendSctpData(void* addr, void* data, size_t len, uint8_t /*t
return -1;
}

sctpAssociation->OnUsrSctpSendSctpData(data, len);
DepUsrSCTP::SendSctpData(sctpAssociation, static_cast<uint8_t*>(data), len);

// NOTE: Must not free data, usrsctp lib does it.

Expand Down Expand Up @@ -60,6 +91,7 @@ thread_local DepUsrSCTP::Checker* DepUsrSCTP::checker{ nullptr };
uint64_t DepUsrSCTP::numSctpAssociations{ 0u };
uintptr_t DepUsrSCTP::nextSctpAssociationId{ 0u };
absl::flat_hash_map<uintptr_t, RTC::SctpAssociation*> DepUsrSCTP::mapIdSctpAssociation;
absl::flat_hash_map<const uv_async_t*, DepUsrSCTP::SendSctpDataStore> DepUsrSCTP::mapAsyncHandlerSendSctpData;

/* Static methods. */

Expand Down Expand Up @@ -91,6 +123,7 @@ void DepUsrSCTP::ClassDestroy()
MS_TRACE();

const std::lock_guard<std::mutex> lock(GlobalSyncMutex);

--GlobalInstances;

if (GlobalInstances == 0)
Expand All @@ -101,6 +134,7 @@ void DepUsrSCTP::ClassDestroy()
nextSctpAssociationId = 0u;

DepUsrSCTP::mapIdSctpAssociation.clear();
DepUsrSCTP::mapAsyncHandlerSendSctpData.clear();
}
}

Expand Down Expand Up @@ -158,13 +192,20 @@ void DepUsrSCTP::RegisterSctpAssociation(RTC::SctpAssociation* sctpAssociation)

MS_ASSERT(DepUsrSCTP::checker != nullptr, "Checker not created");

auto it = DepUsrSCTP::mapIdSctpAssociation.find(sctpAssociation->id);
auto it = DepUsrSCTP::mapIdSctpAssociation.find(sctpAssociation->id);
auto it2 = DepUsrSCTP::mapAsyncHandlerSendSctpData.find(sctpAssociation->GetAsyncHandle());

MS_ASSERT(
it == DepUsrSCTP::mapIdSctpAssociation.end(),
"the id of the SctpAssociation is already in the map");
"the id of the SctpAssociation is already in the mapIdSctpAssociation map");
MS_ASSERT(
it2 == DepUsrSCTP::mapAsyncHandlerSendSctpData.end(),
"the id of the SctpAssociation is already in the mapAsyncHandlerSendSctpData map");

DepUsrSCTP::mapIdSctpAssociation[sctpAssociation->id] = sctpAssociation;
DepUsrSCTP::mapAsyncHandlerSendSctpData[sctpAssociation->GetAsyncHandle()];

sctpAssociation->InitializeSyncHandle(onAsync);

if (++DepUsrSCTP::numSctpAssociations == 1u)
{
Expand All @@ -180,9 +221,11 @@ void DepUsrSCTP::DeregisterSctpAssociation(RTC::SctpAssociation* sctpAssociation

MS_ASSERT(DepUsrSCTP::checker != nullptr, "Checker not created");

auto found = DepUsrSCTP::mapIdSctpAssociation.erase(sctpAssociation->id);
auto found1 = DepUsrSCTP::mapIdSctpAssociation.erase(sctpAssociation->id);
auto found2 = DepUsrSCTP::mapAsyncHandlerSendSctpData.erase(sctpAssociation->GetAsyncHandle());

MS_ASSERT(found > 0, "SctpAssociation not found");
MS_ASSERT(found1 > 0, "SctpAssociation not found in mapIdSctpAssociation map");
MS_ASSERT(found2 > 0, "SctpAssociation not found in mapAsyncHandlerSendSctpData map");
MS_ASSERT(DepUsrSCTP::numSctpAssociations > 0u, "numSctpAssociations was not higher than 0");

if (--DepUsrSCTP::numSctpAssociations == 0u)
Expand All @@ -207,6 +250,56 @@ RTC::SctpAssociation* DepUsrSCTP::RetrieveSctpAssociation(uintptr_t id)
return it->second;
}

void DepUsrSCTP::SendSctpData(RTC::SctpAssociation* sctpAssociation, uint8_t* data, size_t len)
{
MS_TRACE();

const std::lock_guard<std::mutex> lock(GlobalSyncMutex);

// Store the sending data into the map.

auto it = DepUsrSCTP::mapAsyncHandlerSendSctpData.find(sctpAssociation->GetAsyncHandle());

MS_ASSERT(
it != DepUsrSCTP::mapAsyncHandlerSendSctpData.end(),
"SctpAssociation not found in mapAsyncHandlerSendSctpData map");

SendSctpDataStore& store = it->second;

// NOTE: In Rust, DepUsrSCTP::SendSctpData() is called from onSendSctpData()
// callback from a different thread and usrsctp immediately frees |data| when
// the callback execution finishes. So we have to mem copy it.
store.sctpAssociation = sctpAssociation;
store.data = new uint8_t[len];
store.len = len;

std::memcpy(store.data, data, len);

// Invoke UV async send.
int err = uv_async_send(sctpAssociation->GetAsyncHandle());

if (err != 0)
{
MS_WARN_TAG(sctp, "uv_async_send() failed: %s", uv_strerror(err));
}
}

DepUsrSCTP::SendSctpDataStore* DepUsrSCTP::GetSendSctpDataStore(uv_async_t* handle)
{
MS_TRACE();

auto it = DepUsrSCTP::mapAsyncHandlerSendSctpData.find(handle);

if (it == DepUsrSCTP::mapAsyncHandlerSendSctpData.end())
{
return nullptr;
}

SendSctpDataStore& store = it->second;

return std::addressof(store);
}

/* DepUsrSCTP::Checker instance methods. */

DepUsrSCTP::Checker::Checker() : timer(new TimerHandle(this))
Expand Down
21 changes: 18 additions & 3 deletions worker/src/RTC/SctpAssociation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// #define MS_LOG_DEV_LEVEL 3

#include "RTC/SctpAssociation.hpp"
#include "DepLibUV.hpp"
#include "DepUsrSCTP.hpp"
#include "Logger.hpp"
#include "MediaSoupErrors.hpp"
Expand Down Expand Up @@ -121,6 +122,9 @@ namespace RTC
{
MS_TRACE();

// Create a uv_async_t handle.
this->uvAsyncHandle = new uv_async_t;

// Register ourselves in usrsctp.
// NOTE: This must be done before calling usrsctp_bind().
usrsctp_register_address(reinterpret_cast<void*>(this->id));
Expand Down Expand Up @@ -293,6 +297,7 @@ namespace RTC
// Register the SctpAssociation from the global map.
DepUsrSCTP::DeregisterSctpAssociation(this);

delete this->uvAsyncHandle;
delete[] this->messageBuffer;
}

Expand Down Expand Up @@ -381,6 +386,18 @@ namespace RTC
this->isDataChannel);
}

void SctpAssociation::InitializeSyncHandle(uv_async_cb callback)
{
MS_TRACE();

int err = uv_async_init(DepLibUV::GetLoop(), this->uvAsyncHandle, callback);

if (err != 0)
{
MS_ABORT("uv_async_init() failed: %s", uv_strerror(err));
}
}

void SctpAssociation::ProcessSctpData(const uint8_t* data, size_t len) const
{
MS_TRACE();
Expand Down Expand Up @@ -667,12 +684,10 @@ namespace RTC
}
}

void SctpAssociation::OnUsrSctpSendSctpData(void* buffer, size_t len)
void SctpAssociation::OnUsrSctpSendSctpData(uint8_t* data, size_t len)
{
MS_TRACE();

const uint8_t* data = static_cast<uint8_t*>(buffer);

#if MS_LOG_DEV_LEVEL == 3
MS_DUMP_DATA(data, len);
#endif
Expand Down

0 comments on commit b85120f

Please sign in to comment.