Skip to content

Commit 6271dc1

Browse files
committed
fix: Guard outbound reloading with locks
1 parent 46654b7 commit 6271dc1

File tree

4 files changed

+54
-14
lines changed

4 files changed

+54
-14
lines changed

config/outbound.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ type Outbound struct {
1010
Minecraft *MinecraftService `json:",omitempty"`
1111
SocketOptions *network.OutboundSocketOptions `json:",omitempty"`
1212
ProxyProtocolVersion int8 `json:",omitempty"`
13-
ProxyOptions outbound `json:",omitempty"`
13+
ProxyOptions proxyOptions `json:",omitempty"`
1414
}

config/service.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ type Service struct {
1313
Minecraft *MinecraftService `json:",omitempty"`
1414
TLSSniffing *tlsSniffing `json:",omitempty"`
1515
SocketOptions *network.InboundSocketOptions `json:",omitempty"`
16-
Outbound outbound `json:",omitempty"`
16+
Outbound proxyOptions `json:",omitempty"`
1717
}
1818

1919
type access struct {
@@ -52,8 +52,8 @@ type tlsSniffing struct {
5252
SNIAllowListTags []string `json:",omitempty"`
5353
}
5454

55-
type outbound struct {
56-
Type string
55+
type proxyOptions struct {
56+
Type string `json:",omitempty"`
5757
Network string `json:",omitempty"`
5858
Address string `json:",omitempty"`
5959
}

protocol/minecraft/outbound.go

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"net/url"
1111
"strconv"
1212
"strings"
13+
"sync"
1314
"sync/atomic"
1415

1516
"github.com/layou233/zbproxy/v3/adapter"
@@ -32,6 +33,7 @@ import (
3233
var minecraftSRV = &adapter.SRVMetadata{ServiceName: "minecraft"}
3334

3435
type Outbound struct {
36+
access sync.RWMutex
3537
logger *log.Logger
3638
config *config.Outbound
3739
router adapter.Router
@@ -58,11 +60,13 @@ func NewOutbound(logger *log.Logger, newConfig *config.Outbound) (*Outbound, err
5860
return outbound, nil
5961
}
6062

61-
func (o *Outbound) Name() string {
63+
func (o *Outbound) Name() (name string) {
64+
o.access.RLock()
6265
if o.config != nil {
63-
return o.config.Name
66+
name = o.config.Name
6467
}
65-
return ""
68+
o.access.RUnlock()
69+
return
6670
}
6771

6872
func (o *Outbound) PostInitialize(router adapter.Router, provider adapter.RouteResourceProvider) error {
@@ -176,6 +180,8 @@ func (o *Outbound) PostInitialize(router adapter.Router, provider adapter.RouteR
176180
}
177181

178182
func (o *Outbound) Reload(options adapter.OutboundReloadOptions) error {
183+
o.access.Lock()
184+
defer o.access.Unlock()
179185
o.config = options.Config
180186
o.hostnameAccessLists = nil
181187
o.nameAccessLists = nil
@@ -222,16 +228,19 @@ func (o *Outbound) InjectConnection(ctx context.Context, conn *bufio.CachedConn,
222228
if metadata.Minecraft == nil {
223229
return errors.New("require Minecraft metadata")
224230
}
231+
o.access.RLock()
232+
225233
if o.config.Minecraft.HostnameAccess.Mode != access.DefaultMode {
226234
hostnameClean := metadata.Minecraft.CleanOriginDestination()
227235
if o.config.Minecraft.HostnameAccess.LowerCase {
228236
hostnameClean = strings.ToLower(hostnameClean)
229237
}
230238
if !access.Check(o.hostnameAccessLists, o.config.Minecraft.HostnameAccess.Mode, hostnameClean) {
231-
conn.Conn.(*net.TCPConn).SetLinger(0)
232-
conn.Close()
233-
return common.Cause("hostname "+o.config.Minecraft.HostnameAccess.Mode+
239+
err := common.Cause("hostname "+o.config.Minecraft.HostnameAccess.Mode+
234240
" mode, request="+url.QueryEscape(hostnameClean)+": ", access.ErrRejected)
241+
o.access.RUnlock()
242+
conn.Conn.(*net.TCPConn).SetLinger(0)
243+
return err
235244
}
236245
}
237246
if metadata.Minecraft.SniffPosition >= 0 {
@@ -242,13 +251,15 @@ func (o *Outbound) InjectConnection(ctx context.Context, conn *bufio.CachedConn,
242251
// skip Status Request packet
243252
_, err := conn.Peek(2)
244253
if err != nil {
254+
o.access.RUnlock()
245255
return common.Cause("skip status request: ", err)
246256
}
247257
if o.config.Minecraft.MotdFavicon == "" && o.config.Minecraft.MotdDescription == "" {
248258
// directly proxy MOTD from server
249259
var remoteConn net.Conn
250260
remoteConn, err = o.connectServer(ctx, metadata)
251261
if err != nil {
262+
o.access.RUnlock()
252263
return common.Cause("request remote MOTD: ", err)
253264
}
254265
//remoteConn.(*net.TCPConn).SetLinger(0) // for some reason
@@ -291,8 +302,10 @@ func (o *Outbound) InjectConnection(ctx context.Context, conn *bufio.CachedConn,
291302
_, err = remoteConn.Write(buffer.Bytes())
292303
buffer.Release()
293304
if err != nil {
305+
o.access.RUnlock()
294306
return common.Cause("request remote MOTD: ", err)
295307
}
308+
o.access.RUnlock()
296309
return bufio.CopyConn(remoteConn, conn)
297310
} else {
298311
motd := generateMOTD(metadata.Minecraft.ProtocolVersion, o.config, &o.onlineCount)
@@ -307,6 +320,8 @@ func (o *Outbound) InjectConnection(ctx context.Context, conn *bufio.CachedConn,
307320
}
308321
err = clientMC.WriteVectorizedPacket(buffer, motd)
309322
if err != nil {
323+
o.access.RUnlock()
324+
buffer.Release()
310325
return common.Cause("respond MOTD: ", err)
311326
}
312327

@@ -319,21 +334,25 @@ func (o *Outbound) InjectConnection(ctx context.Context, conn *bufio.CachedConn,
319334
err = clientMC.WritePacket(buffer)
320335
buffer.Release()
321336
if err != nil {
337+
o.access.RUnlock()
322338
return common.Cause("respond 0ms ping: ", err)
323339
}
324340
default:
325341
err = clientMC.ReadLimitedPacket(buffer, 9)
326342
if err != nil {
343+
o.access.RUnlock()
327344
buffer.Release()
328345
return common.Cause("read ping request: ", err)
329346
}
330347
err = clientMC.WritePacket(buffer)
331348
buffer.Release()
332349
if err != nil {
350+
o.access.RUnlock()
333351
return common.Cause("respond ping request: ", err)
334352
}
335353
}
336354
o.logger.Info().Str("id", metadata.ConnectionID).Str("outbound", o.config.Name).Msg("Responded MOTD")
355+
o.access.RUnlock()
337356
return nil
338357
}
339358

@@ -348,18 +367,21 @@ func (o *Outbound) InjectConnection(ctx context.Context, conn *bufio.CachedConn,
348367
if !access.Check(o.nameAccessLists, o.config.Minecraft.NameAccess.Mode, name) {
349368
msg, err := generateKickMessage(o.config, metadata.Minecraft.PlayerName).MarshalJSON()
350369
if err != nil { // almost impossible
370+
o.access.RUnlock()
351371
buffer.Release()
352372
return common.Cause("generate kick message: ", err)
353373
}
354374
buffer.WriteByte(0) // Client bound : Disconnect (login)
355375
mcprotocol.VarInt(len(msg)).WriteToBuffer(buffer)
356376
err = mcprotocol.Conn{Writer: common.UnwrapWriter(conn)}.WriteVectorizedPacket(buffer, msg)
357377
if err != nil {
378+
o.access.RUnlock()
358379
buffer.Release()
359380
return common.Cause("send kick packet: ", err)
360381
}
361382
o.logger.Warn().Str("id", metadata.ConnectionID).Str("outbound", o.config.Name).
362383
Str("player", metadata.Minecraft.PlayerName).Msg("Kicked by name access control")
384+
o.access.RUnlock()
363385
conn.Conn.(*net.TCPConn).SetLinger(10)
364386
buffer.Release()
365387
return nil
@@ -369,25 +391,29 @@ func (o *Outbound) InjectConnection(ctx context.Context, conn *bufio.CachedConn,
369391
o.config.Minecraft.OnlineCount.Max <= o.onlineCount.Load() {
370392
msg, err := generatePlayerNumberLimitExceededMessage(o.config, metadata.Minecraft.PlayerName).MarshalJSON()
371393
if err != nil {
394+
o.access.RUnlock()
372395
buffer.Release()
373396
return common.Cause("generate player number limit exceeded packet: ", err)
374397
}
375398
buffer.WriteByte(0)
376399
mcprotocol.VarInt(len(msg)).WriteToBuffer(buffer)
377400
err = mcprotocol.Conn{Writer: common.UnwrapWriter(conn)}.WriteVectorizedPacket(buffer, msg)
378401
if err != nil {
402+
o.access.RUnlock()
379403
buffer.Release()
380404
return common.Cause("send player number limit exceeded packet: ", err)
381405
}
382406
o.logger.Warn().Str("id", metadata.ConnectionID).Str("outbound", o.config.Name).
383407
Str("player", metadata.Minecraft.PlayerName).Msg("Kicked by player number limiter")
408+
o.access.RUnlock()
384409
conn.Conn.(*net.TCPConn).SetLinger(10)
385410
buffer.Release()
386411
return nil
387412
}
388413

389414
serverConn, err := o.connectServer(ctx, metadata)
390415
if err != nil {
416+
o.access.RUnlock()
391417
buffer.Release()
392418
return common.Cause("connect server: ", err)
393419
}
@@ -420,24 +446,28 @@ func (o *Outbound) InjectConnection(ctx context.Context, conn *bufio.CachedConn,
420446
_, err = vector.WriteTo(serverConn)
421447
buffer.Release()
422448
if err != nil {
449+
o.access.RUnlock()
423450
serverConn.Close()
424451
return common.Cause("server handshake: ", err)
425452
}
426453
cache.Advance(cache.Len()) // all written
427454
o.logger.Info().Str("id", metadata.ConnectionID).Str("outbound", o.config.Name).
428455
Str("player", metadata.Minecraft.PlayerName).Msg("Created Minecraft connection")
456+
o.access.RUnlock()
429457
o.onlineCount.Add(1)
430458
err = bufio.CopyConn(serverConn, conn)
431459
o.onlineCount.Add(-1)
432460
return err
433461

434462
case mcprotocol.IntentTransfer:
435463
// TODO: Minecraft transfer support
464+
o.access.RUnlock()
436465
conn.Conn.(*net.TCPConn).SetLinger(0)
437466
return conn.Close()
438467

439468
default:
440-
return errors.New("unknown next state")
469+
o.access.RUnlock()
470+
return errors.New("unknown intent")
441471
}
442472
}
443473

protocol/outbound.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"net"
88
"net/netip"
99
"os"
10+
"sync"
1011

1112
"github.com/layou233/zbproxy/v3/adapter"
1213
"github.com/layou233/zbproxy/v3/common"
@@ -34,6 +35,7 @@ func NewOutbound(logger *log.Logger, newConfig *config.Outbound) (adapter.Outbou
3435
}
3536

3637
type Plain struct {
38+
access sync.RWMutex
3739
logger *log.Logger
3840
config *config.Outbound
3941
router adapter.Router
@@ -46,11 +48,13 @@ var (
4648
_ network.Dialer = (*Plain)(nil)
4749
)
4850

49-
func (o *Plain) Name() string {
51+
func (o *Plain) Name() (name string) {
52+
o.access.RLock()
5053
if o.config != nil {
51-
return o.config.Name
54+
name = o.config.Name
5255
}
53-
return ""
56+
o.access.RUnlock()
57+
return
5458
}
5559

5660
func (o *Plain) PostInitialize(router adapter.Router, provider adapter.RouteResourceProvider) error {
@@ -87,15 +91,21 @@ func (o *Plain) PostInitialize(router adapter.Router, provider adapter.RouteReso
8791
}
8892

8993
func (o *Plain) Reload(options adapter.OutboundReloadOptions) error {
94+
o.access.Lock()
95+
defer o.access.Unlock()
9096
o.config = options.Config
9197
return o.PostInitialize(o.router, &options)
9298
}
9399

94100
func (o *Plain) DialContext(ctx context.Context, network string, address string) (net.Conn, error) {
101+
o.access.RLock()
102+
defer o.access.RUnlock()
95103
return o.dialer.DialContext(ctx, network, address)
96104
}
97105

98106
func (o *Plain) DialContextWithMetadata(ctx context.Context, network string, address string, metadata *adapter.Metadata) (net.Conn, error) {
107+
o.access.RLock()
108+
defer o.access.RUnlock()
99109
conn, err := adapter.DialContextWithMetadata(o.dialer, ctx, network, address, metadata)
100110
if err != nil {
101111
return nil, err

0 commit comments

Comments
 (0)