Skip to content

Commit 319700d

Browse files
committed
Add synchronization to Connection.net.Conn to prevent data races and panics
1 parent ce959a4 commit 319700d

File tree

1 file changed

+39
-25
lines changed

1 file changed

+39
-25
lines changed

connection.go

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ type Connection struct {
7474
idleDeadline time.Time
7575

7676
// connection object
77-
conn net.Conn
77+
conn net.Conn
78+
// connection mutex
79+
connMu sync.RWMutex
7880
totalReceived int64
7981

8082
// histogram to adjust the buff size to optimal value over time
@@ -96,8 +98,6 @@ type Connection struct {
9698
// LimitReader is used to avoid that problem.
9799
limitReader *io.LimitedReader
98100

99-
closer sync.Once
100-
101101
// Used to track the last time the connection was used. This is used to determine
102102
// if the connection is idle/timeout and should be closed.
103103
salvageConnection bool
@@ -119,7 +119,7 @@ func errToAerospikeErr(conn *Connection, err error) (aerr Error) {
119119
if terr, ok := err.(net.Error); ok {
120120
if terr.Timeout() {
121121
if conn != nil {
122-
if conn.node != nil {
122+
if conn.node != nil {
123123
conn.node.stats.ConnectionsTimeoutErrors.IncrementAndGet()
124124
}
125125
if errors.Is(terr, os.ErrDeadlineExceeded) {
@@ -137,7 +137,7 @@ func errToAerospikeErr(conn *Connection, err error) (aerr Error) {
137137

138138
// set node if exists
139139
if conn != nil {
140-
aerr.setNode(conn.node)
140+
_ = aerr.setNode(conn.node)
141141
}
142142

143143
return aerr
@@ -224,6 +224,9 @@ func NewConnection(policy *ClientPolicy, host *Host) (*Connection, Error) {
224224

225225
// Write writes the slice to the connection buffer.
226226
func (ctn *Connection) Write(buf []byte) (total int, aerr Error) {
227+
ctn.connMu.RLock()
228+
defer ctn.connMu.RUnlock()
229+
227230
var err error
228231

229232
// make sure all bytes are written
@@ -252,6 +255,9 @@ func (ctn *Connection) Write(buf []byte) (total int, aerr Error) {
252255

253256
// Read reads from connection buffer to the provided slice.
254257
func (ctn *Connection) Read(buf []byte, length int) (total int, aerr Error) {
258+
ctn.connMu.RLock()
259+
defer ctn.connMu.RUnlock()
260+
255261
var err error
256262

257263
// if all bytes are not read, retry until successful
@@ -302,6 +308,9 @@ func (ctn *Connection) Read(buf []byte, length int) (total int, aerr Error) {
302308

303309
// IsConnected returns true if the connection is not closed yet.
304310
func (ctn *Connection) IsConnected() bool {
311+
ctn.connMu.RLock()
312+
defer ctn.connMu.RUnlock()
313+
305314
return ctn.conn != nil
306315
}
307316

@@ -375,29 +384,34 @@ func (ctn *Connection) SetTimeout(deadline time.Time, socketTimeout time.Duratio
375384
return nil
376385
}
377386

378-
// Close closes the connection
387+
// Close closes the connection.
379388
func (ctn *Connection) Close() {
380-
ctn.closer.Do(func() {
381-
if ctn != nil && ctn.conn != nil {
382-
// deregister
383-
if ctn.node != nil {
384-
ctn.node.connectionCount.DecrementAndGet()
385-
ctn.node.stats.ConnectionsClosed.IncrementAndGet()
386-
}
389+
if ctn == nil {
390+
return
391+
}
387392

388-
if err := ctn.conn.Close(); err != nil {
389-
logger.Logger.Warn("%s", err.Error())
390-
}
391-
ctn.conn = nil
393+
ctn.connMu.Lock()
394+
defer ctn.connMu.Unlock()
392395

393-
// put the data buffer back in the pool in case it gets used again
394-
buffPool.Put(ctn.dataBuffer)
396+
if ctn.conn != nil {
397+
// deregister
398+
if ctn.node != nil {
399+
ctn.node.connectionCount.DecrementAndGet()
400+
ctn.node.stats.ConnectionsClosed.IncrementAndGet()
401+
}
395402

396-
ctn.dataBuffer = nil
397-
ctn.origDataBuffer = nil
398-
ctn.node = nil
403+
if err := ctn.conn.Close(); err != nil {
404+
logger.Logger.Warn("%s", err.Error())
399405
}
400-
})
406+
ctn.conn = nil
407+
408+
// put the data buffer back in the pool in case it gets used again
409+
buffPool.Put(ctn.dataBuffer)
410+
411+
ctn.dataBuffer = nil
412+
ctn.origDataBuffer = nil
413+
ctn.node = nil
414+
}
401415
}
402416

403417
// Login will send authentication information to the server.
@@ -501,7 +515,7 @@ func (ctn *Connection) refresh() {
501515
now := time.Now()
502516
ctn.idleDeadline = now.Add(ctn.idleTimeout)
503517
if ctn.inflater != nil {
504-
ctn.inflater.Close()
518+
_ = ctn.inflater.Close()
505519
}
506520
ctn.compressed = false
507521
ctn.inflater = nil
@@ -552,4 +566,4 @@ func KeepConnection(err Error) bool {
552566
types.SCAN_ABORT,
553567
types.QUERY_ABORTED,
554568
types.TIMEOUT)
555-
}
569+
}

0 commit comments

Comments
 (0)