Skip to content

Commit

Permalink
Refactor and add support for the client-subnet option
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Feb 9, 2024
1 parent 27e217b commit 6a377c9
Show file tree
Hide file tree
Showing 14 changed files with 222 additions and 150 deletions.
4 changes: 4 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ func (c *Client) ExchangeWithResponseCheck(ctx context.Context, transport Transp
return nil, E.New("DNS query loopback in transport[", contextTransport, "]")
}
ctx = contextWithTransportName(ctx, transport.Name())
clientSubnet, loaded := ClientSubnetFromContext(ctx)
if loaded {
SetClientSubnet(message, clientSubnet, true)
}
response, err := transport.Exchange(ctx, message)
if err != nil {
return nil, err
Expand Down
55 changes: 55 additions & 0 deletions extension_edns0_subnet.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package dns

import (
"context"
"net/netip"

"github.com/miekg/dns"
)

type edns0SubnetTransportWrapper struct {
Transport
clientSubnet netip.Addr
}

func (t *edns0SubnetTransportWrapper) Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error) {
SetClientSubnet(message, t.clientSubnet, false)
return t.Transport.Exchange(ctx, message)
}

func SetClientSubnet(message *dns.Msg, clientSubnet netip.Addr, override bool) {
var subnetOption *dns.EDNS0_SUBNET
findExists:
for _, record := range message.Extra {
if optRecord, isOPTRecord := record.(*dns.OPT); isOPTRecord {
for _, option := range optRecord.Option {
var isEDNS0Subnet bool
subnetOption, isEDNS0Subnet = option.(*dns.EDNS0_SUBNET)
if isEDNS0Subnet {
if !override {
return
}
break findExists
}
}
}
}
if subnetOption == nil {
subnetOption = new(dns.EDNS0_SUBNET)
message.Extra = append(message.Extra, &dns.OPT{
Hdr: dns.RR_Header{
Name: ".",
Rrtype: dns.TypeOPT,
},
Option: []dns.EDNS0{subnetOption},
})
}
subnetOption.Code = dns.EDNS0SUBNET
if clientSubnet.Is4() {
subnetOption.Family = 1
} else {
subnetOption.Family = 2
}
subnetOption.SourceNetmask = uint8(clientSubnet.BitLen())
subnetOption.Address = clientSubnet.AsSlice()
}
56 changes: 56 additions & 0 deletions extensions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package dns

import (
"context"
"net/netip"
)

type disableCacheKey struct{}

func ContextWithDisableCache(ctx context.Context, val bool) context.Context {
return context.WithValue(ctx, (*disableCacheKey)(nil), val)
}

func DisableCacheFromContext(ctx context.Context) bool {
val := ctx.Value((*disableCacheKey)(nil))
if val == nil {
return false
}
return val.(bool)
}

type rewriteTTLKey struct{}

func ContextWithRewriteTTL(ctx context.Context, val uint32) context.Context {
return context.WithValue(ctx, (*rewriteTTLKey)(nil), val)
}

func RewriteTTLFromContext(ctx context.Context) (uint32, bool) {
val := ctx.Value((*rewriteTTLKey)(nil))
if val == nil {
return 0, false
}
return val.(uint32), true
}

type transportKey struct{}

func contextWithTransportName(ctx context.Context, transportName string) context.Context {
return context.WithValue(ctx, transportKey{}, transportName)
}

func transportNameFromContext(ctx context.Context) (string, bool) {
value, loaded := ctx.Value(transportKey{}).(string)
return value, loaded
}

type clientSubnetKey struct{}

func ContextWithClientSubnet(ctx context.Context, clientSubnet netip.Addr) context.Context {
return context.WithValue(ctx, clientSubnetKey{}, clientSubnet)
}

func ClientSubnetFromContext(ctx context.Context) (netip.Addr, bool) {
clientSubnet, ok := ctx.Value(clientSubnetKey{}).(netip.Addr)
return clientSubnet, ok
}
14 changes: 0 additions & 14 deletions loopback.go

This file was deleted.

31 changes: 0 additions & 31 deletions options.go

This file was deleted.

33 changes: 15 additions & 18 deletions quic/transport_http3.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/sagernet/quic-go/http3"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"

Expand All @@ -24,16 +23,9 @@ import (
var _ dns.Transport = (*HTTP3Transport)(nil)

func init() {
dns.RegisterTransport([]string{"h3"}, CreateHTTP3Transport)
}

func CreateHTTP3Transport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) {
linkURL, err := url.Parse(link)
if err != nil {
return nil, err
}
linkURL.Scheme = "https"
return NewHTTP3Transport(name, dialer, linkURL.String()), nil
dns.RegisterTransport([]string{"h3"}, func(options dns.TransportOptions) (dns.Transport, error) {
return NewHTTP3Transport(options)
})
}

type HTTP3Transport struct {
Expand All @@ -42,24 +34,29 @@ type HTTP3Transport struct {
transport *http3.RoundTripper
}

func NewHTTP3Transport(name string, dialer N.Dialer, serverURL string) *HTTP3Transport {
func NewHTTP3Transport(options dns.TransportOptions) (*HTTP3Transport, error) {
serverURL, err := url.Parse(options.Address)
if err != nil {
return nil, err
}
serverURL.Scheme = "https"
return &HTTP3Transport{
name: name,
destination: serverURL,
name: options.Name,
destination: options.Address,
transport: &http3.RoundTripper{
Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
destinationAddr := M.ParseSocksaddr(addr)
conn, err := dialer.DialContext(ctx, N.NetworkUDP, destinationAddr)
if err != nil {
return nil, err
conn, dialErr := options.Dialer.DialContext(ctx, N.NetworkUDP, destinationAddr)
if dialErr != nil {
return nil, dialErr
}
return quic.DialEarly(ctx, bufio.NewUnbindPacketConn(conn), conn.RemoteAddr(), tlsCfg, cfg)
},
TLSClientConfig: &tls.Config{
NextProtos: []string{"dns"},
},
},
}
}, nil
}

func (t *HTTP3Transport) Name() string {
Expand Down
29 changes: 23 additions & 6 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/miekg/dns"
)

type TransportConstructor = func(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (Transport, error)
type TransportConstructor = func(options TransportOptions) (Transport, error)

type Transport interface {
Name() string
Expand All @@ -24,6 +24,15 @@ type Transport interface {
Lookup(ctx context.Context, domain string, strategy DomainStrategy) ([]netip.Addr, error)
}

type TransportOptions struct {
Context context.Context
Logger logger.ContextLogger
Name string
Dialer N.Dialer
Address string
ClientSubnet netip.Addr
}

var transports map[string]TransportConstructor

func RegisterTransport(schemes []string, constructor TransportConstructor) {
Expand All @@ -35,18 +44,26 @@ func RegisterTransport(schemes []string, constructor TransportConstructor) {
}
}

func CreateTransport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, address string) (Transport, error) {
constructor := transports[address]
func CreateTransport(options TransportOptions) (Transport, error) {
constructor := transports[options.Address]
if constructor == nil {
serverURL, _ := url.Parse(address)
serverURL, _ := url.Parse(options.Address)
var scheme string
if serverURL != nil {
scheme = serverURL.Scheme
}
constructor = transports[scheme]
}
if constructor == nil {
return nil, E.New("unknown DNS server format: " + address)
return nil, E.New("unknown DNS server format: " + options.Address)
}
options.Context = contextWithTransportName(options.Context, options.Name)
transport, err := constructor(options)
if err != nil {
return nil, err
}
if options.ClientSubnet.IsValid() {
transport = &edns0SubnetTransportWrapper{transport, options.ClientSubnet}
}
return constructor(name, contextWithTransportName(ctx, name), logger, dialer, address)
return transport, nil
}
10 changes: 6 additions & 4 deletions transport_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,21 @@ type myTransportAdapter struct {
cancel context.CancelFunc
dialer N.Dialer
serverAddr M.Socksaddr
clientAddr netip.Addr
handler myTransportHandler
access sync.Mutex
conn *dnsConnection
}

func newAdapter(name string, ctx context.Context, dialer N.Dialer, serverAddr M.Socksaddr) myTransportAdapter {
ctx, cancel := context.WithCancel(ctx)
func newAdapter(options TransportOptions, serverAddr M.Socksaddr) myTransportAdapter {
ctx, cancel := context.WithCancel(options.Context)
return myTransportAdapter{
name: name,
name: options.Name,
ctx: ctx,
cancel: cancel,
dialer: dialer,
dialer: options.Dialer,
serverAddr: serverAddr,
clientAddr: options.ClientSubnet,
}
}

Expand Down
18 changes: 7 additions & 11 deletions transport_https.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ import (
"net/netip"
"os"

"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"

"github.com/miekg/dns"
)
Expand All @@ -28,21 +26,19 @@ type HTTPSTransport struct {
}

func init() {
RegisterTransport([]string{"https"}, CreateHTTPSTransport)
RegisterTransport([]string{"https"}, func(options TransportOptions) (Transport, error) {
return NewHTTPSTransport(options), nil
})
}

func CreateHTTPSTransport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (Transport, error) {
return NewHTTPSTransport(name, dialer, link), nil
}

func NewHTTPSTransport(name string, dialer N.Dialer, serverURL string) *HTTPSTransport {
func NewHTTPSTransport(options TransportOptions) *HTTPSTransport {
return &HTTPSTransport{
name: name,
destination: serverURL,
name: options.Name,
destination: options.Address,
transport: &http.Transport{
ForceAttemptHTTP2: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
return options.Dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
TLSClientConfig: &tls.Config{
NextProtos: []string{"dns"},
Expand Down
15 changes: 6 additions & 9 deletions transport_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,16 @@ import (
"sort"

"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"

"github.com/miekg/dns"
)

func init() {
RegisterTransport([]string{"local"}, CreateLocalTransport)
}

func CreateLocalTransport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (Transport, error) {
return NewLocalTransport(name, dialer), nil
RegisterTransport([]string{"local"}, func(options TransportOptions) (Transport, error) {
return NewLocalTransport(options), nil
})
}

var _ Transport = (*LocalTransport)(nil)
Expand All @@ -30,12 +27,12 @@ type LocalTransport struct {
resolver net.Resolver
}

func NewLocalTransport(name string, dialer N.Dialer) *LocalTransport {
func NewLocalTransport(options TransportOptions) *LocalTransport {
return &LocalTransport{
name: name,
name: options.Name,
resolver: net.Resolver{
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
return dialer.DialContext(ctx, N.NetworkName(network), M.ParseSocksaddr(address))
return options.Dialer.DialContext(ctx, N.NetworkName(network), M.ParseSocksaddr(address))
},
},
}
Expand Down
Loading

0 comments on commit 6a377c9

Please sign in to comment.