diff --git a/p2p/transport/webrtc/udpmux/mux.go b/p2p/transport/webrtc/udpmux/mux.go index 64b9eb5212..4dd0bf78c2 100644 --- a/p2p/transport/webrtc/udpmux/mux.go +++ b/p2p/transport/webrtc/udpmux/mux.go @@ -40,8 +40,11 @@ type Candidate struct { type UDPMux struct { socket net.PacketConn - storage *udpMuxStorage - queue chan Candidate + queue chan Candidate + + mx sync.Mutex + ufragMap map[ufragConnKey]*muxedConnection + addrMap map[string]*muxedConnection // the context controls the lifecycle of the mux wg sync.WaitGroup @@ -54,11 +57,12 @@ var _ ice.UDPMux = &UDPMux{} func NewUDPMux(socket net.PacketConn) *UDPMux { ctx, cancel := context.WithCancel(context.Background()) mux := &UDPMux{ - ctx: ctx, - cancel: cancel, - socket: socket, - storage: newUDPMuxStorage(), - queue: make(chan Candidate, 32), + ctx: ctx, + cancel: cancel, + socket: socket, + ufragMap: make(map[ufragConnKey]*muxedConnection), + addrMap: make(map[string]*muxedConnection), + queue: make(chan Candidate, 32), } return mux @@ -86,8 +90,14 @@ func (mux *UDPMux) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) if !ok { return nil, fmt.Errorf("unexpected address type: %T", addr) } - isIPv6 := ok && a.IP.To4() == nil - return mux.getOrCreateConn(ufrag, isIPv6, addr) + select { + case <-mux.ctx.Done(): + return nil, io.ErrClosedPipe + default: + isIPv6 := ok && a.IP.To4() == nil + _, conn := mux.getOrCreateConn(ufrag, isIPv6, mux, addr) + return conn, nil + } } // Close implements ice.UDPMux @@ -103,23 +113,6 @@ func (mux *UDPMux) Close() error { return nil } -// RemoveConnByUfrag implements ice.UDPMux -func (mux *UDPMux) RemoveConnByUfrag(ufrag string) { - if ufrag != "" { - mux.storage.RemoveConnByUfrag(ufrag) - } -} - -func (mux *UDPMux) getOrCreateConn(ufrag string, isIPv6 bool, addr net.Addr) (net.PacketConn, error) { - select { - case <-mux.ctx.Done(): - return nil, io.ErrClosedPipe - default: - _, conn := mux.storage.GetOrCreateConn(ufrag, isIPv6, mux, addr) - return conn, nil - } -} - // writeTo writes a packet to the underlying net.PacketConn func (mux *UDPMux) writeTo(buf []byte, addr net.Addr) (int, error) { return mux.socket.WriteTo(buf, addr) @@ -160,7 +153,10 @@ func (mux *UDPMux) processPacket(buf []byte, addr net.Addr) (processed bool) { // Connections are indexed by remote address. We first // check if the remote address has a connection associated // with it. If yes, we push the received packet to the connection - if conn, ok := mux.storage.GetConnByAddr(udpAddr); ok { + mux.mx.Lock() + conn, ok := mux.addrMap[addr.String()] + mux.mx.Unlock() + if ok { if err := conn.Push(buf); err != nil { log.Debugf("could not push packet: %v", err) return false @@ -189,7 +185,7 @@ func (mux *UDPMux) processPacket(buf []byte, addr net.Addr) (processed bool) { return false } - connCreated, conn := mux.storage.GetOrCreateConn(ufrag, isIPv6, mux, udpAddr) + connCreated, conn := mux.getOrCreateConn(ufrag, isIPv6, mux, udpAddr) if connCreated { select { case mux.queue <- Candidate{Addr: udpAddr, Ufrag: ufrag}: @@ -244,53 +240,36 @@ func ufragFromSTUNMessage(msg *stun.Message) (string, error) { return string(attr[index+1:]), nil } -type udpMuxStorage struct { - sync.Mutex - - ufragMap map[ufragConnKey]*muxedConnection - addrMap map[string]*muxedConnection -} - -func newUDPMuxStorage() *udpMuxStorage { - return &udpMuxStorage{ - ufragMap: make(map[ufragConnKey]*muxedConnection), - addrMap: make(map[string]*muxedConnection), +func (mux *UDPMux) RemoveConnByUfrag(ufrag string) { + if ufrag == "" { + return } -} -func (s *udpMuxStorage) RemoveConnByUfrag(ufrag string) { - s.Lock() - defer s.Unlock() + mux.mx.Lock() + defer mux.mx.Unlock() for _, isIPv6 := range [...]bool{true, false} { key := ufragConnKey{ufrag: ufrag, isIPv6: isIPv6} - if conn, ok := s.ufragMap[key]; ok { - delete(s.ufragMap, key) - delete(s.addrMap, conn.RemoteAddr().String()) + if conn, ok := mux.ufragMap[key]; ok { + delete(mux.ufragMap, key) + delete(mux.addrMap, conn.RemoteAddr().String()) } } } -func (s *udpMuxStorage) GetOrCreateConn(ufrag string, isIPv6 bool, mux *UDPMux, addr net.Addr) (created bool, _ *muxedConnection) { +func (mux *UDPMux) getOrCreateConn(ufrag string, isIPv6 bool, _ *UDPMux, addr net.Addr) (created bool, _ *muxedConnection) { key := ufragConnKey{ufrag: ufrag, isIPv6: isIPv6} - s.Lock() - defer s.Unlock() + mux.mx.Lock() + defer mux.mx.Unlock() - if conn, ok := s.ufragMap[key]; ok { + if conn, ok := mux.ufragMap[key]; ok { return false, conn } - conn := newMuxedConnection(mux, func() { s.RemoveConnByUfrag(ufrag) }, addr) - s.ufragMap[key] = conn - s.addrMap[addr.String()] = conn + conn := newMuxedConnection(mux, func() { mux.RemoveConnByUfrag(ufrag) }, addr) + mux.ufragMap[key] = conn + mux.addrMap[addr.String()] = conn return true, conn } - -func (s *udpMuxStorage) GetConnByAddr(addr *net.UDPAddr) (*muxedConnection, bool) { - s.Lock() - conn, ok := s.addrMap[addr.String()] - s.Unlock() - return conn, ok -}