|
4 | 4 | "context"
|
5 | 5 | "errors"
|
6 | 6 | "sync"
|
| 7 | + "sync/atomic" |
7 | 8 | "testing"
|
8 | 9 | "time"
|
9 | 10 |
|
@@ -167,3 +168,149 @@ func TestPreparedStmtReset(t *testing.T) {
|
167 | 168 | t.Fatalf("prepared stmt should be empty")
|
168 | 169 | }
|
169 | 170 | }
|
| 171 | + |
| 172 | +func isUsingClosedConnError(err error) bool { |
| 173 | + // https://github.com/golang/go/blob/e705a2d16e4ece77e08e80c168382cdb02890f5b/src/database/sql/sql.go#L2717 |
| 174 | + return err.Error() == "sql: statement is closed" |
| 175 | +} |
| 176 | + |
| 177 | +// TestPreparedStmtConcurrentReset test calling reset and executing SQL concurrently |
| 178 | +// this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt |
| 179 | +func TestPreparedStmtConcurrentReset(t *testing.T) { |
| 180 | + name := "prepared_stmt_concurrent_reset" |
| 181 | + user := *GetUser(name, Config{}) |
| 182 | + createTx := DB.Session(&gorm.Session{}).Create(&user) |
| 183 | + if createTx.Error != nil { |
| 184 | + t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error) |
| 185 | + } |
| 186 | + |
| 187 | + // create a new connection to keep away from other tests |
| 188 | + tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true}) |
| 189 | + if err != nil { |
| 190 | + t.Fatalf("failed to open test connection due to %s", err) |
| 191 | + } |
| 192 | + pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) |
| 193 | + if !ok { |
| 194 | + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") |
| 195 | + } |
| 196 | + |
| 197 | + loopCount := 100 |
| 198 | + var wg sync.WaitGroup |
| 199 | + var unexpectedError bool |
| 200 | + writerFinish := make(chan struct{}) |
| 201 | + |
| 202 | + wg.Add(1) |
| 203 | + go func(id uint) { |
| 204 | + defer wg.Done() |
| 205 | + defer close(writerFinish) |
| 206 | + |
| 207 | + for j := 0; j < loopCount; j++ { |
| 208 | + var tmp User |
| 209 | + err := tx.Session(&gorm.Session{}).First(&tmp, id).Error |
| 210 | + if err == nil || isUsingClosedConnError(err) { |
| 211 | + continue |
| 212 | + } |
| 213 | + t.Errorf("failed to read user of id %d due to %s, there should not be error", id, err) |
| 214 | + unexpectedError = true |
| 215 | + break |
| 216 | + } |
| 217 | + }(user.ID) |
| 218 | + |
| 219 | + wg.Add(1) |
| 220 | + go func() { |
| 221 | + defer wg.Done() |
| 222 | + <-writerFinish |
| 223 | + pdb.Reset() |
| 224 | + }() |
| 225 | + |
| 226 | + wg.Wait() |
| 227 | + |
| 228 | + if unexpectedError { |
| 229 | + t.Fatalf("should is a unexpected error") |
| 230 | + } |
| 231 | +} |
| 232 | + |
| 233 | +// TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently |
| 234 | +// for example: one goroutine found error and just close the database, and others are executing SQL |
| 235 | +// this test making sure that the gorm would not get a Segmentation Fault, |
| 236 | +// and the only error cause by this is using a closed Stmt or gorm.ErrInvalidDB |
| 237 | +// and all of the goroutine must got gorm.ErrInvalidDB after database close |
| 238 | +func TestPreparedStmtConcurrentClose(t *testing.T) { |
| 239 | + name := "prepared_stmt_concurrent_close" |
| 240 | + user := *GetUser(name, Config{}) |
| 241 | + createTx := DB.Session(&gorm.Session{}).Create(&user) |
| 242 | + if createTx.Error != nil { |
| 243 | + t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error) |
| 244 | + } |
| 245 | + |
| 246 | + // create a new connection to keep away from other tests |
| 247 | + tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true}) |
| 248 | + if err != nil { |
| 249 | + t.Fatalf("failed to open test connection due to %s", err) |
| 250 | + } |
| 251 | + pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) |
| 252 | + if !ok { |
| 253 | + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") |
| 254 | + } |
| 255 | + |
| 256 | + loopCount := 100 |
| 257 | + var wg sync.WaitGroup |
| 258 | + var lastErr error |
| 259 | + closeValid := make(chan struct{}, loopCount) |
| 260 | + closeStartIdx := loopCount / 2 // close the database at the middle of the execution |
| 261 | + var lastRunIndex int |
| 262 | + var closeFinishedAt int64 |
| 263 | + |
| 264 | + wg.Add(1) |
| 265 | + go func(id uint) { |
| 266 | + defer wg.Done() |
| 267 | + defer close(closeValid) |
| 268 | + for lastRunIndex = 1; lastRunIndex <= loopCount; lastRunIndex++ { |
| 269 | + if lastRunIndex == closeStartIdx { |
| 270 | + closeValid <- struct{}{} |
| 271 | + } |
| 272 | + var tmp User |
| 273 | + now := time.Now().UnixNano() |
| 274 | + err := tx.Session(&gorm.Session{}).First(&tmp, id).Error |
| 275 | + if err == nil { |
| 276 | + closeFinishedAt := atomic.LoadInt64(&closeFinishedAt) |
| 277 | + if (closeFinishedAt != 0) && (now > closeFinishedAt) { |
| 278 | + lastErr = errors.New("must got error after database closed") |
| 279 | + break |
| 280 | + } |
| 281 | + continue |
| 282 | + } |
| 283 | + lastErr = err |
| 284 | + break |
| 285 | + } |
| 286 | + }(user.ID) |
| 287 | + |
| 288 | + wg.Add(1) |
| 289 | + go func() { |
| 290 | + defer wg.Done() |
| 291 | + for range closeValid { |
| 292 | + for i := 0; i < loopCount; i++ { |
| 293 | + pdb.Close() // the Close method must can be call multiple times |
| 294 | + atomic.CompareAndSwapInt64(&closeFinishedAt, 0, time.Now().UnixNano()) |
| 295 | + } |
| 296 | + } |
| 297 | + }() |
| 298 | + |
| 299 | + wg.Wait() |
| 300 | + var tmp User |
| 301 | + err = tx.Session(&gorm.Session{}).First(&tmp, user.ID).Error |
| 302 | + if err != gorm.ErrInvalidDB { |
| 303 | + t.Fatalf("must got a gorm.ErrInvalidDB while execution after db close, got %+v instead", err) |
| 304 | + } |
| 305 | + |
| 306 | + // must be error |
| 307 | + if lastErr != gorm.ErrInvalidDB && !isUsingClosedConnError(lastErr) { |
| 308 | + t.Fatalf("exp error gorm.ErrInvalidDB, got %+v instead", lastErr) |
| 309 | + } |
| 310 | + if lastRunIndex >= loopCount || lastRunIndex < closeStartIdx { |
| 311 | + t.Fatalf("exp loop times between (closeStartIdx %d <=) and (< loopCount %d), got %d instead", closeStartIdx, loopCount, lastRunIndex) |
| 312 | + } |
| 313 | + if pdb.Stmts != nil { |
| 314 | + t.Fatalf("stmts must be nil") |
| 315 | + } |
| 316 | +} |
0 commit comments