Skip to content

Commit 2c58592

Browse files
authored
Fix: data race (#317)
1 parent 6c907b7 commit 2c58592

File tree

4 files changed

+69
-44
lines changed

4 files changed

+69
-44
lines changed

common/io.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@ package common
33
import (
44
"io"
55
"net"
6+
"sync"
67

78
"github.com/p4gefau1t/trojan-go/log"
89
)
910

1011
type RewindReader struct {
12+
mu sync.Mutex
1113
rawReader io.Reader
1214
buf []byte
1315
bufReadIdx int
@@ -17,13 +19,16 @@ type RewindReader struct {
1719
}
1820

1921
func (r *RewindReader) Read(p []byte) (int, error) {
22+
r.mu.Lock()
23+
defer r.mu.Unlock()
24+
2025
if r.rewound {
2126
if len(r.buf) > r.bufReadIdx {
2227
n := copy(p, r.buf[r.bufReadIdx:])
2328
r.bufReadIdx += n
2429
return n, nil
2530
}
26-
r.rewound = false //all buffering content has been read
31+
r.rewound = false // all buffering content has been read
2732
}
2833
n, err := r.rawReader.Read(p)
2934
if r.buffering {
@@ -59,19 +64,24 @@ func (r *RewindReader) Discard(n int) (int, error) {
5964
}
6065

6166
func (r *RewindReader) Rewind() {
67+
r.mu.Lock()
6268
if r.bufferSize == 0 {
6369
panic("no buffer")
6470
}
6571
r.rewound = true
6672
r.bufReadIdx = 0
73+
r.mu.Unlock()
6774
}
6875

6976
func (r *RewindReader) StopBuffering() {
77+
r.mu.Lock()
7078
r.buffering = false
79+
r.mu.Unlock()
7180
}
7281

7382
func (r *RewindReader) SetBufferSize(size int) {
74-
if size == 0 { //disable buffering
83+
r.mu.Lock()
84+
if size == 0 { // disable buffering
7585
if !r.buffering {
7686
panic("reader is disabled")
7787
}
@@ -88,6 +98,7 @@ func (r *RewindReader) SetBufferSize(size int) {
8898
r.bufferSize = size
8999
r.buf = make([]byte, 0, size)
90100
}
101+
r.mu.Unlock()
91102
}
92103

93104
type RewindConn struct {

test/util/util.go

Lines changed: 44 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,33 @@ import (
1212

1313
// CheckConn checks if two netConn were connected and work properly
1414
func CheckConn(a net.Conn, b net.Conn) bool {
15-
payload1 := [1024]byte{}
16-
payload2 := [1024]byte{}
17-
rand.Reader.Read(payload1[:])
18-
rand.Reader.Read(payload2[:])
15+
payload1 := make([]byte, 1024)
16+
payload2 := make([]byte, 1024)
17+
18+
result1 := make([]byte, 1024)
19+
result2 := make([]byte, 1024)
20+
21+
rand.Reader.Read(payload1)
22+
rand.Reader.Read(payload2)
1923

20-
result1 := [1024]byte{}
21-
result2 := [1024]byte{}
2224
wg := sync.WaitGroup{}
2325
wg.Add(2)
26+
2427
go func() {
25-
a.Write(payload1[:])
26-
a.Read(result2[:])
28+
a.Write(payload1)
29+
a.Read(result2)
2730
wg.Done()
2831
}()
32+
2933
go func() {
30-
b.Read(result1[:])
31-
b.Write(payload2[:])
34+
b.Read(result1)
35+
b.Write(payload2)
3236
wg.Done()
3337
}()
38+
3439
wg.Wait()
35-
if !bytes.Equal(payload1[:], result1[:]) || !bytes.Equal(payload2[:], result2[:]) {
36-
return false
37-
}
38-
return true
40+
41+
return bytes.Equal(payload1, result1) && bytes.Equal(payload2, result2)
3942
}
4043

4144
// CheckPacketOverConn checks if two PacketConn streaming over a connection work properly
@@ -45,55 +48,54 @@ func CheckPacketOverConn(a, b net.PacketConn) bool {
4548
IP: net.ParseIP("127.0.0.1"),
4649
Port: port,
4750
}
48-
payload1 := [1024]byte{}
49-
payload2 := [1024]byte{}
50-
rand.Reader.Read(payload1[:])
51-
rand.Reader.Read(payload2[:])
5251

53-
result1 := [1024]byte{}
54-
result2 := [1024]byte{}
52+
payload1 := make([]byte, 1024)
53+
payload2 := make([]byte, 1024)
54+
55+
result1 := make([]byte, 1024)
56+
result2 := make([]byte, 1024)
5557

56-
common.Must2(a.WriteTo(payload1[:], addr))
57-
_, addr1, err := b.ReadFrom(result1[:])
58+
rand.Reader.Read(payload1)
59+
rand.Reader.Read(payload2)
60+
61+
common.Must2(a.WriteTo(payload1, addr))
62+
_, addr1, err := b.ReadFrom(result1)
5863
common.Must(err)
5964
if addr1.String() != addr.String() {
6065
return false
6166
}
6267

63-
common.Must2(a.WriteTo(payload2[:], addr))
64-
_, addr2, err := b.ReadFrom(result2[:])
68+
common.Must2(a.WriteTo(payload2, addr))
69+
_, addr2, err := b.ReadFrom(result2)
6570
common.Must(err)
6671
if addr2.String() != addr.String() {
6772
return false
6873
}
69-
if !bytes.Equal(payload1[:], result1[:]) || !bytes.Equal(payload2[:], result2[:]) {
70-
return false
71-
}
72-
return true
74+
75+
return bytes.Equal(payload1, result1) && bytes.Equal(payload2, result2)
7376
}
7477

7578
func CheckPacket(a, b net.PacketConn) bool {
76-
payload1 := [1024]byte{}
77-
payload2 := [1024]byte{}
78-
rand.Reader.Read(payload1[:])
79-
rand.Reader.Read(payload2[:])
79+
payload1 := make([]byte, 1024)
80+
payload2 := make([]byte, 1024)
8081

81-
result1 := [1024]byte{}
82-
result2 := [1024]byte{}
82+
result1 := make([]byte, 1024)
83+
result2 := make([]byte, 1024)
8384

84-
_, err := a.WriteTo(payload1[:], b.LocalAddr())
85+
rand.Reader.Read(payload1)
86+
rand.Reader.Read(payload2)
87+
88+
_, err := a.WriteTo(payload1, b.LocalAddr())
8589
common.Must(err)
86-
_, _, err = b.ReadFrom(result1[:])
90+
_, _, err = b.ReadFrom(result1)
8791
common.Must(err)
8892

89-
_, err = b.WriteTo(payload2[:], a.LocalAddr())
93+
_, err = b.WriteTo(payload2, a.LocalAddr())
9094
common.Must(err)
91-
_, _, err = a.ReadFrom(result2[:])
95+
_, _, err = a.ReadFrom(result2)
9296
common.Must(err)
93-
if !bytes.Equal(payload1[:], result1[:]) || !bytes.Equal(payload2[:], result2[:]) {
94-
return false
95-
}
96-
return true
97+
98+
return bytes.Equal(payload1, result1) && bytes.Equal(payload2, result2)
9799
}
98100

99101
func GetTestAddr() string {

tunnel/adapter/server.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package adapter
33
import (
44
"context"
55
"net"
6+
"sync"
67

78
"github.com/p4gefau1t/trojan-go/common"
89
"github.com/p4gefau1t/trojan-go/config"
@@ -18,6 +19,7 @@ type Server struct {
1819
udpListener net.PacketConn
1920
socksConn chan tunnel.Conn
2021
httpConn chan tunnel.Conn
22+
socksLock sync.RWMutex
2123
nextSocks bool
2224
ctx context.Context
2325
cancel context.CancelFunc
@@ -45,7 +47,9 @@ func (s *Server) acceptConnLoop() {
4547
log.Error(common.NewError("failed to detect proxy protocol type").Base(err))
4648
continue
4749
}
50+
s.socksLock.RLock()
4851
if buf[0] == 5 && s.nextSocks {
52+
s.socksLock.RUnlock()
4953
log.Debug("socks5 connection")
5054
s.socksConn <- &freedom.Conn{
5155
Conn: rewindConn,
@@ -68,7 +72,9 @@ func (s *Server) AcceptConn(overlay tunnel.Tunnel) (tunnel.Conn, error) {
6872
return nil, common.NewError("adapter closed")
6973
}
7074
} else if _, ok := overlay.(*socks.Tunnel); ok {
75+
s.socksLock.Lock()
7176
s.nextSocks = true
77+
s.socksLock.Unlock()
7278
select {
7379
case conn := <-s.socksConn:
7480
return conn, nil

tunnel/transport/server.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"os"
99
"os/exec"
1010
"strconv"
11+
"sync"
1112
"time"
1213

1314
"github.com/p4gefau1t/trojan-go/common"
@@ -22,6 +23,7 @@ type Server struct {
2223
cmd *exec.Cmd
2324
connChan chan tunnel.Conn
2425
wsChan chan tunnel.Conn
26+
httpLock sync.RWMutex
2527
nextHTTP bool
2628
ctx context.Context
2729
cancel context.CancelFunc
@@ -50,7 +52,9 @@ func (s *Server) acceptLoop() {
5052

5153
go func(tcpConn net.Conn) {
5254
log.Info("tcp connection from", tcpConn.RemoteAddr())
55+
s.httpLock.RLock()
5356
if s.nextHTTP { // plaintext mode enabled
57+
s.httpLock.RUnlock()
5458
// we use real http header parser to mimic a real http server
5559
rewindConn := common.NewRewindConn(tcpConn)
5660
rewindConn.SetBufferSize(512)
@@ -84,7 +88,9 @@ func (s *Server) acceptLoop() {
8488
func (s *Server) AcceptConn(overlay tunnel.Tunnel) (tunnel.Conn, error) {
8589
// TODO fix import cycle
8690
if overlay != nil && (overlay.Name() == "WEBSOCKET" || overlay.Name() == "HTTP") {
91+
s.httpLock.Lock()
8792
s.nextHTTP = true
93+
s.httpLock.Unlock()
8894
select {
8995
case conn := <-s.wsChan:
9096
return conn, nil

0 commit comments

Comments
 (0)