@@ -4,6 +4,8 @@ package badtls
4
4
5
5
import (
6
6
"bytes"
7
+ "context"
8
+ "net"
7
9
"os"
8
10
"reflect"
9
11
"sync"
@@ -18,20 +20,32 @@ import (
18
20
var _ N.ReadWaiter = (* ReadWaitConn )(nil )
19
21
20
22
type ReadWaitConn struct {
21
- * tls.STDConn
22
- halfAccess * sync.Mutex
23
- rawInput * bytes.Buffer
24
- input * bytes.Reader
25
- hand * bytes.Buffer
26
- readWaitOptions N.ReadWaitOptions
23
+ tls.Conn
24
+ halfAccess * sync.Mutex
25
+ rawInput * bytes.Buffer
26
+ input * bytes.Reader
27
+ hand * bytes.Buffer
28
+ readWaitOptions N.ReadWaitOptions
29
+ tlsReadRecord func () error
30
+ tlsHandlePostHandshakeMessage func () error
27
31
}
28
32
29
33
func NewReadWaitConn (conn tls.Conn ) (tls.Conn , error ) {
30
- stdConn , isSTDConn := conn .(* tls.STDConn )
31
- if ! isSTDConn {
34
+ var (
35
+ loaded bool
36
+ tlsReadRecord func () error
37
+ tlsHandlePostHandshakeMessage func () error
38
+ )
39
+ for _ , tlsCreator := range tlsRegistry {
40
+ loaded , tlsReadRecord , tlsHandlePostHandshakeMessage = tlsCreator (conn )
41
+ if loaded {
42
+ break
43
+ }
44
+ }
45
+ if ! loaded {
32
46
return nil , os .ErrInvalid
33
47
}
34
- rawConn := reflect .Indirect (reflect .ValueOf (stdConn ))
48
+ rawConn := reflect .Indirect (reflect .ValueOf (conn ))
35
49
rawHalfConn := rawConn .FieldByName ("in" )
36
50
if ! rawHalfConn .IsValid () || rawHalfConn .Kind () != reflect .Struct {
37
51
return nil , E .New ("badtls: invalid half conn" )
@@ -57,11 +71,13 @@ func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
57
71
}
58
72
hand := (* bytes .Buffer )(unsafe .Pointer (rawHand .UnsafeAddr ()))
59
73
return & ReadWaitConn {
60
- STDConn : stdConn ,
61
- halfAccess : halfAccess ,
62
- rawInput : rawInput ,
63
- input : input ,
64
- hand : hand ,
74
+ Conn : conn ,
75
+ halfAccess : halfAccess ,
76
+ rawInput : rawInput ,
77
+ input : input ,
78
+ hand : hand ,
79
+ tlsReadRecord : tlsReadRecord ,
80
+ tlsHandlePostHandshakeMessage : tlsHandlePostHandshakeMessage ,
65
81
}, nil
66
82
}
67
83
@@ -71,19 +87,19 @@ func (c *ReadWaitConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy
71
87
}
72
88
73
89
func (c * ReadWaitConn ) WaitReadBuffer () (buffer * buf.Buffer , err error ) {
74
- err = c .Handshake ( )
90
+ err = c .HandshakeContext ( context . Background () )
75
91
if err != nil {
76
92
return
77
93
}
78
94
c .halfAccess .Lock ()
79
95
defer c .halfAccess .Unlock ()
80
96
for c .input .Len () == 0 {
81
- err = tlsReadRecord ( c . STDConn )
97
+ err = c . tlsReadRecord ( )
82
98
if err != nil {
83
99
return
84
100
}
85
101
for c .hand .Len () > 0 {
86
- err = tlsHandlePostHandshakeMessage ( c . STDConn )
102
+ err = c . tlsHandlePostHandshakeMessage ( )
87
103
if err != nil {
88
104
return
89
105
}
@@ -100,16 +116,32 @@ func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
100
116
if n != 0 && c .input .Len () == 0 && c .rawInput .Len () > 0 &&
101
117
// recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
102
118
c .rawInput .Bytes ()[0 ] == 21 {
103
- _ = tlsReadRecord ( c . STDConn )
119
+ _ = c . tlsReadRecord ( )
104
120
// return n, err // will be io.EOF on closeNotify
105
121
}
106
122
107
123
c .readWaitOptions .PostReturn (buffer )
108
124
return
109
125
}
110
126
111
- //go:linkname tlsReadRecord crypto/tls.(*Conn).readRecord
112
- func tlsReadRecord (c * tls.STDConn ) error
127
+ var tlsRegistry []func (conn net.Conn ) (loaded bool , tlsReadRecord func () error , tlsHandlePostHandshakeMessage func () error )
128
+
129
+ func init () {
130
+ tlsRegistry = append (tlsRegistry , func (conn net.Conn ) (loaded bool , tlsReadRecord func () error , tlsHandlePostHandshakeMessage func () error ) {
131
+ tlsConn , loaded := conn .(* tls.STDConn )
132
+ if ! loaded {
133
+ return
134
+ }
135
+ return true , func () error {
136
+ return stdTLSReadRecord (tlsConn )
137
+ }, func () error {
138
+ return stdTLSHandlePostHandshakeMessage (tlsConn )
139
+ }
140
+ })
141
+ }
142
+
143
+ //go:linkname stdTLSReadRecord crypto/tls.(*Conn).readRecord
144
+ func stdTLSReadRecord (c * tls.STDConn ) error
113
145
114
- //go:linkname tlsHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
115
- func tlsHandlePostHandshakeMessage (c * tls.STDConn ) error
146
+ //go:linkname stdTLSHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
147
+ func stdTLSHandlePostHandshakeMessage (c * tls.STDConn ) error
0 commit comments