Skip to content

Commit 9ae4518

Browse files
committed
Migrate to udpnat2 / Add PrepareConnection
1 parent dc5d3e8 commit 9ae4518

8 files changed

+126
-89
lines changed

go.mod

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ go 1.20
55
require (
66
github.com/go-ole/go-ole v1.3.0
77
github.com/sagernet/fswatch v0.1.1
8-
github.com/sagernet/gvisor v0.0.0-20241019061641-46bad1ee6ecc
8+
github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3
99
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a
1010
github.com/sagernet/nftables v0.3.0-beta.4
11-
github.com/sagernet/sing v0.5.0-rc.4.0.20241020060022-1270938dd44a
11+
github.com/sagernet/sing v0.5.0-rc.4.0.20241021153852-cf58af1a4627
1212
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
1313
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8
1414
golang.org/x/net v0.26.0

go.sum

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@ github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8Ku
1616
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
1717
github.com/sagernet/fswatch v0.1.1 h1:YqID+93B7VRfqIH3PArW/XpJv5H4OLEVWDfProGoRQs=
1818
github.com/sagernet/fswatch v0.1.1/go.mod h1:nz85laH0mkQqJfaOrqPpkwtU1znMFNVTpT/5oRsVz/o=
19-
github.com/sagernet/gvisor v0.0.0-20241019061641-46bad1ee6ecc h1:IvmeRstYX63O0QpLGJgVOaaM21ZIG0frJi6MT29Irtk=
20-
github.com/sagernet/gvisor v0.0.0-20241019061641-46bad1ee6ecc/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw=
19+
github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3 h1:RxEz7LhPNiF/gX/Hg+OXr5lqsM9iVAgmaK1L1vzlDRM=
20+
github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw=
2121
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZNjr6sGeT00J8uU7JF4cNUdb44/Duis=
2222
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
2323
github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I=
2424
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
25-
github.com/sagernet/sing v0.5.0-rc.4.0.20241020060022-1270938dd44a h1:6qlFfBvLZT/MhDpUr4cKY6RxYTnaCcFgOrJEnf/0+io=
26-
github.com/sagernet/sing v0.5.0-rc.4.0.20241020060022-1270938dd44a/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
25+
github.com/sagernet/sing v0.5.0-rc.4.0.20241021153852-cf58af1a4627 h1:wWRmqHPHfyWRPUIGsjAmYshvXF+pC/csl9pAmo/vGpo=
26+
github.com/sagernet/sing v0.5.0-rc.4.0.20241021153852-cf58af1a4627/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
2727
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
2828
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
2929
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=

stack.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/binary"
66
"net"
77
"net/netip"
8+
"time"
89

910
"github.com/sagernet/sing/common/control"
1011
E "github.com/sagernet/sing/common/exceptions"
@@ -23,7 +24,7 @@ type StackOptions struct {
2324
Tun Tun
2425
TunOptions Options
2526
EndpointIndependentNat bool
26-
UDPTimeout int64
27+
UDPTimeout time.Duration
2728
Handler Handler
2829
Logger logger.Logger
2930
ForwarderBindInterface bool

stack_gvisor.go

+17-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package tun
55
import (
66
"context"
77
"net/netip"
8+
"os"
89
"time"
910

1011
"github.com/sagernet/gvisor/pkg/tcpip"
@@ -32,7 +33,7 @@ type GVisor struct {
3233
ctx context.Context
3334
tun GVisorTun
3435
endpointIndependentNat bool
35-
udpTimeout int64
36+
udpTimeout time.Duration
3637
broadcastAddr netip.Addr
3738
handler Handler
3839
logger logger.Logger
@@ -85,13 +86,18 @@ func (t *GVisor) Start() error {
8586
localAddr: source.TCPAddr(),
8687
remoteAddr: destination.TCPAddr(),
8788
}
89+
pErr := t.handler.PrepareConnection(source, destination)
90+
if pErr != nil {
91+
r.Complete(gWriteUnreachable(t.stack, r.Packet(), err) == os.ErrInvalid)
92+
return
93+
}
8894
go t.handler.NewConnectionEx(t.ctx, conn, source, destination, nil)
8995
})
9096
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
9197
if !t.endpointIndependentNat {
92-
udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
98+
udpForwarder := udp.NewForwarder(ipStack, func(r *udp.ForwarderRequest) {
9399
var wq waiter.Queue
94-
endpoint, err := request.CreateEndpoint(&wq)
100+
endpoint, err := r.CreateEndpoint(&wq)
95101
if err != nil {
96102
return
97103
}
@@ -102,9 +108,15 @@ func (t *GVisor) Start() error {
102108
endpoint.Abort()
103109
return
104110
}
111+
source := M.SocksaddrFromNet(lAddr)
112+
destination := M.SocksaddrFromNet(rAddr)
113+
pErr := t.handler.PrepareConnection(source, destination)
114+
if pErr != nil {
115+
gWriteUnreachable(t.stack, r.Packet(), pErr)
116+
r.Packet().DecRef()
117+
return
118+
}
105119
go func() {
106-
source := M.SocksaddrFromNet(lAddr)
107-
destination := M.SocksaddrFromNet(rAddr)
108120
ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(udpConn, destination), time.Duration(t.udpTimeout)*time.Second)
109121
t.handler.NewPacketConnectionEx(ctx, conn, source, destination, nil)
110122
}()

stack_gvisor_udp.go

+39-36
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"net/netip"
99
"os"
1010
"sync"
11+
"time"
12+
_ "unsafe"
1113

1214
"github.com/sagernet/gvisor/pkg/buffer"
1315
"github.com/sagernet/gvisor/pkg/tcpip"
@@ -19,59 +21,60 @@ import (
1921
E "github.com/sagernet/sing/common/exceptions"
2022
M "github.com/sagernet/sing/common/metadata"
2123
N "github.com/sagernet/sing/common/network"
22-
"github.com/sagernet/sing/common/udpnat"
24+
"github.com/sagernet/sing/common/udpnat2"
2325
)
2426

2527
type UDPForwarder struct {
26-
ctx context.Context
27-
stack *stack.Stack
28-
udpNat *udpnat.Service[netip.AddrPort]
29-
30-
// cache
31-
cacheProto tcpip.NetworkProtocolNumber
32-
cacheID stack.TransportEndpointID
28+
ctx context.Context
29+
stack *stack.Stack
30+
handler Handler
31+
udpNat *udpnat.Service
3332
}
3433

35-
func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder {
36-
return &UDPForwarder{
37-
ctx: ctx,
38-
stack: stack,
39-
udpNat: udpnat.NewEx[netip.AddrPort](udpTimeout, handler),
34+
func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, timeout time.Duration) *UDPForwarder {
35+
forwarder := &UDPForwarder{
36+
ctx: ctx,
37+
stack: stack,
38+
handler: handler,
4039
}
40+
forwarder.udpNat = udpnat.New(handler, forwarder.PreparePacketConnection, timeout)
41+
return forwarder
4142
}
4243

4344
func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
4445
source := M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort)
4546
destination := M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort)
46-
if source.IsIPv4() {
47-
f.cacheProto = header.IPv4ProtocolNumber
48-
} else {
49-
f.cacheProto = header.IPv6ProtocolNumber
50-
}
51-
gBuffer := pkt.Data().ToBuffer()
52-
sBuffer := buf.NewSize(int(gBuffer.Size()))
53-
gBuffer.Apply(func(view *buffer.View) {
54-
sBuffer.Write(view.AsSlice())
47+
bufferRange := pkt.Data().AsRange()
48+
bufferSlices := make([][]byte, bufferRange.Size())
49+
rangeIterate(bufferRange, func(view *buffer.View) {
50+
bufferSlices = append(bufferSlices, view.AsSlice())
5551
})
56-
f.cacheID = id
57-
f.udpNat.NewPacketEx(
58-
f.ctx,
59-
source.AddrPort(),
60-
sBuffer,
61-
source,
62-
destination,
63-
f.newUDPConn,
64-
)
52+
f.udpNat.NewPacket(bufferSlices, source, destination, pkt)
6553
return true
6654
}
6755

68-
func (f *UDPForwarder) newUDPConn(natConn N.PacketConn) N.PacketWriter {
69-
return &UDPBackWriter{
56+
//go:linkname rangeIterate github.com/sagernet/gvisor/pkg/tcpip/stack.Range.iterate
57+
func rangeIterate(r stack.Range, fn func(*buffer.View))
58+
59+
func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
60+
pErr := f.handler.PrepareConnection(source, destination)
61+
if pErr != nil {
62+
gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer), pErr)
63+
return false, nil, nil, nil
64+
}
65+
var sourceNetwork tcpip.NetworkProtocolNumber
66+
if source.Addr.Is4() {
67+
sourceNetwork = header.IPv4ProtocolNumber
68+
} else {
69+
sourceNetwork = header.IPv6ProtocolNumber
70+
}
71+
writer := &UDPBackWriter{
7072
stack: f.stack,
71-
source: f.cacheID.RemoteAddress,
72-
sourcePort: f.cacheID.RemotePort,
73-
sourceNetwork: f.cacheProto,
73+
source: AddressFromAddr(source.Addr),
74+
sourcePort: source.Port,
75+
sourceNetwork: sourceNetwork,
7476
}
77+
return true, f.ctx, writer, nil
7578
}
7679

7780
type UDPBackWriter struct {

stack_system.go

+51-38
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515
"github.com/sagernet/sing/common/logger"
1616
M "github.com/sagernet/sing/common/metadata"
1717
N "github.com/sagernet/sing/common/network"
18-
"github.com/sagernet/sing/common/udpnat"
18+
"github.com/sagernet/sing/common/udpnat2"
1919
)
2020

2121
var ErrIncludeAllNetworks = E.New("`system` and `mixed` stack are not available when `includeAllNetworks` is enabled. See https://github.com/SagerNet/sing-tun/issues/25")
@@ -34,13 +34,13 @@ type System struct {
3434
inet6ServerAddress netip.Addr
3535
inet6Address netip.Addr
3636
broadcastAddr netip.Addr
37-
udpTimeout int64
37+
udpTimeout time.Duration
3838
tcpListener net.Listener
3939
tcpListener6 net.Listener
4040
tcpPort uint16
4141
tcpPort6 uint16
4242
tcpNat *TCPNat
43-
udpNat *udpnat.Service[netip.AddrPort]
43+
udpNat *udpnat.Service
4444
bindInterface bool
4545
interfaceFinder control.InterfaceFinder
4646
frontHeadroom int
@@ -151,8 +151,8 @@ func (s *System) start() error {
151151
s.tcpPort6 = M.SocksaddrFromNet(tcpListener.Addr()).Port
152152
go s.acceptLoop(tcpListener)
153153
}
154-
s.tcpNat = NewNat(s.ctx, time.Second*time.Duration(s.udpTimeout))
155-
s.udpNat = udpnat.NewEx[netip.AddrPort](s.udpTimeout, s.handler)
154+
s.tcpNat = NewNat(s.ctx, s.udpTimeout)
155+
s.udpNat = udpnat.New(s.handler, s.preparePacketConnection, s.udpTimeout)
156156
return nil
157157
}
158158

@@ -354,7 +354,11 @@ func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.
354354
packet.SetDestinationIP(session.Source.Addr())
355355
header.SetDestinationPort(session.Source.Port())
356356
} else {
357-
natPort := s.tcpNat.Lookup(source, destination)
357+
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
358+
if err != nil {
359+
// TODO: implement rejects
360+
return nil
361+
}
358362
packet.SetSourceIP(s.inet4Address)
359363
header.SetSourcePort(natPort)
360364
packet.SetDestinationIP(s.inet4ServerAddress)
@@ -385,7 +389,11 @@ func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.
385389
packet.SetDestinationIP(session.Source.Addr())
386390
header.SetDestinationPort(session.Source.Port())
387391
} else {
388-
natPort := s.tcpNat.Lookup(source, destination)
392+
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
393+
if err != nil {
394+
// TODO: implement rejects
395+
return nil
396+
}
389397
packet.SetSourceIP(s.inet6Address)
390398
header.SetSourcePort(natPort)
391399
packet.SetDestinationIP(s.inet6ServerAddress)
@@ -409,56 +417,61 @@ func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.
409417
if !header.Valid() {
410418
return E.New("ipv4: udp: invalid packet")
411419
}
412-
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
413-
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
414-
if !destination.Addr().IsGlobalUnicast() {
420+
source := M.SocksaddrFrom(packet.SourceIP(), header.SourcePort())
421+
destination := M.SocksaddrFrom(packet.DestinationIP(), header.DestinationPort())
422+
if !destination.Addr.IsGlobalUnicast() {
415423
return nil
416424
}
417-
data := buf.As(header.Payload())
418-
if data.Len() == 0 {
425+
s.udpNat.NewPacket([][]byte{header.Payload()}, source, destination, packet)
426+
return nil
427+
}
428+
429+
func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error {
430+
if !header.Valid() {
431+
return E.New("ipv6: udp: invalid packet")
432+
}
433+
source := M.SocksaddrFrom(packet.SourceIP(), header.SourcePort())
434+
destination := M.SocksaddrFrom(packet.DestinationIP(), header.DestinationPort())
435+
if !destination.Addr.IsGlobalUnicast() {
419436
return nil
420437
}
421-
s.udpNat.NewPacketEx(s.ctx, source, data.ToOwned(), M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination), func(natConn N.PacketConn) N.PacketWriter {
438+
s.udpNat.NewPacket([][]byte{header.Payload()}, source, destination, packet)
439+
return nil
440+
}
441+
442+
func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
443+
pErr := s.handler.PrepareConnection(source, destination)
444+
if pErr != nil {
445+
// TODO: implement ICMP port unreachable
446+
return false, nil, nil, nil
447+
}
448+
var writer N.PacketWriter
449+
if source.IsIPv4() {
450+
packet := userData.(clashtcpip.IPv4Packet)
422451
headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize
423452
headerCopy := make([]byte, headerLen)
424453
copy(headerCopy, packet[:headerLen])
425-
return &systemUDPPacketWriter4{
454+
writer = &systemUDPPacketWriter4{
426455
s.tun,
427456
s.frontHeadroom + PacketOffset,
428457
headerCopy,
429-
source,
458+
source.AddrPort(),
430459
s.txChecksumOffload,
431460
}
432-
})
433-
return nil
434-
}
435-
436-
func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error {
437-
if !header.Valid() {
438-
return E.New("ipv6: udp: invalid packet")
439-
}
440-
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
441-
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
442-
if !destination.Addr().IsGlobalUnicast() {
443-
return nil
444-
}
445-
data := buf.As(header.Payload())
446-
if data.Len() == 0 {
447-
return nil
448-
}
449-
s.udpNat.NewPacketEx(s.ctx, source, data.ToOwned(), M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination), func(natConn N.PacketConn) N.PacketWriter {
450-
headerLen := len(packet) - int(header.Length()) + clashtcpip.UDPHeaderSize
461+
} else {
462+
packet := userData.(clashtcpip.IPv6Packet)
463+
headerLen := len(packet) - int(packet.PayloadLength()) + clashtcpip.UDPHeaderSize
451464
headerCopy := make([]byte, headerLen)
452465
copy(headerCopy, packet[:headerLen])
453-
return &systemUDPPacketWriter6{
466+
writer = &systemUDPPacketWriter6{
454467
s.tun,
455468
s.frontHeadroom + PacketOffset,
456469
headerCopy,
457-
source,
470+
source.AddrPort(),
458471
s.txChecksumOffload,
459472
}
460-
})
461-
return nil
473+
}
474+
return true, s.ctx, writer, nil
462475
}
463476

464477
func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error {

0 commit comments

Comments
 (0)