diff --git a/gather.go b/gather.go index 15536ce4..7469424b 100644 --- a/gather.go +++ b/gather.go @@ -16,7 +16,9 @@ import ( ) const ( - stunGatherTimeout = time.Second * 5 + stunGatherTimeout = time.Second * 5 + maxIPv4LookupsPerRelay = 5 + maxIPv6LookupsPerRelay = 5 ) type closeable interface { @@ -377,7 +379,7 @@ func (a *Agent) gatherCandidatesSrflxUDPMux(ctx context.Context, urls []*URL, ne go func(url URL, network string, isIPv6 bool) { defer wg.Done() - hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port) + hostPort := url.HostPortString() serverAddr, err := a.net.ResolveUDPAddr(network, hostPort) if err != nil { a.log.Warnf("failed to resolve stun host: %s: %v", hostPort, err) @@ -444,7 +446,7 @@ func (a *Agent) gatherCandidatesSrflx(ctx context.Context, urls []*URL, networkT go func(url URL, network string) { defer wg.Done() - hostPort := fmt.Sprintf("%s:%d", url.Host, url.Port) + hostPort := url.HostPortString() serverAddr, err := a.net.ResolveUDPAddr(network, hostPort) if err != nil { a.log.Warnf("failed to resolve stun host: %s: %v", hostPort, err) @@ -508,7 +510,6 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*URL) { //noli var wg sync.WaitGroup defer wg.Wait() - network := NetworkTypeUDP4.String() for i := range urls { switch { case urls[i].Scheme != SchemeTypeTURN && urls[i].Scheme != SchemeTypeTURNS: @@ -521,11 +522,17 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*URL) { //noli return } - wg.Add(1) - go func(url URL) { + generateCandidate := func(url URL, ipv6 bool) { defer wg.Done() - TURNServerAddr := fmt.Sprintf("%s:%d", url.Host, url.Port) + var ( + TURNServerAddr = url.HostPortString() + + udpNetworkType string + tcpNetworkType string + localAddress string + network string + locConn net.PacketConn err error RelAddr string @@ -533,9 +540,20 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*URL) { //noli relayProtocol string ) + if ipv6 { + udpNetworkType = NetworkTypeUDP6.String() + tcpNetworkType = NetworkTypeTCP6.String() + localAddress = ":" + } else { + udpNetworkType = NetworkTypeUDP4.String() + tcpNetworkType = NetworkTypeTCP4.String() + localAddress = "0.0.0.0:0" + } + switch { case url.Proto == ProtoTypeUDP && url.Scheme == SchemeTypeTURN: - if locConn, err = a.net.ListenPacket(network, "0.0.0.0:0"); err != nil { + network = udpNetworkType + if locConn, err = a.net.ListenPacket(network, localAddress); err != nil { a.log.Warnf("Failed to listen %s: %v", network, err) return } @@ -545,7 +563,8 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*URL) { //noli relayProtocol = udp case a.proxyDialer != nil && url.Proto == ProtoTypeTCP && (url.Scheme == SchemeTypeTURN || url.Scheme == SchemeTypeTURNS): - conn, connectErr := a.proxyDialer.Dial(NetworkTypeTCP4.String(), TURNServerAddr) + network = tcpNetworkType + conn, connectErr := a.proxyDialer.Dial(network, TURNServerAddr) if connectErr != nil { a.log.Warnf("Failed to Dial TCP Addr %s via proxy dialer: %v", TURNServerAddr, connectErr) return @@ -561,13 +580,14 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*URL) { //noli locConn = turn.NewSTUNConn(conn) case url.Proto == ProtoTypeTCP && url.Scheme == SchemeTypeTURN: - tcpAddr, connectErr := net.ResolveTCPAddr(NetworkTypeTCP4.String(), TURNServerAddr) + network = tcpNetworkType + tcpAddr, connectErr := net.ResolveTCPAddr(network, TURNServerAddr) if connectErr != nil { a.log.Warnf("Failed to resolve TCP Addr %s: %v", TURNServerAddr, connectErr) return } - conn, connectErr := net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr) + conn, connectErr := net.DialTCP(network, nil, tcpAddr) if connectErr != nil { a.log.Warnf("Failed to Dial TCP Addr %s: %v", TURNServerAddr, connectErr) return @@ -578,6 +598,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*URL) { //noli relayProtocol = tcp locConn = turn.NewSTUNConn(conn) case url.Proto == ProtoTypeUDP && url.Scheme == SchemeTypeTURNS: + network = udpNetworkType udpAddr, connectErr := net.ResolveUDPAddr(network, TURNServerAddr) if connectErr != nil { a.log.Warnf("Failed to resolve UDP Addr %s: %v", TURNServerAddr, connectErr) @@ -598,7 +619,8 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*URL) { //noli relayProtocol = "dtls" locConn = &fakePacketConn{conn} case url.Proto == ProtoTypeTCP && url.Scheme == SchemeTypeTURNS: - conn, connectErr := tls.Dial(NetworkTypeTCP4.String(), TURNServerAddr, &tls.Config{ + network = tcpNetworkType + conn, connectErr := tls.Dial(network, TURNServerAddr, &tls.Config{ InsecureSkipVerify: a.insecureSkipVerify, //nolint:gosec }) if connectErr != nil { @@ -621,6 +643,7 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*URL) { //noli Password: url.Password, LoggerFactory: a.loggerFactory, Net: a.net, + IPv6: ipv6, }) if err != nil { closeConnAndLog(locConn, a.log, fmt.Sprintf("Failed to build new turn.Client %s %s", TURNServerAddr, err)) @@ -676,6 +699,44 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*URL) { //noli } a.log.Warnf("Failed to append to localCandidates and run onCandidateHdlr: %v", err) } - }(*urls[i]) + } + + url := *urls[i] + + var targetIPs []net.IP + var err error + if hostIP := net.ParseIP(url.Host); hostIP != nil { + // literal IP provided + targetIPs = []net.IP{hostIP} + } else { + // hostname provided, perform DNS lookup + targetIPs, err = net.LookupIP(url.Host) + if err != nil { + a.log.Warnf("Failed to lookup host IPs: %v", err) + continue + } + } + + ipv4Lookups := 0 + ipv6Lookups := 0 + for _, ip := range targetIPs { + if ipv4 := ip.To4(); ipv4 != nil && ipv4Lookups < maxIPv4LookupsPerRelay { + ipv4Lookups += 1 + var urlCopy URL = url + urlCopy.Host = ipv4.String() + + wg.Add(1) + go generateCandidate(urlCopy, false) + } else if ipv6 := ip.To16(); ipv6 != nil && ipv6Lookups < maxIPv6LookupsPerRelay { + ipv6Lookups += 1 + var urlCopy URL = url + urlCopy.Host = ipv6.String() + + wg.Add(1) + go generateCandidate(urlCopy, true) + } else if ipv6Lookups >= maxIPv6LookupsPerRelay && ipv4Lookups >= maxIPv4LookupsPerRelay { + break + } + } } } diff --git a/url.go b/url.go index 33082cd6..f6b39584 100644 --- a/url.go +++ b/url.go @@ -213,6 +213,12 @@ func parseProto(raw string) (ProtoType, error) { return proto, nil } +// HostPortString returns a string in the format "host:port" or "[host]:port" if +// a literal IPv6 address is given +func (u URL) HostPortString() string { + return net.JoinHostPort(u.Host, strconv.Itoa(u.Port)) +} + func (u URL) String() string { rawURL := u.Scheme.String() + ":" + net.JoinHostPort(u.Host, strconv.Itoa(u.Port)) if u.Scheme == SchemeTypeTURN || u.Scheme == SchemeTypeTURNS {