Skip to content

Commit

Permalink
webrtc: fix race condition when starting the UDP muxer
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Aug 20, 2023
1 parent a773e70 commit 82a4196
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 19 deletions.
4 changes: 3 additions & 1 deletion p2p/transport/webrtc/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func newListener(transport *WebRTCTransport, laddr ma.Multiaddr, socket net.Pack
}

l.ctx, l.cancel = context.WithCancel(context.Background())
l.mux = udpmux.NewUDPMux(socket, func(ufrag string, addr net.Addr) bool {
mux := udpmux.NewUDPMux(socket, func(ufrag string, addr net.Addr) bool {
select {
case <-inFlightQueueCh:
// we have space to accept, Yihaa
Expand Down Expand Up @@ -149,6 +149,8 @@ func newListener(transport *WebRTCTransport, laddr ma.Multiaddr, socket net.Pack

return true
})
l.mux = mux
mux.Start()

return l, err
}
Expand Down
33 changes: 18 additions & 15 deletions p2p/transport/webrtc/udpmux/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ var log = logging.Logger("webrtc-udpmux")

const ReceiveMTU = 1500

// udpMux multiplexes multiple ICE connections over a single net.PacketConn,
// UDPMux multiplexes multiple ICE connections over a single net.PacketConn,
// generally a UDP socket.
//
// The connections are indexed by (ufrag, IP address family)
Expand All @@ -33,7 +33,7 @@ const ReceiveMTU = 1500
// is a connection associated with the (ufrag, IP address family) pair. If found
// we add the association to the address map. If not found, it is a previously
// unseen IP address and the `unknownUfragCallback` callback is invoked.
type udpMux struct {
type UDPMux struct {
socket net.PacketConn
unknownUfragCallback func(string, net.Addr) bool

Expand All @@ -45,28 +45,31 @@ type udpMux struct {
cancel context.CancelFunc
}

var _ ice.UDPMux = &udpMux{}
var _ ice.UDPMux = &UDPMux{}

func NewUDPMux(socket net.PacketConn, unknownUfragCallback func(string, net.Addr) bool) *udpMux {
func NewUDPMux(socket net.PacketConn, unknownUfragCallback func(string, net.Addr) bool) *UDPMux {
ctx, cancel := context.WithCancel(context.Background())
mux := &udpMux{
mux := &UDPMux{
ctx: ctx,
cancel: cancel,
socket: socket,
unknownUfragCallback: unknownUfragCallback,
storage: newUDPMuxStorage(),
}

return mux
}

func (mux *UDPMux) Start() {
mux.wg.Add(1)
go func() {
defer mux.wg.Done()
mux.readLoop()
}()
return mux
}

// GetListenAddresses implements ice.UDPMux
func (mux *udpMux) GetListenAddresses() []net.Addr {
func (mux *UDPMux) GetListenAddresses() []net.Addr {
return []net.Addr{mux.socket.LocalAddr()}
}

Expand All @@ -76,7 +79,7 @@ func (mux *udpMux) GetListenAddresses() []net.Addr {
// as a remote is capable of being reachable through multiple different
// UDP addresses of the same IP address family (eg. Server-reflexive addresses
// and peer-reflexive addresses).
func (mux *udpMux) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
func (mux *UDPMux) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
a, ok := addr.(*net.UDPAddr)
if !ok && addr != nil {
return nil, fmt.Errorf("unexpected address type: %T", addr)
Expand All @@ -86,7 +89,7 @@ func (mux *udpMux) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error)
}

// Close implements ice.UDPMux
func (mux *udpMux) Close() error {
func (mux *UDPMux) Close() error {
select {
case <-mux.ctx.Done():
return nil
Expand All @@ -99,13 +102,13 @@ func (mux *udpMux) Close() error {
}

// RemoveConnByUfrag implements ice.UDPMux
func (mux *udpMux) RemoveConnByUfrag(ufrag string) {
func (mux *UDPMux) RemoveConnByUfrag(ufrag string) {
if ufrag != "" {
mux.storage.RemoveConnByUfrag(ufrag)
}
}

func (mux *udpMux) getOrCreateConn(ufrag string, isIPv6 bool, addr net.Addr) (net.PacketConn, error) {
func (mux *UDPMux) getOrCreateConn(ufrag string, isIPv6 bool, addr net.Addr) (net.PacketConn, error) {
select {
case <-mux.ctx.Done():
return nil, io.ErrClosedPipe
Expand All @@ -116,11 +119,11 @@ func (mux *udpMux) getOrCreateConn(ufrag string, isIPv6 bool, addr net.Addr) (ne
}

// writeTo writes a packet to the underlying net.PacketConn
func (mux *udpMux) writeTo(buf []byte, addr net.Addr) (int, error) {
func (mux *UDPMux) writeTo(buf []byte, addr net.Addr) (int, error) {
return mux.socket.WriteTo(buf, addr)
}

func (mux *udpMux) readLoop() {
func (mux *UDPMux) readLoop() {
for {
select {
case <-mux.ctx.Done():
Expand All @@ -144,7 +147,7 @@ func (mux *udpMux) readLoop() {
}
}

func (mux *udpMux) processPacket(buf []byte, addr net.Addr) (processed bool) {
func (mux *UDPMux) processPacket(buf []byte, addr net.Addr) (processed bool) {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
log.Errorf("received a non-UDP address: %s", addr)
Expand Down Expand Up @@ -250,7 +253,7 @@ func (s *udpMuxStorage) removeConnByUfrag(ufrag string, closeConn bool) {
}
}

func (s *udpMuxStorage) GetOrCreateConn(ufrag string, isIPv6 bool, mux *udpMux, addr net.Addr) (created bool, _ *muxedConnection) {
func (s *udpMuxStorage) GetOrCreateConn(ufrag string, isIPv6 bool, mux *UDPMux, addr net.Addr) (created bool, _ *muxedConnection) {
key := ufragConnKey{ufrag: ufrag, isIPv6: isIPv6}

s.Lock()
Expand Down
2 changes: 1 addition & 1 deletion p2p/transport/webrtc/udpmux/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (dummyPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return 0, nil
}

func hasConn(m *udpMux, ufrag string, isIPv6 bool) bool {
func hasConn(m *UDPMux, ufrag string, isIPv6 bool) bool {
m.storage.Lock()
_, ok := m.storage.ufragMap[ufragConnKey{ufrag: ufrag, isIPv6: isIPv6}]
m.storage.Unlock()
Expand Down
4 changes: 2 additions & 2 deletions p2p/transport/webrtc/udpmux/muxed_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ type muxedConnection struct {
onClose func()
pq *packetQueue
addr net.Addr
mux *udpMux
mux *UDPMux
}

var _ net.PacketConn = (*muxedConnection)(nil)

func newMuxedConnection(mux *udpMux, onClose func(), addr net.Addr) *muxedConnection {
func newMuxedConnection(mux *UDPMux, onClose func(), addr net.Addr) *muxedConnection {
ctx, cancel := context.WithCancel(mux.ctx)
return &muxedConnection{
ctx: ctx,
Expand Down

0 comments on commit 82a4196

Please sign in to comment.