From 82a4196bc7fb4ee996d1ecc9528d42f971a9dcd5 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 20 Aug 2023 08:48:59 +0700 Subject: [PATCH] webrtc: fix race condition when starting the UDP muxer --- p2p/transport/webrtc/listener.go | 4 ++- p2p/transport/webrtc/udpmux/mux.go | 33 ++++++++++--------- p2p/transport/webrtc/udpmux/mux_test.go | 2 +- .../webrtc/udpmux/muxed_connection.go | 4 +-- 4 files changed, 24 insertions(+), 19 deletions(-) diff --git a/p2p/transport/webrtc/listener.go b/p2p/transport/webrtc/listener.go index 1caa5592c8..408698c985 100644 --- a/p2p/transport/webrtc/listener.go +++ b/p2p/transport/webrtc/listener.go @@ -111,7 +111,7 @@ func newListener(transport *WebRTCTransport, laddr ma.Multiaddr, socket net.Pack } l.ctx, l.cancel = context.WithCancel(context.Background()) - l.mux = udpmux.NewUDPMux(socket, func(ufrag string, addr net.Addr) bool { + mux := udpmux.NewUDPMux(socket, func(ufrag string, addr net.Addr) bool { select { case <-inFlightQueueCh: // we have space to accept, Yihaa @@ -149,6 +149,8 @@ func newListener(transport *WebRTCTransport, laddr ma.Multiaddr, socket net.Pack return true }) + l.mux = mux + mux.Start() return l, err } diff --git a/p2p/transport/webrtc/udpmux/mux.go b/p2p/transport/webrtc/udpmux/mux.go index 08b105a704..cfbfe26638 100644 --- a/p2p/transport/webrtc/udpmux/mux.go +++ b/p2p/transport/webrtc/udpmux/mux.go @@ -18,7 +18,7 @@ var log = logging.Logger("webrtc-udpmux") const ReceiveMTU = 1500 -// udpMux multiplexes multiple ICE connections over a single net.PacketConn, +// UDPMux multiplexes multiple ICE connections over a single net.PacketConn, // generally a UDP socket. // // The connections are indexed by (ufrag, IP address family) @@ -33,7 +33,7 @@ const ReceiveMTU = 1500 // is a connection associated with the (ufrag, IP address family) pair. If found // we add the association to the address map. If not found, it is a previously // unseen IP address and the `unknownUfragCallback` callback is invoked. -type udpMux struct { +type UDPMux struct { socket net.PacketConn unknownUfragCallback func(string, net.Addr) bool @@ -45,11 +45,11 @@ type udpMux struct { cancel context.CancelFunc } -var _ ice.UDPMux = &udpMux{} +var _ ice.UDPMux = &UDPMux{} -func NewUDPMux(socket net.PacketConn, unknownUfragCallback func(string, net.Addr) bool) *udpMux { +func NewUDPMux(socket net.PacketConn, unknownUfragCallback func(string, net.Addr) bool) *UDPMux { ctx, cancel := context.WithCancel(context.Background()) - mux := &udpMux{ + mux := &UDPMux{ ctx: ctx, cancel: cancel, socket: socket, @@ -57,16 +57,19 @@ func NewUDPMux(socket net.PacketConn, unknownUfragCallback func(string, net.Addr storage: newUDPMuxStorage(), } + return mux +} + +func (mux *UDPMux) Start() { mux.wg.Add(1) go func() { defer mux.wg.Done() mux.readLoop() }() - return mux } // GetListenAddresses implements ice.UDPMux -func (mux *udpMux) GetListenAddresses() []net.Addr { +func (mux *UDPMux) GetListenAddresses() []net.Addr { return []net.Addr{mux.socket.LocalAddr()} } @@ -76,7 +79,7 @@ func (mux *udpMux) GetListenAddresses() []net.Addr { // as a remote is capable of being reachable through multiple different // UDP addresses of the same IP address family (eg. Server-reflexive addresses // and peer-reflexive addresses). -func (mux *udpMux) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { +func (mux *UDPMux) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { a, ok := addr.(*net.UDPAddr) if !ok && addr != nil { return nil, fmt.Errorf("unexpected address type: %T", addr) @@ -86,7 +89,7 @@ func (mux *udpMux) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) } // Close implements ice.UDPMux -func (mux *udpMux) Close() error { +func (mux *UDPMux) Close() error { select { case <-mux.ctx.Done(): return nil @@ -99,13 +102,13 @@ func (mux *udpMux) Close() error { } // RemoveConnByUfrag implements ice.UDPMux -func (mux *udpMux) RemoveConnByUfrag(ufrag string) { +func (mux *UDPMux) RemoveConnByUfrag(ufrag string) { if ufrag != "" { mux.storage.RemoveConnByUfrag(ufrag) } } -func (mux *udpMux) getOrCreateConn(ufrag string, isIPv6 bool, addr net.Addr) (net.PacketConn, error) { +func (mux *UDPMux) getOrCreateConn(ufrag string, isIPv6 bool, addr net.Addr) (net.PacketConn, error) { select { case <-mux.ctx.Done(): return nil, io.ErrClosedPipe @@ -116,11 +119,11 @@ func (mux *udpMux) getOrCreateConn(ufrag string, isIPv6 bool, addr net.Addr) (ne } // writeTo writes a packet to the underlying net.PacketConn -func (mux *udpMux) writeTo(buf []byte, addr net.Addr) (int, error) { +func (mux *UDPMux) writeTo(buf []byte, addr net.Addr) (int, error) { return mux.socket.WriteTo(buf, addr) } -func (mux *udpMux) readLoop() { +func (mux *UDPMux) readLoop() { for { select { case <-mux.ctx.Done(): @@ -144,7 +147,7 @@ func (mux *udpMux) readLoop() { } } -func (mux *udpMux) processPacket(buf []byte, addr net.Addr) (processed bool) { +func (mux *UDPMux) processPacket(buf []byte, addr net.Addr) (processed bool) { udpAddr, ok := addr.(*net.UDPAddr) if !ok { log.Errorf("received a non-UDP address: %s", addr) @@ -250,7 +253,7 @@ func (s *udpMuxStorage) removeConnByUfrag(ufrag string, closeConn bool) { } } -func (s *udpMuxStorage) GetOrCreateConn(ufrag string, isIPv6 bool, mux *udpMux, addr net.Addr) (created bool, _ *muxedConnection) { +func (s *udpMuxStorage) GetOrCreateConn(ufrag string, isIPv6 bool, mux *UDPMux, addr net.Addr) (created bool, _ *muxedConnection) { key := ufragConnKey{ufrag: ufrag, isIPv6: isIPv6} s.Lock() diff --git a/p2p/transport/webrtc/udpmux/mux_test.go b/p2p/transport/webrtc/udpmux/mux_test.go index f6dc0e2aa3..498be61dcc 100644 --- a/p2p/transport/webrtc/udpmux/mux_test.go +++ b/p2p/transport/webrtc/udpmux/mux_test.go @@ -47,7 +47,7 @@ func (dummyPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { return 0, nil } -func hasConn(m *udpMux, ufrag string, isIPv6 bool) bool { +func hasConn(m *UDPMux, ufrag string, isIPv6 bool) bool { m.storage.Lock() _, ok := m.storage.ufragMap[ufragConnKey{ufrag: ufrag, isIPv6: isIPv6}] m.storage.Unlock() diff --git a/p2p/transport/webrtc/udpmux/muxed_connection.go b/p2p/transport/webrtc/udpmux/muxed_connection.go index dbfc5d6b75..d51e846b2c 100644 --- a/p2p/transport/webrtc/udpmux/muxed_connection.go +++ b/p2p/transport/webrtc/udpmux/muxed_connection.go @@ -18,12 +18,12 @@ type muxedConnection struct { onClose func() pq *packetQueue addr net.Addr - mux *udpMux + mux *UDPMux } var _ net.PacketConn = (*muxedConnection)(nil) -func newMuxedConnection(mux *udpMux, onClose func(), addr net.Addr) *muxedConnection { +func newMuxedConnection(mux *UDPMux, onClose func(), addr net.Addr) *muxedConnection { ctx, cancel := context.WithCancel(mux.ctx) return &muxedConnection{ ctx: ctx,