Skip to content

Commit c597c63

Browse files
committed
conn: do not allow ReceiveIPvX to race with Close
If Close is called after ReceiveIPvX, then ReceiveIPvX will block on an invalid or potentially reused fd. Signed-off-by: Jason A. Donenfeld <[email protected]>
1 parent 29b0477 commit c597c63

File tree

1 file changed

+32
-17
lines changed

1 file changed

+32
-17
lines changed

conn/conn_linux.go

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@ import (
1818
"golang.org/x/sys/unix"
1919
)
2020

21-
const (
22-
FD_ERR = -1
23-
)
24-
2521
type IPv4Source struct {
2622
Src [4]byte
2723
Ifindex int32
@@ -63,6 +59,7 @@ type nativeBind struct {
6359
sock4 int
6460
sock6 int
6561
lastMark uint32
62+
closing sync.RWMutex
6663
}
6764

6865
var _ Endpoint = (*NativeEndpoint)(nil)
@@ -129,7 +126,7 @@ func createBind(port uint16) (Bind, uint16, error) {
129126
port = newPort
130127
}
131128

132-
if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR {
129+
if bind.sock4 == -1 && bind.sock6 == -1 {
133130
return nil, 0, errors.New("ipv4 and ipv6 not supported")
134131
}
135132

@@ -141,6 +138,9 @@ func (bind *nativeBind) LastMark() uint32 {
141138
}
142139

143140
func (bind *nativeBind) SetMark(value uint32) error {
141+
bind.closing.RLock()
142+
defer bind.closing.RUnlock()
143+
144144
if bind.sock6 != -1 {
145145
err := unix.SetsockoptInt(
146146
bind.sock6,
@@ -171,20 +171,26 @@ func (bind *nativeBind) SetMark(value uint32) error {
171171
return nil
172172
}
173173

174-
func closeUnblock(fd int) error {
175-
// shutdown to unblock readers and writers
176-
unix.Shutdown(fd, unix.SHUT_RDWR)
177-
return unix.Close(fd)
178-
}
179-
180174
func (bind *nativeBind) Close() error {
181175
var err1, err2 error
176+
bind.closing.RLock()
177+
if bind.sock6 != -1 {
178+
unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
179+
}
180+
if bind.sock4 != -1 {
181+
unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
182+
}
183+
bind.closing.RUnlock()
184+
bind.closing.Lock()
182185
if bind.sock6 != -1 {
183-
err1 = closeUnblock(bind.sock6)
186+
err1 = unix.Close(bind.sock6)
187+
bind.sock6 = -1
184188
}
185189
if bind.sock4 != -1 {
186-
err2 = closeUnblock(bind.sock4)
190+
err2 = unix.Close(bind.sock4)
191+
bind.sock4 = -1
187192
}
193+
bind.closing.Unlock()
188194

189195
if err1 != nil {
190196
return err1
@@ -193,6 +199,9 @@ func (bind *nativeBind) Close() error {
193199
}
194200

195201
func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
202+
bind.closing.RLock()
203+
defer bind.closing.RUnlock()
204+
196205
var end NativeEndpoint
197206
if bind.sock6 == -1 {
198207
return 0, nil, syscall.EAFNOSUPPORT
@@ -206,6 +215,9 @@ func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
206215
}
207216

208217
func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
218+
bind.closing.RLock()
219+
defer bind.closing.RUnlock()
220+
209221
var end NativeEndpoint
210222
if bind.sock4 == -1 {
211223
return 0, nil, syscall.EAFNOSUPPORT
@@ -219,6 +231,9 @@ func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
219231
}
220232

221233
func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
234+
bind.closing.RLock()
235+
defer bind.closing.RUnlock()
236+
222237
nend := end.(*NativeEndpoint)
223238
if !nend.isV6 {
224239
if bind.sock4 == -1 {
@@ -316,7 +331,7 @@ func create4(port uint16) (int, uint16, error) {
316331
)
317332

318333
if err != nil {
319-
return FD_ERR, 0, err
334+
return -1, 0, err
320335
}
321336

322337
addr := unix.SockaddrInet4{
@@ -338,7 +353,7 @@ func create4(port uint16) (int, uint16, error) {
338353
return unix.Bind(fd, &addr)
339354
}(); err != nil {
340355
unix.Close(fd)
341-
return FD_ERR, 0, err
356+
return -1, 0, err
342357
}
343358

344359
sa, err := unix.Getsockname(fd)
@@ -360,7 +375,7 @@ func create6(port uint16) (int, uint16, error) {
360375
)
361376

362377
if err != nil {
363-
return FD_ERR, 0, err
378+
return -1, 0, err
364379
}
365380

366381
// set sockopts and bind
@@ -392,7 +407,7 @@ func create6(port uint16) (int, uint16, error) {
392407

393408
}(); err != nil {
394409
unix.Close(fd)
395-
return FD_ERR, 0, err
410+
return -1, 0, err
396411
}
397412

398413
sa, err := unix.Getsockname(fd)

0 commit comments

Comments
 (0)