Skip to content

Commit f9c2ac2

Browse files
committed
fix(endpoint): use custom resolver if specified for [IP] tests
1 parent d668a14 commit f9c2ac2

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

client/config.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ type Config struct {
6060
// Expected format is {protocol}://{host}:{port}, e.g. tcp://8.8.8.8:53
6161
DNSResolver string `yaml:"dns-resolver,omitempty"`
6262

63+
DNSResolverConfig *DNSResolverConfig `yaml:"-"`
64+
6365
// OAuth2Config is the OAuth2 configuration used for the client.
6466
//
6567
// If non-nil, the http.Client returned by getHTTPClient will automatically retrieve a token if necessary.
@@ -116,8 +118,10 @@ func (c *Config) ValidateAndSetDefaults() error {
116118
}
117119
if c.HasCustomDNSResolver() {
118120
// Validate the DNS resolver now to make sure it will not return an error later.
119-
if _, err := c.parseDNSResolver(); err != nil {
121+
if resolver, err := c.parseDNSResolver(); err != nil {
120122
return err
123+
} else {
124+
c.DNSResolverConfig = resolver
121125
}
122126
}
123127
if c.HasOAuth2Config() && !c.OAuth2Config.isValid() {

config/endpoint/endpoint.go

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package endpoint
22

33
import (
44
"bytes"
5+
"context"
56
"crypto/x509"
67
"encoding/json"
78
"errors"
@@ -432,11 +433,28 @@ func (e *Endpoint) getParsedBody() string {
432433
}
433434

434435
func (e *Endpoint) getIP(result *Result) {
435-
if ips, err := net.LookupIP(result.Hostname); err != nil {
436+
437+
resolver := net.DefaultResolver
438+
439+
if e.ClientConfig.HasCustomDNSResolver() {
440+
dnsResolver := e.ClientConfig.DNSResolverConfig
441+
resolver = &net.Resolver{
442+
PreferGo: true,
443+
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
444+
d := net.Dialer{}
445+
return d.DialContext(ctx, dnsResolver.Protocol, dnsResolver.Host+":"+dnsResolver.Port)
446+
},
447+
}
448+
}
449+
450+
addrs, err := resolver.LookupIP(context.Background(), e.ClientConfig.Network, result.Hostname)
451+
if err != nil {
436452
result.AddError(err.Error())
437453
return
438-
} else {
439-
result.IP = ips[0].String()
454+
}
455+
for _, ia := range addrs {
456+
result.IP = ia.String()
457+
return
440458
}
441459
}
442460

0 commit comments

Comments
 (0)