Skip to content

Commit 1ce8ae6

Browse files
committed
badtls: Support uTLS and TLS ECH for read waiter
1 parent 11bec79 commit 1ce8ae6

File tree

3 files changed

+116
-22
lines changed

3 files changed

+116
-22
lines changed

common/badtls/read_wait.go

+54-22
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ package badtls
44

55
import (
66
"bytes"
7+
"context"
8+
"net"
79
"os"
810
"reflect"
911
"sync"
@@ -18,20 +20,32 @@ import (
1820
var _ N.ReadWaiter = (*ReadWaitConn)(nil)
1921

2022
type ReadWaitConn struct {
21-
*tls.STDConn
22-
halfAccess *sync.Mutex
23-
rawInput *bytes.Buffer
24-
input *bytes.Reader
25-
hand *bytes.Buffer
26-
readWaitOptions N.ReadWaitOptions
23+
tls.Conn
24+
halfAccess *sync.Mutex
25+
rawInput *bytes.Buffer
26+
input *bytes.Reader
27+
hand *bytes.Buffer
28+
readWaitOptions N.ReadWaitOptions
29+
tlsReadRecord func() error
30+
tlsHandlePostHandshakeMessage func() error
2731
}
2832

2933
func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
30-
stdConn, isSTDConn := conn.(*tls.STDConn)
31-
if !isSTDConn {
34+
var (
35+
loaded bool
36+
tlsReadRecord func() error
37+
tlsHandlePostHandshakeMessage func() error
38+
)
39+
for _, tlsCreator := range tlsRegistry {
40+
loaded, tlsReadRecord, tlsHandlePostHandshakeMessage = tlsCreator(conn)
41+
if loaded {
42+
break
43+
}
44+
}
45+
if !loaded {
3246
return nil, os.ErrInvalid
3347
}
34-
rawConn := reflect.Indirect(reflect.ValueOf(stdConn))
48+
rawConn := reflect.Indirect(reflect.ValueOf(conn))
3549
rawHalfConn := rawConn.FieldByName("in")
3650
if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
3751
return nil, E.New("badtls: invalid half conn")
@@ -57,11 +71,13 @@ func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
5771
}
5872
hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr()))
5973
return &ReadWaitConn{
60-
STDConn: stdConn,
61-
halfAccess: halfAccess,
62-
rawInput: rawInput,
63-
input: input,
64-
hand: hand,
74+
Conn: conn,
75+
halfAccess: halfAccess,
76+
rawInput: rawInput,
77+
input: input,
78+
hand: hand,
79+
tlsReadRecord: tlsReadRecord,
80+
tlsHandlePostHandshakeMessage: tlsHandlePostHandshakeMessage,
6581
}, nil
6682
}
6783

@@ -71,19 +87,19 @@ func (c *ReadWaitConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy
7187
}
7288

7389
func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
74-
err = c.Handshake()
90+
err = c.HandshakeContext(context.Background())
7591
if err != nil {
7692
return
7793
}
7894
c.halfAccess.Lock()
7995
defer c.halfAccess.Unlock()
8096
for c.input.Len() == 0 {
81-
err = tlsReadRecord(c.STDConn)
97+
err = c.tlsReadRecord()
8298
if err != nil {
8399
return
84100
}
85101
for c.hand.Len() > 0 {
86-
err = tlsHandlePostHandshakeMessage(c.STDConn)
102+
err = c.tlsHandlePostHandshakeMessage()
87103
if err != nil {
88104
return
89105
}
@@ -100,16 +116,32 @@ func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
100116
if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
101117
// recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
102118
c.rawInput.Bytes()[0] == 21 {
103-
_ = tlsReadRecord(c.STDConn)
119+
_ = c.tlsReadRecord()
104120
// return n, err // will be io.EOF on closeNotify
105121
}
106122

107123
c.readWaitOptions.PostReturn(buffer)
108124
return
109125
}
110126

111-
//go:linkname tlsReadRecord crypto/tls.(*Conn).readRecord
112-
func tlsReadRecord(c *tls.STDConn) error
127+
var tlsRegistry []func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error)
128+
129+
func init() {
130+
tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
131+
tlsConn, loaded := conn.(*tls.STDConn)
132+
if !loaded {
133+
return
134+
}
135+
return true, func() error {
136+
return stdTLSReadRecord(tlsConn)
137+
}, func() error {
138+
return stdTLSHandlePostHandshakeMessage(tlsConn)
139+
}
140+
})
141+
}
142+
143+
//go:linkname stdTLSReadRecord crypto/tls.(*Conn).readRecord
144+
func stdTLSReadRecord(c *tls.STDConn) error
113145

114-
//go:linkname tlsHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
115-
func tlsHandlePostHandshakeMessage(c *tls.STDConn) error
146+
//go:linkname stdTLSHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
147+
func stdTLSHandlePostHandshakeMessage(c *tls.STDConn) error

common/badtls/read_wait_ech.go

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//go:build go1.21 && !without_badtls && with_ech
2+
3+
package badtls
4+
5+
import (
6+
"net"
7+
_ "unsafe"
8+
9+
"github.com/sagernet/cloudflare-tls"
10+
"github.com/sagernet/sing/common"
11+
)
12+
13+
func init() {
14+
tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
15+
tlsConn, loaded := common.Cast[*tls.Conn](conn)
16+
if !loaded {
17+
return
18+
}
19+
return true, func() error {
20+
return echReadRecord(tlsConn)
21+
}, func() error {
22+
return echHandlePostHandshakeMessage(tlsConn)
23+
}
24+
})
25+
}
26+
27+
//go:linkname echReadRecord github.com/sagernet/cloudflare-tls.(*Conn).readRecord
28+
func echReadRecord(c *tls.Conn) error
29+
30+
//go:linkname echHandlePostHandshakeMessage github.com/sagernet/cloudflare-tls.(*Conn).handlePostHandshakeMessage
31+
func echHandlePostHandshakeMessage(c *tls.Conn) error

common/badtls/read_wait_utls.go

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//go:build go1.21 && !without_badtls && with_utls
2+
3+
package badtls
4+
5+
import (
6+
"net"
7+
_ "unsafe"
8+
9+
"github.com/sagernet/sing/common"
10+
"github.com/sagernet/utls"
11+
)
12+
13+
func init() {
14+
tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, tlsReadRecord func() error, tlsHandlePostHandshakeMessage func() error) {
15+
tlsConn, loaded := common.Cast[*tls.UConn](conn)
16+
if !loaded {
17+
return
18+
}
19+
return true, func() error {
20+
return utlsReadRecord(tlsConn.Conn)
21+
}, func() error {
22+
return utlsHandlePostHandshakeMessage(tlsConn.Conn)
23+
}
24+
})
25+
}
26+
27+
//go:linkname utlsReadRecord github.com/sagernet/utls.(*Conn).readRecord
28+
func utlsReadRecord(c *tls.Conn) error
29+
30+
//go:linkname utlsHandlePostHandshakeMessage github.com/sagernet/utls.(*Conn).handlePostHandshakeMessage
31+
func utlsHandlePostHandshakeMessage(c *tls.Conn) error

0 commit comments

Comments
 (0)