Skip to content

Commit 0dbfda5

Browse files
ivilaZehui Chen
and
Zehui Chen
authored
fix memory leaks in PrepareStatementDB (#7142)
* fix memory leaks in PrepareStatementDB * Fix CR: 1) Fix potential Segmentation Fault in Reset function 2) Setting db.Stmts to nil map when Close to avoid further using * Add Test: 1) TestPreparedStmtConcurrentReset 2) TestPreparedStmtConcurrentClose * Fix test, create new connection to keep away from other tests --------- Co-authored-by: Zehui Chen <[email protected]>
1 parent 4a50b36 commit 0dbfda5

File tree

2 files changed

+175
-16
lines changed

2 files changed

+175
-16
lines changed

prepare_stmt.go

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,16 @@ type Stmt struct {
1717
}
1818

1919
type PreparedStmtDB struct {
20-
Stmts map[string]*Stmt
21-
PreparedSQL []string
22-
Mux *sync.RWMutex
20+
Stmts map[string]*Stmt
21+
Mux *sync.RWMutex
2322
ConnPool
2423
}
2524

2625
func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
2726
return &PreparedStmtDB{
28-
ConnPool: connPool,
29-
Stmts: make(map[string]*Stmt),
30-
Mux: &sync.RWMutex{},
31-
PreparedSQL: make([]string, 0, 100),
27+
ConnPool: connPool,
28+
Stmts: make(map[string]*Stmt),
29+
Mux: &sync.RWMutex{},
3230
}
3331
}
3432

@@ -48,22 +46,32 @@ func (db *PreparedStmtDB) Close() {
4846
db.Mux.Lock()
4947
defer db.Mux.Unlock()
5048

51-
for _, query := range db.PreparedSQL {
52-
if stmt, ok := db.Stmts[query]; ok {
53-
delete(db.Stmts, query)
54-
go stmt.Close()
55-
}
49+
for _, stmt := range db.Stmts {
50+
go func(s *Stmt) {
51+
// make sure the stmt must finish preparation first
52+
<-s.prepared
53+
if s.Stmt != nil {
54+
_ = s.Close()
55+
}
56+
}(stmt)
5657
}
58+
// setting db.Stmts to nil to avoid further using
59+
db.Stmts = nil
5760
}
5861

5962
func (sdb *PreparedStmtDB) Reset() {
6063
sdb.Mux.Lock()
6164
defer sdb.Mux.Unlock()
6265

6366
for _, stmt := range sdb.Stmts {
64-
go stmt.Close()
67+
go func(s *Stmt) {
68+
// make sure the stmt must finish preparation first
69+
<-s.prepared
70+
if s.Stmt != nil {
71+
_ = s.Close()
72+
}
73+
}(stmt)
6574
}
66-
sdb.PreparedSQL = make([]string, 0, 100)
6775
sdb.Stmts = make(map[string]*Stmt)
6876
}
6977

@@ -93,7 +101,12 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
93101

94102
return *stmt, nil
95103
}
96-
104+
// check db.Stmts first to avoid Segmentation Fault(setting value to nil map)
105+
// which cause by calling Close and executing SQL concurrently
106+
if db.Stmts == nil {
107+
db.Mux.Unlock()
108+
return Stmt{}, ErrInvalidDB
109+
}
97110
// cache preparing stmt first
98111
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
99112
db.Stmts[query] = &cacheStmt
@@ -118,7 +131,6 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
118131

119132
db.Mux.Lock()
120133
cacheStmt.Stmt = stmt
121-
db.PreparedSQL = append(db.PreparedSQL, query)
122134
db.Mux.Unlock()
123135

124136
return cacheStmt, nil

tests/prepared_stmt_test.go

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"errors"
66
"sync"
7+
"sync/atomic"
78
"testing"
89
"time"
910

@@ -167,3 +168,149 @@ func TestPreparedStmtReset(t *testing.T) {
167168
t.Fatalf("prepared stmt should be empty")
168169
}
169170
}
171+
172+
func isUsingClosedConnError(err error) bool {
173+
// https://github.com/golang/go/blob/e705a2d16e4ece77e08e80c168382cdb02890f5b/src/database/sql/sql.go#L2717
174+
return err.Error() == "sql: statement is closed"
175+
}
176+
177+
// TestPreparedStmtConcurrentReset test calling reset and executing SQL concurrently
178+
// this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt
179+
func TestPreparedStmtConcurrentReset(t *testing.T) {
180+
name := "prepared_stmt_concurrent_reset"
181+
user := *GetUser(name, Config{})
182+
createTx := DB.Session(&gorm.Session{}).Create(&user)
183+
if createTx.Error != nil {
184+
t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error)
185+
}
186+
187+
// create a new connection to keep away from other tests
188+
tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true})
189+
if err != nil {
190+
t.Fatalf("failed to open test connection due to %s", err)
191+
}
192+
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
193+
if !ok {
194+
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
195+
}
196+
197+
loopCount := 100
198+
var wg sync.WaitGroup
199+
var unexpectedError bool
200+
writerFinish := make(chan struct{})
201+
202+
wg.Add(1)
203+
go func(id uint) {
204+
defer wg.Done()
205+
defer close(writerFinish)
206+
207+
for j := 0; j < loopCount; j++ {
208+
var tmp User
209+
err := tx.Session(&gorm.Session{}).First(&tmp, id).Error
210+
if err == nil || isUsingClosedConnError(err) {
211+
continue
212+
}
213+
t.Errorf("failed to read user of id %d due to %s, there should not be error", id, err)
214+
unexpectedError = true
215+
break
216+
}
217+
}(user.ID)
218+
219+
wg.Add(1)
220+
go func() {
221+
defer wg.Done()
222+
<-writerFinish
223+
pdb.Reset()
224+
}()
225+
226+
wg.Wait()
227+
228+
if unexpectedError {
229+
t.Fatalf("should is a unexpected error")
230+
}
231+
}
232+
233+
// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently
234+
// for example: one goroutine found error and just close the database, and others are executing SQL
235+
// this test making sure that the gorm would not get a Segmentation Fault,
236+
// and the only error cause by this is using a closed Stmt or gorm.ErrInvalidDB
237+
// and all of the goroutine must got gorm.ErrInvalidDB after database close
238+
func TestPreparedStmtConcurrentClose(t *testing.T) {
239+
name := "prepared_stmt_concurrent_close"
240+
user := *GetUser(name, Config{})
241+
createTx := DB.Session(&gorm.Session{}).Create(&user)
242+
if createTx.Error != nil {
243+
t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error)
244+
}
245+
246+
// create a new connection to keep away from other tests
247+
tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true})
248+
if err != nil {
249+
t.Fatalf("failed to open test connection due to %s", err)
250+
}
251+
pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
252+
if !ok {
253+
t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode")
254+
}
255+
256+
loopCount := 100
257+
var wg sync.WaitGroup
258+
var lastErr error
259+
closeValid := make(chan struct{}, loopCount)
260+
closeStartIdx := loopCount / 2 // close the database at the middle of the execution
261+
var lastRunIndex int
262+
var closeFinishedAt int64
263+
264+
wg.Add(1)
265+
go func(id uint) {
266+
defer wg.Done()
267+
defer close(closeValid)
268+
for lastRunIndex = 1; lastRunIndex <= loopCount; lastRunIndex++ {
269+
if lastRunIndex == closeStartIdx {
270+
closeValid <- struct{}{}
271+
}
272+
var tmp User
273+
now := time.Now().UnixNano()
274+
err := tx.Session(&gorm.Session{}).First(&tmp, id).Error
275+
if err == nil {
276+
closeFinishedAt := atomic.LoadInt64(&closeFinishedAt)
277+
if (closeFinishedAt != 0) && (now > closeFinishedAt) {
278+
lastErr = errors.New("must got error after database closed")
279+
break
280+
}
281+
continue
282+
}
283+
lastErr = err
284+
break
285+
}
286+
}(user.ID)
287+
288+
wg.Add(1)
289+
go func() {
290+
defer wg.Done()
291+
for range closeValid {
292+
for i := 0; i < loopCount; i++ {
293+
pdb.Close() // the Close method must can be call multiple times
294+
atomic.CompareAndSwapInt64(&closeFinishedAt, 0, time.Now().UnixNano())
295+
}
296+
}
297+
}()
298+
299+
wg.Wait()
300+
var tmp User
301+
err = tx.Session(&gorm.Session{}).First(&tmp, user.ID).Error
302+
if err != gorm.ErrInvalidDB {
303+
t.Fatalf("must got a gorm.ErrInvalidDB while execution after db close, got %+v instead", err)
304+
}
305+
306+
// must be error
307+
if lastErr != gorm.ErrInvalidDB && !isUsingClosedConnError(lastErr) {
308+
t.Fatalf("exp error gorm.ErrInvalidDB, got %+v instead", lastErr)
309+
}
310+
if lastRunIndex >= loopCount || lastRunIndex < closeStartIdx {
311+
t.Fatalf("exp loop times between (closeStartIdx %d <=) and (< loopCount %d), got %d instead", closeStartIdx, loopCount, lastRunIndex)
312+
}
313+
if pdb.Stmts != nil {
314+
t.Fatalf("stmts must be nil")
315+
}
316+
}

0 commit comments

Comments
 (0)